Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| # π Device configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # π§ Preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor() | |
| ]) | |
| def load_image(img): | |
| image = img.convert("RGB") | |
| return transform(image).unsqueeze(0).to(device) | |
| # π― Loss modules | |
| class Normalization(nn.Module): | |
| def __init__(self, mean, std): | |
| super().__init__() | |
| self.mean = mean.view(-1, 1, 1) | |
| self.std = std.view(-1, 1, 1) | |
| def forward(self, img): | |
| return (img - self.mean) / self.std | |
| class ContentLoss(nn.Module): | |
| def __init__(self, target): | |
| super().__init__() | |
| self.target = target.detach() | |
| self.loss = 0 | |
| def forward(self, input): | |
| self.loss = nn.functional.mse_loss(input, self.target) | |
| return input | |
| def gram_matrix(input): | |
| b, c, h, w = input.size() | |
| features = input.view(c, h * w) | |
| G = torch.mm(features, features.t()) | |
| return G.div(c * h * w) | |
| class StyleLoss(nn.Module): | |
| def __init__(self, target_feature): | |
| super().__init__() | |
| self.target = gram_matrix(target_feature).detach() | |
| self.loss = 0 | |
| def forward(self, input): | |
| G = gram_matrix(input) | |
| self.loss = nn.functional.mse_loss(G, self.target) | |
| return input | |
| # 𧬠Model builder | |
| def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img): | |
| normalization = Normalization(norm_mean, norm_std).to(device) | |
| model = nn.Sequential(normalization) | |
| content_losses = [] | |
| style_losses = [] | |
| i = 0 | |
| for layer in cnn.children(): | |
| name = None | |
| if isinstance(layer, nn.Conv2d): | |
| i += 1 | |
| name = f"conv_{i}" | |
| elif isinstance(layer, nn.ReLU): | |
| name = f"relu_{i}" | |
| layer = nn.ReLU(inplace=False) | |
| elif isinstance(layer, nn.MaxPool2d): | |
| name = f"pool_{i}" | |
| elif isinstance(layer, nn.BatchNorm2d): | |
| name = f"bn_{i}" | |
| if name: | |
| model.add_module(name, layer) | |
| if name == "conv_4": | |
| target = model(content_img).detach() | |
| content_loss = ContentLoss(target) | |
| model.add_module(f"content_loss_{i}", content_loss) | |
| content_losses.append(content_loss) | |
| if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]: | |
| target_feature = model(style_img).detach() | |
| style_loss = StyleLoss(target_feature) | |
| model.add_module(f"style_loss_{i}", style_loss) | |
| style_losses.append(style_loss) | |
| for j in range(len(model) - 1, -1, -1): | |
| if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss): | |
| break | |
| return model[:j + 1], style_losses, content_losses | |
| # β¨ Stylization pipeline | |
| def run_nst(content_pil, style_pil, steps=300): | |
| content = load_image(content_pil) | |
| style = load_image(style_pil) | |
| input_img = content.clone().requires_grad_(True) | |
| cnn = models.vgg19(pretrained=True).features.to(device).eval() | |
| norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
| norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
| model, style_losses, content_losses = get_model_losses( | |
| cnn, norm_mean, norm_std, style, content | |
| ) | |
| optimizer = optim.LBFGS([input_img]) | |
| run = [0] | |
| while run[0] <= steps: | |
| def closure(): | |
| input_img.data.clamp_(0, 1) | |
| optimizer.zero_grad() | |
| model(input_img) | |
| style_score = sum(sl.loss for sl in style_losses) | |
| content_score = sum(cl.loss for cl in content_losses) | |
| loss = content_score + 1e6 * style_score | |
| loss.backward() | |
| run[0] += 1 | |
| return loss | |
| optimizer.step(closure) | |
| output = input_img.clone().detach().cpu().squeeze(0) | |
| return TF.to_pil_image(output) |