Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| def requires_grad(model, flag=True): | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| def get_keys(d, name): | |
| if 'state_dict' in d: | |
| d = d['state_dict'] | |
| d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} | |
| return d_filt | |
| def load_img(path_img, img_size=(256, 256)): | |
| transform = transforms.Compose( | |
| [transforms.Resize(img_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))]) | |
| if type(path_img) is np.ndarray: | |
| img = Image.fromarray(path_img) | |
| else: | |
| img = Image.open(path_img).convert('RGB') | |
| img = transform(img) | |
| img.unsqueeze_(0) | |
| return img | |