Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.optim as optim | |
| from torch import nn | |
| from utils import get_features, gram_matrix | |
| # weight early layers more heavily | |
| style_weights = {'conv1_1': 1., | |
| 'conv2_1': 0.75, | |
| 'conv3_1': 0.2, | |
| 'conv4_1': 0.2, | |
| 'conv5_1': 0.2} | |
| # the balance between style and content | |
| content_weight = 1 # alpha | |
| style_weight = 1e9 # beta | |
| def generate_image(model: nn.Module, content: torch.Tensor, style: torch.Tensor, target: torch.Tensor, steps = 2700, content_wt=content_weight): | |
| content_features = get_features(content, model) | |
| style_features = get_features(style, model) | |
| # apply gram_matrix to each of the style features for that same layer | |
| style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features} | |
| optimizer = optim.Adam([target], lr=0.003) | |
| for ii in range(1, steps + 1): | |
| target_features = get_features(target, model) | |
| content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2) | |
| style_loss = 0 | |
| # calculate the style loss | |
| for layer in style_weights: | |
| target_feature = target_features[layer] | |
| target_gram = gram_matrix(target_feature) | |
| style_gram = style_grams[layer] | |
| layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2) | |
| _, d, h, w = target_feature.shape | |
| style_loss += layer_style_loss / (d * h * w) | |
| total_loss = content_wt * content_loss + style_weight * style_loss | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| if ii % 10 == 0: | |
| status = f"Processing step {ii} of {steps}" | |
| if ii == steps: | |
| status = "✅ Completed" | |
| yield target, status | |
| # return target | |