Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torchvision import models, transforms | |
| from torchvision.models import VGG19_Weights | |
| from PIL import Image | |
| import gradio as gr | |
| import time | |
| # β Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| # --- Image Utilities --- | |
| def load_image(img, max_size=384): | |
| transform = transforms.Compose([ | |
| transforms.Resize(max_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = img.convert('RGB') | |
| image = transform(image).unsqueeze(0) | |
| return image.to(device) | |
| def tensor_to_image(tensor): | |
| unnormalize = transforms.Normalize( | |
| mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], | |
| std=[1 / 0.229, 1 / 0.224, 1 / 0.225] | |
| ) | |
| image = tensor.clone().detach().squeeze(0) | |
| image = unnormalize(image) | |
| image = torch.clamp(image, 0, 1) | |
| return transforms.ToPILImage()(image) | |
| # --- Style Transfer Utilities --- | |
| def gram_matrix(tensor): | |
| b, c, h, w = tensor.size() | |
| features = tensor.view(b * c, h * w) | |
| return torch.mm(features, features.t()) | |
| class StyleTransferNet(nn.Module): | |
| def __init__(self, style_img, content_img): | |
| super().__init__() | |
| weights = VGG19_Weights.DEFAULT | |
| self.vgg = models.vgg19(weights=weights).features.to(device).eval() | |
| self.style_img = style_img | |
| self.content_img = content_img | |
| self.content_layers = ['conv_4'] | |
| self.style_layers = ['conv_1', 'conv_3', 'conv_5', 'conv_9'] | |
| def get_features(self, x): | |
| features = {} | |
| i = 0 | |
| for layer in self.vgg.children(): | |
| x = layer(x) | |
| if isinstance(layer, nn.Conv2d): | |
| i += 1 | |
| name = f'conv_{i}' | |
| if name in self.content_layers + self.style_layers: | |
| features[name] = x | |
| return features | |
| def forward(self, input_img, steps=100, style_weight=1e6, content_weight=1e5): | |
| input_img = input_img.clone().requires_grad_(True) | |
| optimizer = optim.Adam([input_img], lr=0.02) | |
| style_features = self.get_features(self.style_img) | |
| content_features = self.get_features(self.content_img) | |
| style_grams = {k: gram_matrix(v) for k, v in style_features.items()} | |
| for step in range(steps): | |
| optimizer.zero_grad() | |
| target_features = self.get_features(input_img) | |
| style_loss = 0 | |
| content_loss = 0 | |
| for layer in self.style_layers: | |
| target_feature = target_features[layer] | |
| target_gram = gram_matrix(target_feature) | |
| style_gram = style_grams[layer] | |
| style_loss += torch.mean((target_gram - style_gram)**2) | |
| for layer in self.content_layers: | |
| target_feature = target_features[layer] | |
| content_feature = content_features[layer] | |
| content_loss += torch.mean((target_feature - content_feature)**2) | |
| total_loss = style_weight * style_loss + content_weight * content_loss | |
| total_loss.backward() | |
| optimizer.step() | |
| return input_img | |
| # --- Gradio App --- | |
| def style_transfer_app(content_img, style_img, content_weight_ui, style_weight_ui, steps): | |
| start_time = time.time() | |
| content = load_image(content_img) | |
| style = load_image(style_img) | |
| # Map intuitive UI weights (1-10) to actual values | |
| content_weight = content_weight_ui * 1e5 | |
| style_weight = style_weight_ui * 1e6 | |
| model = StyleTransferNet(style, content) | |
| output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight) | |
| stylized = tensor_to_image(output) | |
| elapsed = round(time.time() - start_time) | |
| # Estimated time display | |
| estimate_note = f"π Estimated processing time: {elapsed} seconds for {steps} steps." | |
| return stylized, estimate_note | |
| # --- Launch Interface --- | |
| gr.Interface( | |
| fn=style_transfer_app, | |
| inputs=[ | |
| gr.Image(type="pil", label="πΌοΈ Content Image"), | |
| gr.Image(type="pil", label="π¨ Style Image"), | |
| gr.Slider(1, 10, value=1, step=1, label="Content Weight (1 = weak structure, 10 = strong)"), | |
| gr.Slider(1, 10, value=6, step=1, label="Style Weight (1 = subtle, 10 = strong style)"), | |
| gr.Slider(50, 300, value=100, step=50, label="Steps (speed vs quality)") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="π§ Stylized Output"), | |
| gr.Textbox(label="β±οΈ Time Info") | |
| ], | |
| title="π¨ Fast AI Neural Style Transfer", | |
| description="Upload content and style images, then tune how much structure vs style you want. Powered by PyTorch + VGG19.", | |
| allow_flagging="never" | |
| ).launch(share=True) | |