| from util.utils import pad_to_multiple_of_256, split_into_blocks, merge_blocks, crop_to_original_shape |
| import glob |
| import os |
| import torch |
| import numpy as np |
| import math |
| from torch.nn import functional as F |
| import PIL.Image as Image |
| from torchvision import utils as vutils |
|
|
|
|
| def load_img(p, padding=True, factor=64): |
| x = Image.open(p) |
| x = torch.from_numpy(np.asarray(x)) |
| if len(x.shape) == 2: |
| x = x.unsqueeze(-1).repeat(1, 1, 3) |
| x = x.permute(2, 0, 1).unsqueeze(0).float().div(255) |
| h, w = x.shape[2:4] |
|
|
| if padding: |
| dh = factor * math.ceil(h / factor) - h |
| dw = factor * math.ceil(w / factor) - w |
| |
| dh_half = dh // 2 |
| dw_half = dw // 2 |
| dh_extra = dh % 2 |
| dw_extra = dw % 2 |
| x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra)) |
| return x, h, w |
|
|
| def save_img(img: torch.Tensor, vis_path, input_p, rec=False): |
| img = img.clone().detach() |
| img = img.to(torch.device('cpu')) |
| if os.path.isdir(vis_path) is not True: |
| os.makedirs(vis_path) |
| end = '/' |
| if rec: |
| vis_path = vis_path + '/rec' |
| if os.path.isdir(vis_path) is not True: |
| os.makedirs(vis_path) |
| img_name = vis_path + str(input_p[input_p.rfind(end):]) |
| else: |
| img_name = vis_path + str(input_p[input_p.rfind(end):]) |
| vutils.save_image(img, os.path.join(img_name), nrow=8) |
|
|
|
|
| eval_path = sorted(glob.glob(os.path.join('/home/t2vg-a100-G4-10/project/qyp/datasets/test', '*.jpg'))) |
| vis_path = os.path.join("./test_crop/") |
| os.makedirs(vis_path, exist_ok=True) |
|
|
| for input_p in eval_path: |
| x, hx, wx = load_img(input_p, padding=True, factor=64) |
| print("ori height", hx, "ori width", wx) |
| ori_shape = x.shape |
| print("input shape", ori_shape) |
| x = pad_to_multiple_of_256(x, 0) |
| save_img(x, vis_path, input_p, rec=False) |
| print("shape after padding", x.shape) |
| _, _, new_h, new_w = x.shape |
| x = split_into_blocks(x) |
| print("new shape", x.shape) |
| new_bsz = x.size(0) |
| new_shape = [ori_shape[0], 3, new_h, new_w] |
| x = merge_blocks(x, new_shape) |
| print("shape after merge", x.shape) |
| x = crop_to_original_shape(x, ori_shape) |
| print("shape after crop", x.shape) |
| save_img(x, vis_path, input_p, rec=True) |