Spaces:
Runtime error
Runtime error
| import math | |
| import random | |
| from PIL import Image | |
| import blobfile as bf | |
| #from mpi4py import MPI | |
| import numpy as np | |
| from torch.utils.data import DataLoader, Dataset | |
| import os | |
| import torchvision.transforms as transforms | |
| import torch as th | |
| from functools import partial | |
| import cv2 | |
| def get_params( size, resize_size, crop_size): | |
| w, h = size | |
| new_h = h | |
| new_w = w | |
| ss, ls = min(w, h), max(w, h) # shortside and longside | |
| width_is_shorter = w == ss | |
| ls = int(resize_size * ls / ss) | |
| ss = resize_size | |
| new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) | |
| x = random.randint(0, np.maximum(0, new_w - crop_size)) | |
| y = random.randint(0, np.maximum(0, new_h - crop_size)) | |
| flip = random.random() > 0.5 | |
| return {'crop_pos': (x, y), 'flip': flip} | |
| def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True): | |
| transform_list = [] | |
| transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method))) | |
| if flip: | |
| transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | |
| return transforms.Compose(transform_list) | |
| def get_tensor(normalize=True, toTensor=True): | |
| transform_list = [] | |
| if toTensor: | |
| transform_list += [transforms.ToTensor()] | |
| if normalize: | |
| transform_list += [transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))] | |
| return transforms.Compose(transform_list) | |
| def normalize(): | |
| return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| def __scale(img, target_width, method=Image.BICUBIC): | |
| return img.resize((target_width, target_width), method) | |
| def __flip(img, flip): | |
| if flip: | |
| return img.transpose(Image.FLIP_LEFT_RIGHT) | |
| return img |