Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import torch | |
| import torch.optim as optim | |
| from torchvision import transforms | |
| import torch.nn as nn | |
| from Models_Class.NST_class import ( | |
| ContentLoss, | |
| Normalization, | |
| StyleLoss, | |
| ) | |
| import copy | |
| style_weight = 1e8 | |
| content_weight = 1e1 | |
| def image_loader(image_path, loader, device): | |
| image = Image.open(image_path).convert('RGB') | |
| image = loader(image).unsqueeze(0) | |
| return image.to(device, torch.float) | |
| def save_image(tensor, path="output.png"): | |
| image = tensor.cpu().clone() | |
| image = image.squeeze(0) | |
| image = transforms.ToPILImage()(image) | |
| image.save(path) | |
| def gram_matrix(input): | |
| a, b, c, d = input.size() | |
| features = input.view(a * b, c * d) | |
| G = torch.mm(features, features.t()) | |
| return G.div(a * b * c * d) | |
| def get_style_model_and_losses(cnn, normalization_mean, normalization_std, | |
| style_img, content_img, content_layers, style_layers, device): | |
| cnn = copy.deepcopy(cnn) | |
| normalization = Normalization(normalization_mean, normalization_std).to(device) | |
| content_losses = [] | |
| style_losses = [] | |
| model = nn.Sequential(normalization) | |
| i = 0 | |
| for layer in cnn.children(): | |
| if isinstance(layer, nn.Conv2d): | |
| i += 1 | |
| name = f'conv_{i}' | |
| elif isinstance(layer, nn.ReLU): | |
| name = f'relu_{i}' | |
| layer = nn.ReLU(inplace=False) | |
| elif isinstance(layer, nn.MaxPool2d): | |
| name = f'pool_{i}' | |
| elif isinstance(layer, nn.BatchNorm2d): | |
| name = f'bn_{i}' | |
| else: | |
| raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}') | |
| model.add_module(name, layer) | |
| if name in content_layers: | |
| target = model(content_img).detach() | |
| content_loss = ContentLoss(target) | |
| model.add_module(f"content_loss_{i}", content_loss) | |
| content_losses.append(content_loss) | |
| if name in style_layers: | |
| target_feature = model(style_img).detach() | |
| style_loss = StyleLoss(target_feature) | |
| model.add_module(f"style_loss_{i}", style_loss) | |
| style_losses.append(style_loss) | |
| for i in range(len(model) - 1, -1, -1): | |
| if isinstance(model[i], (ContentLoss, StyleLoss)): | |
| break | |
| model = model[:i+1] | |
| return model, style_losses, content_losses | |
| def run_style_transfer(cnn, normalization_mean, normalization_std, | |
| content_img, style_img, input_img,content_layers, style_layers, device, num_steps=300): | |
| print("Building the style transfer model..") | |
| model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, | |
| style_img, content_img,content_layers, style_layers, device ) | |
| optimizer = optim.LBFGS([input_img.requires_grad_()]) | |
| print("Optimizing..") | |
| run = [0] | |
| while run[0] <= num_steps: | |
| def closure(): | |
| input_img.data.clamp_(0, 1) | |
| optimizer.zero_grad() | |
| model(input_img) | |
| style_score = sum(sl.loss for sl in style_losses) | |
| content_score = sum(cl.loss for cl in content_losses) | |
| loss = style_weight * style_score + content_weight * content_score | |
| loss.backward() | |
| if run[0] % 50 == 0: | |
| print(f"Step {run[0]}:") | |
| print(f" Style Loss: {style_score.item():.4f}") | |
| print(f" Content Loss: {content_score.item():.4f}") | |
| print(f" Total Loss: {loss.item():.4f}\n") | |
| run[0] += 1 | |
| return loss | |
| optimizer.step(closure) | |
| input_img.data.clamp_(0, 1) | |
| return input_img |