Spaces:
Sleeping
Sleeping
File size: 4,264 Bytes
3369069 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import torchvision.transforms.functional as TF
# 🚀 Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 🔧 Preprocessing
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
def load_image(img):
image = img.convert("RGB")
return transform(image).unsqueeze(0).to(device)
# 🎯 Loss modules
class Normalization(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std
class ContentLoss(nn.Module):
def __init__(self, target):
super().__init__()
self.target = target.detach()
self.loss = 0
def forward(self, input):
self.loss = nn.functional.mse_loss(input, self.target)
return input
def gram_matrix(input):
b, c, h, w = input.size()
features = input.view(c, h * w)
G = torch.mm(features, features.t())
return G.div(c * h * w)
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super().__init__()
self.target = gram_matrix(target_feature).detach()
self.loss = 0
def forward(self, input):
G = gram_matrix(input)
self.loss = nn.functional.mse_loss(G, self.target)
return input
# 🧬 Model builder
def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img):
normalization = Normalization(norm_mean, norm_std).to(device)
model = nn.Sequential(normalization)
content_losses = []
style_losses = []
i = 0
for layer in cnn.children():
name = None
if isinstance(layer, nn.Conv2d):
i += 1
name = f"conv_{i}"
elif isinstance(layer, nn.ReLU):
name = f"relu_{i}"
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f"pool_{i}"
elif isinstance(layer, nn.BatchNorm2d):
name = f"bn_{i}"
if name:
model.add_module(name, layer)
if name == "conv_4":
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module(f"content_loss_{i}", content_loss)
content_losses.append(content_loss)
if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module(f"style_loss_{i}", style_loss)
style_losses.append(style_loss)
for j in range(len(model) - 1, -1, -1):
if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
break
return model[:j + 1], style_losses, content_losses
# ✨ Stylization pipeline
def run_nst(content_pil, style_pil, steps=300):
content = load_image(content_pil)
style = load_image(style_pil)
input_img = content.clone().requires_grad_(True)
cnn = models.vgg19(pretrained=True).features.to(device).eval()
norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
model, style_losses, content_losses = get_model_losses(
cnn, norm_mean, norm_std, style, content
)
optimizer = optim.LBFGS([input_img])
run = [0]
while run[0] <= steps:
def closure():
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = sum(sl.loss for sl in style_losses)
content_score = sum(cl.loss for cl in content_losses)
loss = content_score + 1e6 * style_score
loss.backward()
run[0] += 1
return loss
optimizer.step(closure)
output = input_img.clone().detach().cpu().squeeze(0)
return TF.to_pil_image(output) |