maomao88's picture
add intermediate steps
7dff5ef
import torch
from torchvision import transforms, models
from PIL import Image
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
def load_model(d=device):
weights = models.VGG19_Weights.DEFAULT
model = models.vgg19(weights=weights).features # only uses the feature layers of the model
# https://pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html
for param in model.parameters():
param.requires_grad_(False)
model.to(device=d)
return model
# max_size limits the image size to 400 pixel
def load_image(image, max_size=400, shape=None):
# image = Image.open(img_path).convert('RGB')
# either the horizontal or vertical image size exceeds max_size, set the size to max_size
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size), # Resize will scale the smaller edge of the image to 'size'
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
image = in_transform(image).unsqueeze(0)
return image
def im_convert(tensor):
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
image = image.clip(0, 1)
return image
def get_features(image, model):
layers = {'0': 'conv1_1', # Style Extraction
'5': 'conv2_1', # Style Extraction
'10': 'conv3_1', # Style Extraction
'19': 'conv4_1', # Style Extraction
'21': 'conv4_2', # Content Extraction
'28': 'conv5_1'} # Style Extraction
features = {}
for name, layer in model._modules.items():
# feed the image through the network
image = layer(image) # run the image through this layer and store it as the output for the layer
if name in layers:
features[layers[name]] = image
return features
# Eliminate content feature and only maintain style features
def gram_matrix(tensor):
_, d, h, w = tensor.size() # d is depth, h is height, w is width
tensor = tensor.view(d, h * w) # reshape the data into a 2 dimensional tensor
gram = torch.mm(tensor, tensor.t())
return gram