from PIL import Image import torch import torch.optim as optim from torchvision import transforms import torch.nn as nn from Models_Class.NST_class import ( ContentLoss, Normalization, StyleLoss, ) import copy style_weight = 1e8 content_weight = 1e1 def image_loader(image_path, loader, device): image = Image.open(image_path).convert('RGB') image = loader(image).unsqueeze(0) return image.to(device, torch.float) def save_image(tensor, path="output.png"): image = tensor.cpu().clone() image = image.squeeze(0) image = transforms.ToPILImage()(image) image.save(path) def gram_matrix(input): a, b, c, d = input.size() features = input.view(a * b, c * d) G = torch.mm(features, features.t()) return G.div(a * b * c * d) def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, content_layers, style_layers, device): cnn = copy.deepcopy(cnn) normalization = Normalization(normalization_mean, normalization_std).to(device) content_losses = [] style_losses = [] model = nn.Sequential(normalization) i = 0 for layer in cnn.children(): if isinstance(layer, nn.Conv2d): i += 1 name = f'conv_{i}' elif isinstance(layer, nn.ReLU): name = f'relu_{i}' layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): name = f'pool_{i}' elif isinstance(layer, nn.BatchNorm2d): name = f'bn_{i}' else: raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}') model.add_module(name, layer) if name in content_layers: target = model(content_img).detach() content_loss = ContentLoss(target) model.add_module(f"content_loss_{i}", content_loss) content_losses.append(content_loss) if name in style_layers: target_feature = model(style_img).detach() style_loss = StyleLoss(target_feature) model.add_module(f"style_loss_{i}", style_loss) style_losses.append(style_loss) for i in range(len(model) - 1, -1, -1): if isinstance(model[i], (ContentLoss, StyleLoss)): break model = model[:i+1] return model, style_losses, content_losses def run_style_transfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img,content_layers, style_layers, device, num_steps=300): print("Building the style transfer model..") model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img,content_layers, style_layers, device ) optimizer = optim.LBFGS([input_img.requires_grad_()]) print("Optimizing..") run = [0] while run[0] <= num_steps: def closure(): input_img.data.clamp_(0, 1) optimizer.zero_grad() model(input_img) style_score = sum(sl.loss for sl in style_losses) content_score = sum(cl.loss for cl in content_losses) loss = style_weight * style_score + content_weight * content_score loss.backward() if run[0] % 50 == 0: print(f"Step {run[0]}:") print(f" Style Loss: {style_score.item():.4f}") print(f" Content Loss: {content_score.item():.4f}") print(f" Total Loss: {loss.item():.4f}\n") run[0] += 1 return loss optimizer.step(closure) input_img.data.clamp_(0, 1) return input_img