mimc_rl / deploited /test_crop.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
from util.utils import pad_to_multiple_of_256, split_into_blocks, merge_blocks, crop_to_original_shape
import glob
import os
import torch
import numpy as np
import math
from torch.nn import functional as F
import PIL.Image as Image
from torchvision import utils as vutils
def load_img(p, padding=True, factor=64):
x = Image.open(p)
x = torch.from_numpy(np.asarray(x))
if len(x.shape) == 2:
x = x.unsqueeze(-1).repeat(1, 1, 3) # h,w -> h,w,3
x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
h, w = x.shape[2:4]
if padding:
dh = factor * math.ceil(h / factor) - h
dw = factor * math.ceil(w / factor) - w
# 均匀添加padding
dh_half = dh // 2
dw_half = dw // 2
dh_extra = dh % 2
dw_extra = dw % 2
x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra))
return x, h, w
def save_img(img: torch.Tensor, vis_path, input_p, rec=False):
img = img.clone().detach()
img = img.to(torch.device('cpu'))
if os.path.isdir(vis_path) is not True:
os.makedirs(vis_path)
end = '/'
if rec:
vis_path = vis_path + '/rec'
if os.path.isdir(vis_path) is not True:
os.makedirs(vis_path)
img_name = vis_path + str(input_p[input_p.rfind(end):])
else:
img_name = vis_path + str(input_p[input_p.rfind(end):])
vutils.save_image(img, os.path.join(img_name), nrow=8)
eval_path = sorted(glob.glob(os.path.join('/home/t2vg-a100-G4-10/project/qyp/datasets/test', '*.jpg')))
vis_path = os.path.join("./test_crop/")
os.makedirs(vis_path, exist_ok=True)
for input_p in eval_path:
x, hx, wx = load_img(input_p, padding=True, factor=64)
print("ori height", hx, "ori width", wx)
ori_shape = x.shape
print("input shape", ori_shape)
x = pad_to_multiple_of_256(x, 0)
save_img(x, vis_path, input_p, rec=False)
print("shape after padding", x.shape)
_, _, new_h, new_w = x.shape
x = split_into_blocks(x)
print("new shape", x.shape)
new_bsz = x.size(0)
new_shape = [ori_shape[0], 3, new_h, new_w]
x = merge_blocks(x, new_shape)
print("shape after merge", x.shape)
x = crop_to_original_shape(x, ori_shape)
print("shape after crop", x.shape)
save_img(x, vis_path, input_p, rec=True)