import torch import torch.optim as optim from torch import nn from utils import get_features, gram_matrix # weight early layers more heavily style_weights = {'conv1_1': 1., 'conv2_1': 0.75, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2} # the balance between style and content content_weight = 1 # alpha style_weight = 1e9 # beta def generate_image(model: nn.Module, content: torch.Tensor, style: torch.Tensor, target: torch.Tensor, steps = 2700, content_wt=content_weight): content_features = get_features(content, model) style_features = get_features(style, model) # apply gram_matrix to each of the style features for that same layer style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features} optimizer = optim.Adam([target], lr=0.003) for ii in range(1, steps + 1): target_features = get_features(target, model) content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2) style_loss = 0 # calculate the style loss for layer in style_weights: target_feature = target_features[layer] target_gram = gram_matrix(target_feature) style_gram = style_grams[layer] layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2) _, d, h, w = target_feature.shape style_loss += layer_style_loss / (d * h * w) total_loss = content_wt * content_loss + style_weight * style_loss optimizer.zero_grad() total_loss.backward() optimizer.step() if ii % 10 == 0: status = f"Processing step {ii} of {steps}" if ii == steps: status = "✅ Completed" yield target, status # return target