|
|
| import torch |
| import numpy as np |
| import cv2 |
| import torch.nn as nn |
|
|
|
|
|
|
| class AverageMeter(object): |
| """ |
| Computes and stores the average and |
| current value. |
| """ |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
|
|
|
|
| def worker_init_fn(_): |
| worker_info = torch.utils.data.get_worker_info() |
| dataset = worker_info.dataset |
| worker_id = worker_info.id |
| return np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
|
|
|
| def recover_image( image_tensor, MEAN=[0.5, 0.5, 0.5], STD=[0.5, 0.5, 0.5]): |
| """ |
| read a tensor and recover it to image in cv2 format |
| args: |
| image_tensor: [C, H, W] or [B, C, H, W] |
| return: |
| image_save: [B, H, W, C] |
| """ |
| if image_tensor.ndim == 3: |
| image_tensor = image_tensor.unsqueeze(0) |
|
|
| x = torch.mul(image_tensor, torch.FloatTensor(STD).view(3,1,1).to(image_tensor.device)) |
| x = torch.add(x, torch.FloatTensor(MEAN).view(3,1,1).to(image_tensor.device) ) |
| x = x.data.cpu().numpy() |
| |
| image_rgb = np.transpose(x, (0, 2, 3, 1)) |
| |
| image_bgr = image_rgb[:, :, :, [2,1,0]] |
| |
| image_save = np.clip(image_bgr*255, 0, 255).astype('uint8') |
|
|
| return image_save |
|
|
|
|
|
|
| def align_images(image_list, h, w): |
| if len(image_list) != h * w: |
| |
| h = int(np.sqrt(len(image_list))) |
| w = int(np.ceil(len(image_list) / h)) |
| |
| image_list += [np.zeros_like(image_list[0])] * (h * w - len(image_list)) |
| |
| rows = [image_list[i * w:(i + 1) * w] for i in range(h)] |
| row_images = [cv2.hconcat(row) for row in rows] |
| final_image = cv2.vconcat(row_images) |
| return final_image |
|
|
|
|
|
|
|
|