Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import numpy as np | |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu") | |
| def load_model(d=device): | |
| weights = models.VGG19_Weights.DEFAULT | |
| model = models.vgg19(weights=weights).features # only uses the feature layers of the model | |
| # https://pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html | |
| for param in model.parameters(): | |
| param.requires_grad_(False) | |
| model.to(device=d) | |
| return model | |
| # max_size limits the image size to 400 pixel | |
| def load_image(image, max_size=400, shape=None): | |
| # image = Image.open(img_path).convert('RGB') | |
| # either the horizontal or vertical image size exceeds max_size, set the size to max_size | |
| if max(image.size) > max_size: | |
| size = max_size | |
| else: | |
| size = max(image.size) | |
| if shape is not None: | |
| size = shape | |
| in_transform = transforms.Compose([ | |
| transforms.Resize(size), # Resize will scale the smaller edge of the image to 'size' | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))]) | |
| image = in_transform(image).unsqueeze(0) | |
| return image | |
| def im_convert(tensor): | |
| image = tensor.to("cpu").clone().detach() | |
| image = image.numpy().squeeze() | |
| image = image.transpose(1,2,0) | |
| image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) | |
| image = image.clip(0, 1) | |
| return image | |
| def get_features(image, model): | |
| layers = {'0': 'conv1_1', # Style Extraction | |
| '5': 'conv2_1', # Style Extraction | |
| '10': 'conv3_1', # Style Extraction | |
| '19': 'conv4_1', # Style Extraction | |
| '21': 'conv4_2', # Content Extraction | |
| '28': 'conv5_1'} # Style Extraction | |
| features = {} | |
| for name, layer in model._modules.items(): | |
| # feed the image through the network | |
| image = layer(image) # run the image through this layer and store it as the output for the layer | |
| if name in layers: | |
| features[layers[name]] = image | |
| return features | |
| # Eliminate content feature and only maintain style features | |
| def gram_matrix(tensor): | |
| _, d, h, w = tensor.size() # d is depth, h is height, w is width | |
| tensor = tensor.view(d, h * w) # reshape the data into a 2 dimensional tensor | |
| gram = torch.mm(tensor, tensor.t()) | |
| return gram |