Spaces:
Configuration error
Configuration error
| import cv2 | |
| import numpy as np | |
| import torch | |
| from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans | |
| from collections import defaultdict | |
| from PIL import Image | |
| from skimage import color, io | |
| import util.functional as F | |
| class Normalize(object): | |
| def __init__(self): | |
| pass | |
| def __call__(self, inputs): | |
| inputs[0:1, :, :] = F.normalize(inputs[0:1, :, :], 50, 1) | |
| inputs[1:3, :, :] = F.normalize(inputs[1:3, :, :], (0, 0), (1, 1)) | |
| return inputs | |
| def tensor_to_numpy(image): | |
| image_np = (image.numpy() * 255).astype('uint8') | |
| return image_np | |
| def tensor_to_np_float(image): | |
| image_np = image.numpy().astype('float32') | |
| return image_np | |
| def detach_to_cpu(x): | |
| return x.detach().cpu() | |
| def transpose_np(x): | |
| return np.transpose(x, [1,2,0]) | |
| def tensor_to_gray_im(x): | |
| x = detach_to_cpu(x) | |
| x = tensor_to_numpy(x) | |
| x = transpose_np(x) | |
| return x | |
| def tensor_to_im(x): | |
| x = detach_to_cpu(x) | |
| x = inv_im_trans(x).clamp(0, 1) | |
| x = tensor_to_numpy(x) | |
| x = transpose_np(x) | |
| return x | |
| # Predefined key <-> caption dict | |
| key_captions = { | |
| 'im': 'Image', | |
| 'gt': 'GT', | |
| } | |
| """ | |
| Return an image array with captions | |
| keys in dictionary will be used as caption if not provided | |
| values should contain lists of cv2 images | |
| """ | |
| def get_image_array(images, grid_shape, captions={}): | |
| h, w = grid_shape | |
| cate_counts = len(images) | |
| rows_counts = len(next(iter(images.values()))) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) | |
| col_cnt = 0 | |
| for k, v in images.items(): | |
| # Default as key value itself | |
| caption = captions.get(k, k) | |
| # Handles new line character | |
| dy = 40 | |
| for i, line in enumerate(caption.split('\n')): | |
| cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), | |
| font, 0.8, (255,255,255), 2, cv2.LINE_AA) | |
| # Put images | |
| for row_cnt, img in enumerate(v): | |
| im_shape = img.shape | |
| if len(im_shape) == 2: | |
| img = img[..., np.newaxis] | |
| img = (img * 255).astype('uint8') | |
| output_image[(col_cnt+0)*w:(col_cnt+1)*w, | |
| (row_cnt+1)*h:(row_cnt+2)*h, :] = img | |
| col_cnt += 1 | |
| return output_image | |
| def base_transform(im, size): | |
| im = tensor_to_np_float(im) | |
| if len(im.shape) == 3: | |
| im = im.transpose((1, 2, 0)) | |
| else: | |
| im = im[:, :, None] | |
| # Resize | |
| if im.shape[1] != size: | |
| im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) | |
| return im.clip(0, 1) | |
| def im_transform(im, size): | |
| return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) | |
| def mask_transform(mask, size): | |
| return base_transform(detach_to_cpu(mask), size=size) | |
| def out_transform(mask, size): | |
| return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) | |
| def lll2rgb_transform(mask, size): | |
| flag_test = False | |
| mask_d = detach_to_cpu(mask) | |
| mask_d[1:3,:,:] = 0 | |
| if flag_test: print('before inv', mask_d.size(), torch.min(mask_d), torch.max(mask_d)) | |
| mask_d = inv_lll2rgb_trans(mask_d) | |
| if flag_test: print('after inv', mask_d.size(), torch.min(mask_d), torch.max(mask_d));assert 1==0 | |
| im = tensor_to_np_float(mask_d) | |
| if len(im.shape) == 3: | |
| im = im.transpose((1, 2, 0)) | |
| else: | |
| im = im[:, :, None] | |
| im = color.lab2rgb(im) | |
| # Resize | |
| if im.shape[1] != size: | |
| im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) | |
| return im.clip(0, 1) | |
| def lab2rgb_transform(mask, size): | |
| flag_test = False | |
| mask_d = detach_to_cpu(mask) | |
| if flag_test: print('before inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d)) | |
| mask_d = inv_lll2rgb_trans(mask_d) | |
| if flag_test: print('after inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d));assert 1==0 | |
| im = tensor_to_np_float(mask_d) | |
| if len(im.shape) == 3: | |
| im = im.transpose((1, 2, 0)) | |
| else: | |
| im = im[:, :, None] | |
| im = color.lab2rgb(im) | |
| # Resize | |
| if im.shape[1] != size: | |
| im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) | |
| return im.clip(0, 1) | |
| def pool_pairs_221128_TransColorization(images, size, num_objects): | |
| req_images = defaultdict(list) | |
| b, t = images['rgb'].shape[:2] | |
| # limit the number of images saved | |
| b = min(2, b) | |
| # find max num objects | |
| # max_num_objects = max(num_objects[:b]) | |
| max_num_objects = 1 | |
| GT_suffix = '' | |
| for bi in range(b): | |
| GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] | |
| # print(images['rgb'].size(), b, max_num_objects, images['info']['name'], GT_suffix) | |
| # print(images['info']['name'][0][-25:-4]) | |
| # print(images['info']['name'][1][-25:-4]) | |
| # assert 1==0 | |
| for bi in range(b): | |
| for ti in range(t): | |
| req_images['RGB'].append(lll2rgb_transform(images['rgb'][bi,ti], size)) | |
| for oi in range(max_num_objects): | |
| if ti == 0 or oi >= num_objects[bi]: | |
| # req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) | |
| # print(images['rgb'][bi,ti][:1,:,:].size(), images['first_frame_gt'][bi][0,:].size());assert 1==0 | |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['first_frame_gt'][bi][0,:]], dim=0), size)) | |
| else: | |
| # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) | |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['masks_%d'%ti][bi][:]], dim=0), size)) | |
| # req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) | |
| # print(images['cls_gt'][bi,ti,:,:].size());assert 1==0 | |
| req_images['GT_%d_%s'%(oi, GT_suffix)].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['cls_gt'][bi,ti,:,:]], dim=0), size)) | |
| # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) | |
| # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) | |
| return get_image_array(req_images, size, key_captions) | |
| def pool_pairs_221128_TransColorization_val(images, size, num_objects): | |
| req_images = defaultdict(list) | |
| b, t = images['rgb'].shape[:2] | |
| # limit the number of images saved | |
| b = min(2, b) | |
| # find max num objects | |
| # max_num_objects = max(num_objects[:b]) | |
| max_num_objects = 1 | |
| GT_suffix = '' | |
| for bi in range(b): | |
| GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] | |
| # print(images['rgb'].size(), b, max_num_objects, images['info']['name'], GT_suffix) | |
| # print(images['info']['name'][0][-25:-4]) | |
| # print(images['info']['name'][1][-25:-4]) | |
| # assert 1==0 | |
| for bi in range(b): | |
| for ti in range(t): | |
| req_images['RGB'].append(lll2rgb_transform(images['rgb'][bi,ti], size)) | |
| for oi in range(max_num_objects): | |
| if ti == 0 or oi >= num_objects[bi]: | |
| # req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) | |
| # print(images['rgb'][bi,ti][:1,:,:].size(), images['first_frame_gt'][bi][0,:].size());assert 1==0 | |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['first_frame_gt'][bi][0,:]], dim=0), size)) | |
| else: | |
| # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) | |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['masks_%d'%ti][bi][:]], dim=0), size)) | |
| # req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) | |
| # print(images['cls_gt'][bi,ti,:,:].size());assert 1==0 | |
| req_images['GT_%d_%s'%(oi, GT_suffix)].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['cls_gt'][bi,ti,:,:]], dim=0), size)) | |
| # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) | |
| # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) | |
| return get_image_array(req_images, size, key_captions) | |
| def pool_pairs(images, size, num_objects): | |
| req_images = defaultdict(list) | |
| b, t = images['rgb'].shape[:2] | |
| # limit the number of images saved | |
| b = min(2, b) | |
| # find max num objects | |
| max_num_objects = max(num_objects[:b]) | |
| GT_suffix = '' | |
| for bi in range(b): | |
| GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] | |
| for bi in range(b): | |
| for ti in range(t): | |
| req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size)) | |
| for oi in range(max_num_objects): | |
| if ti == 0 or oi >= num_objects[bi]: | |
| req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) | |
| # req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) | |
| # req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) | |
| else: | |
| req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) | |
| # req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size)) | |
| # req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size)) | |
| # req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size)) | |
| req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) | |
| # print((images['cls_gt'][bi,ti,0]==(oi+1)).shape) | |
| # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) | |
| return get_image_array(req_images, size, key_captions) |