import gradio as gr import torch import torchvision.transforms as transforms import torchvision.models as models import torch.nn as nn import torch.optim as optim from PIL import Image import random from datasets import load_dataset # ✅ Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 📦 Image preprocessing transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) def load_image(img): image = img.convert("RGB") return transform(image).unsqueeze(0).to(device) # 🔧 NST Core Classes class Normalization(nn.Module): def __init__(self, mean, std): super().__init__() self.mean = mean.view(-1, 1, 1) self.std = std.view(-1, 1, 1) def forward(self, img): return (img - self.mean) / self.std class ContentLoss(nn.Module): def __init__(self, target): super().__init__() self.target = target.detach() self.loss = 0 def forward(self, input): self.loss = nn.functional.mse_loss(input, self.target) return input def gram_matrix(input): b, c, h, w = input.size() features = input.view(c, h * w) G = torch.mm(features, features.t()) return G.div(c * h * w) class StyleLoss(nn.Module): def __init__(self, target_feature): super().__init__() self.target = gram_matrix(target_feature).detach() self.loss = 0 def forward(self, input): G = gram_matrix(input) self.loss = nn.functional.mse_loss(G, self.target) return input # 🧠 Model builder def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img): norm = Normalization(norm_mean, norm_std).to(device) model = nn.Sequential(norm) content_losses, style_losses = [], [] i = 0 for layer in cnn.children(): name = None if isinstance(layer, nn.Conv2d): i += 1 name = f"conv_{i}" elif isinstance(layer, nn.ReLU): name = f"relu_{i}" layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): name = f"pool_{i}" elif isinstance(layer, nn.BatchNorm2d): name = f"bn_{i}" if name: model.add_module(name, layer) if name == "conv_4": target = model(content_img).detach() content_loss = ContentLoss(target) model.add_module(f"content_loss_{i}", content_loss) content_losses.append(content_loss) if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]: target_feature = model(style_img).detach() style_loss = StyleLoss(target_feature) model.add_module(f"style_loss_{i}", style_loss) style_losses.append(style_loss) for j in range(len(model) - 1, -1, -1): if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss): break return model[:j + 1], style_losses, content_losses # 🎲 Random selector from Hugging Face dataset def get_random_image_pair(): ds = load_dataset("heramb04/Famous-paintings", split="train") samples = random.sample(list(ds), 2) imgs = [sample["image"].convert("RGB") for sample in samples] return imgs[0], imgs[1] # 🖌️ NST logic def run_nst(content_pil, style_pil, steps=300): content = load_image(content_pil) style = load_image(style_pil) input_img = content.clone().requires_grad_(True) cnn = models.vgg19(pretrained=True).features.to(device).eval() norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device) model, style_losses, content_losses = get_model_losses(cnn, norm_mean, norm_std, style, content) optimizer = optim.LBFGS([input_img]) run = [0] while run[0] <= steps: def closure(): input_img.data.clamp_(0, 1) optimizer.zero_grad() model(input_img) style_score = sum(sl.loss for sl in style_losses) content_score = sum(cl.loss for cl in content_losses) loss = content_score + 1e6 * style_score loss.backward() run[0] += 1 return loss optimizer.step(closure) output = input_img.clone().detach().cpu().squeeze(0) return transforms.ToPILImage()(output) # 🎛️ Gradio UI with gr.Blocks(title="Neural Style Transfer — A + B = C") as demo: gr.Markdown("## 🎨 Neural Style Transfer
Upload two images OR pick random paintings to remix") with gr.Row(): with gr.Column(): content_input = gr.Image(label="🖼️ Content Image", type="pil") style_input = gr.Image(label="🎨 Style Image", type="pil") steps_slider = gr.Slider(100, 500, value=300, step=50, label="Optimization Steps") upload_button = gr.Button("✨ Stylize Uploaded Images") random_button = gr.Button("🎲 Pick Random & Generate") with gr.Column(): gr.Markdown("### 🧠 A + B = C") content_preview = gr.Image(label="A: Content", interactive=False) style_preview = gr.Image(label="B: Style", interactive=False) output_preview = gr.Image(label="C: Stylized Output", interactive=False) upload_button.click( fn=run_nst, inputs=[content_input, style_input, steps_slider], outputs=output_preview ) def random_nst_wrapper(steps): content_img, style_img = get_random_image_pair() result = run_nst(content_img, style_img, steps) return content_img, style_img, result random_button.click( fn=random_nst_wrapper, inputs=[steps_slider], outputs=[content_preview, style_preview, output_preview] ) demo.launch(share=True)