import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms from torchvision.models import VGG19_Weights from PIL import Image import gradio as gr import time # ✅ Use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # --- Image Utilities --- def load_image(img, max_size=384): transform = transforms.Compose([ transforms.Resize(max_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = img.convert('RGB') image = transform(image).unsqueeze(0) return image.to(device) def tensor_to_image(tensor): unnormalize = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225] ) image = tensor.clone().detach().squeeze(0) image = unnormalize(image) image = torch.clamp(image, 0, 1) return transforms.ToPILImage()(image) # --- Style Transfer Utilities --- def gram_matrix(tensor): b, c, h, w = tensor.size() features = tensor.view(b * c, h * w) return torch.mm(features, features.t()) class StyleTransferNet(nn.Module): def __init__(self, style_img, content_img): super().__init__() weights = VGG19_Weights.DEFAULT self.vgg = models.vgg19(weights=weights).features.to(device).eval() self.style_img = style_img self.content_img = content_img self.content_layers = ['conv_4'] self.style_layers = ['conv_1', 'conv_3', 'conv_5', 'conv_9'] def get_features(self, x): features = {} i = 0 for layer in self.vgg.children(): x = layer(x) if isinstance(layer, nn.Conv2d): i += 1 name = f'conv_{i}' if name in self.content_layers + self.style_layers: features[name] = x return features def forward(self, input_img, steps=100, style_weight=1e6, content_weight=1e5): input_img = input_img.clone().requires_grad_(True) optimizer = optim.Adam([input_img], lr=0.02) style_features = self.get_features(self.style_img) content_features = self.get_features(self.content_img) style_grams = {k: gram_matrix(v) for k, v in style_features.items()} for step in range(steps): optimizer.zero_grad() target_features = self.get_features(input_img) style_loss = 0 content_loss = 0 for layer in self.style_layers: target_feature = target_features[layer] target_gram = gram_matrix(target_feature) style_gram = style_grams[layer] style_loss += torch.mean((target_gram - style_gram)**2) for layer in self.content_layers: target_feature = target_features[layer] content_feature = content_features[layer] content_loss += torch.mean((target_feature - content_feature)**2) total_loss = style_weight * style_loss + content_weight * content_loss total_loss.backward() optimizer.step() return input_img # --- Gradio App --- def style_transfer_app(content_img, style_img, content_weight_ui, style_weight_ui, steps): start_time = time.time() content = load_image(content_img) style = load_image(style_img) # Map intuitive UI weights (1-10) to actual values content_weight = content_weight_ui * 1e5 style_weight = style_weight_ui * 1e6 model = StyleTransferNet(style, content) output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight) stylized = tensor_to_image(output) elapsed = round(time.time() - start_time) # Estimated time display estimate_note = f"🕒 Estimated processing time: {elapsed} seconds for {steps} steps." return stylized, estimate_note # --- Launch Interface --- gr.Interface( fn=style_transfer_app, inputs=[ gr.Image(type="pil", label="🖼️ Content Image"), gr.Image(type="pil", label="🎨 Style Image"), gr.Slider(1, 10, value=1, step=1, label="Content Weight (1 = weak structure, 10 = strong)"), gr.Slider(1, 10, value=6, step=1, label="Style Weight (1 = subtle, 10 = strong style)"), gr.Slider(50, 300, value=100, step=50, label="Steps (speed vs quality)") ], outputs=[ gr.Image(type="pil", label="🧠 Stylized Output"), gr.Textbox(label="⏱️ Time Info") ], title="🎨 Fast AI Neural Style Transfer", description="Upload content and style images, then tune how much structure vs style you want. Powered by PyTorch + VGG19.", allow_flagging="never" ).launch(share=True)