maomao88's picture
add intermediate steps
7dff5ef
import torch
import torch.optim as optim
from torch import nn
from utils import get_features, gram_matrix
# weight early layers more heavily
style_weights = {'conv1_1': 1.,
'conv2_1': 0.75,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.2}
# the balance between style and content
content_weight = 1 # alpha
style_weight = 1e9 # beta
def generate_image(model: nn.Module, content: torch.Tensor, style: torch.Tensor, target: torch.Tensor, steps = 2700, content_wt=content_weight):
content_features = get_features(content, model)
style_features = get_features(style, model)
# apply gram_matrix to each of the style features for that same layer
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
optimizer = optim.Adam([target], lr=0.003)
for ii in range(1, steps + 1):
target_features = get_features(target, model)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
style_loss = 0
# calculate the style loss
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
_, d, h, w = target_feature.shape
style_loss += layer_style_loss / (d * h * w)
total_loss = content_wt * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if ii % 10 == 0:
status = f"Processing step {ii} of {steps}"
if ii == steps:
status = "✅ Completed"
yield target, status
# return target