import torch from torchvision import transforms, models from PIL import Image import numpy as np device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu") def load_model(d=device): weights = models.VGG19_Weights.DEFAULT model = models.vgg19(weights=weights).features # only uses the feature layers of the model # https://pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html for param in model.parameters(): param.requires_grad_(False) model.to(device=d) return model # max_size limits the image size to 400 pixel def load_image(image, max_size=400, shape=None): # image = Image.open(img_path).convert('RGB') # either the horizontal or vertical image size exceeds max_size, set the size to max_size if max(image.size) > max_size: size = max_size else: size = max(image.size) if shape is not None: size = shape in_transform = transforms.Compose([ transforms.Resize(size), # Resize will scale the smaller edge of the image to 'size' transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) image = in_transform(image).unsqueeze(0) return image def im_convert(tensor): image = tensor.to("cpu").clone().detach() image = image.numpy().squeeze() image = image.transpose(1,2,0) image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5)) image = image.clip(0, 1) return image def get_features(image, model): layers = {'0': 'conv1_1', # Style Extraction '5': 'conv2_1', # Style Extraction '10': 'conv3_1', # Style Extraction '19': 'conv4_1', # Style Extraction '21': 'conv4_2', # Content Extraction '28': 'conv5_1'} # Style Extraction features = {} for name, layer in model._modules.items(): # feed the image through the network image = layer(image) # run the image through this layer and store it as the output for the layer if name in layers: features[layers[name]] = image return features # Eliminate content feature and only maintain style features def gram_matrix(tensor): _, d, h, w = tensor.size() # d is depth, h is height, w is width tensor = tensor.view(d, h * w) # reshape the data into a 2 dimensional tensor gram = torch.mm(tensor, tensor.t()) return gram