Style_Sync / app.py
heramb04's picture
Update app.py
d945be4 verified
import gradio as gr
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import random
from datasets import load_dataset
# βœ… Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# πŸ“¦ Image preprocessing
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
def load_image(img):
image = img.convert("RGB")
return transform(image).unsqueeze(0).to(device)
# πŸ”§ NST Core Classes
class Normalization(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std
class ContentLoss(nn.Module):
def __init__(self, target):
super().__init__()
self.target = target.detach()
self.loss = 0
def forward(self, input):
self.loss = nn.functional.mse_loss(input, self.target)
return input
def gram_matrix(input):
b, c, h, w = input.size()
features = input.view(c, h * w)
G = torch.mm(features, features.t())
return G.div(c * h * w)
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super().__init__()
self.target = gram_matrix(target_feature).detach()
self.loss = 0
def forward(self, input):
G = gram_matrix(input)
self.loss = nn.functional.mse_loss(G, self.target)
return input
# 🧠 Model builder
def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img):
norm = Normalization(norm_mean, norm_std).to(device)
model = nn.Sequential(norm)
content_losses, style_losses = [], []
i = 0
for layer in cnn.children():
name = None
if isinstance(layer, nn.Conv2d):
i += 1
name = f"conv_{i}"
elif isinstance(layer, nn.ReLU):
name = f"relu_{i}"
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f"pool_{i}"
elif isinstance(layer, nn.BatchNorm2d):
name = f"bn_{i}"
if name:
model.add_module(name, layer)
if name == "conv_4":
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module(f"content_loss_{i}", content_loss)
content_losses.append(content_loss)
if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module(f"style_loss_{i}", style_loss)
style_losses.append(style_loss)
for j in range(len(model) - 1, -1, -1):
if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
break
return model[:j + 1], style_losses, content_losses
# 🎲 Random selector from Hugging Face dataset
def get_random_image_pair():
ds = load_dataset("heramb04/Famous-paintings", split="train")
samples = random.sample(list(ds), 2)
imgs = [sample["image"].convert("RGB") for sample in samples]
return imgs[0], imgs[1]
# πŸ–ŒοΈ NST logic
def run_nst(content_pil, style_pil, steps=300):
content = load_image(content_pil)
style = load_image(style_pil)
input_img = content.clone().requires_grad_(True)
cnn = models.vgg19(pretrained=True).features.to(device).eval()
norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
model, style_losses, content_losses = get_model_losses(cnn, norm_mean, norm_std, style, content)
optimizer = optim.LBFGS([input_img])
run = [0]
while run[0] <= steps:
def closure():
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = sum(sl.loss for sl in style_losses)
content_score = sum(cl.loss for cl in content_losses)
loss = content_score + 1e6 * style_score
loss.backward()
run[0] += 1
return loss
optimizer.step(closure)
output = input_img.clone().detach().cpu().squeeze(0)
return transforms.ToPILImage()(output)
# πŸŽ›οΈ Gradio UI
with gr.Blocks(title="Neural Style Transfer β€” A + B = C") as demo:
gr.Markdown("## 🎨 Neural Style Transfer<br>Upload two images OR pick random paintings to remix")
with gr.Row():
with gr.Column():
content_input = gr.Image(label="πŸ–ΌοΈ Content Image", type="pil")
style_input = gr.Image(label="🎨 Style Image", type="pil")
steps_slider = gr.Slider(100, 500, value=300, step=50, label="Optimization Steps")
upload_button = gr.Button("✨ Stylize Uploaded Images")
random_button = gr.Button("🎲 Pick Random & Generate")
with gr.Column():
gr.Markdown("### 🧠 A + B = C")
content_preview = gr.Image(label="A: Content", interactive=False)
style_preview = gr.Image(label="B: Style", interactive=False)
output_preview = gr.Image(label="C: Stylized Output", interactive=False)
upload_button.click(
fn=run_nst,
inputs=[content_input, style_input, steps_slider],
outputs=output_preview
)
def random_nst_wrapper(steps):
content_img, style_img = get_random_image_pair()
result = run_nst(content_img, style_img, steps)
return content_img, style_img, result
random_button.click(
fn=random_nst_wrapper,
inputs=[steps_slider],
outputs=[content_preview, style_preview, output_preview]
)
demo.launch(share=True)