StyleTransferAI / app.py
Nano233's picture
Update main
a60a7da verified
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.models import VGG19_Weights
from PIL import Image
import gradio as gr
import time
# βœ… Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# --- Image Utilities ---
def load_image(img, max_size=384):
transform = transforms.Compose([
transforms.Resize(max_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = img.convert('RGB')
image = transform(image).unsqueeze(0)
return image.to(device)
def tensor_to_image(tensor):
unnormalize = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
)
image = tensor.clone().detach().squeeze(0)
image = unnormalize(image)
image = torch.clamp(image, 0, 1)
return transforms.ToPILImage()(image)
# --- Style Transfer Utilities ---
def gram_matrix(tensor):
b, c, h, w = tensor.size()
features = tensor.view(b * c, h * w)
return torch.mm(features, features.t())
class StyleTransferNet(nn.Module):
def __init__(self, style_img, content_img):
super().__init__()
weights = VGG19_Weights.DEFAULT
self.vgg = models.vgg19(weights=weights).features.to(device).eval()
self.style_img = style_img
self.content_img = content_img
self.content_layers = ['conv_4']
self.style_layers = ['conv_1', 'conv_3', 'conv_5', 'conv_9']
def get_features(self, x):
features = {}
i = 0
for layer in self.vgg.children():
x = layer(x)
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv_{i}'
if name in self.content_layers + self.style_layers:
features[name] = x
return features
def forward(self, input_img, steps=100, style_weight=1e6, content_weight=1e5):
input_img = input_img.clone().requires_grad_(True)
optimizer = optim.Adam([input_img], lr=0.02)
style_features = self.get_features(self.style_img)
content_features = self.get_features(self.content_img)
style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
for step in range(steps):
optimizer.zero_grad()
target_features = self.get_features(input_img)
style_loss = 0
content_loss = 0
for layer in self.style_layers:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
style_loss += torch.mean((target_gram - style_gram)**2)
for layer in self.content_layers:
target_feature = target_features[layer]
content_feature = content_features[layer]
content_loss += torch.mean((target_feature - content_feature)**2)
total_loss = style_weight * style_loss + content_weight * content_loss
total_loss.backward()
optimizer.step()
return input_img
# --- Gradio App ---
def style_transfer_app(content_img, style_img, content_weight_ui, style_weight_ui, steps):
start_time = time.time()
content = load_image(content_img)
style = load_image(style_img)
# Map intuitive UI weights (1-10) to actual values
content_weight = content_weight_ui * 1e5
style_weight = style_weight_ui * 1e6
model = StyleTransferNet(style, content)
output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
stylized = tensor_to_image(output)
elapsed = round(time.time() - start_time)
# Estimated time display
estimate_note = f"πŸ•’ Estimated processing time: {elapsed} seconds for {steps} steps."
return stylized, estimate_note
# --- Launch Interface ---
gr.Interface(
fn=style_transfer_app,
inputs=[
gr.Image(type="pil", label="πŸ–ΌοΈ Content Image"),
gr.Image(type="pil", label="🎨 Style Image"),
gr.Slider(1, 10, value=1, step=1, label="Content Weight (1 = weak structure, 10 = strong)"),
gr.Slider(1, 10, value=6, step=1, label="Style Weight (1 = subtle, 10 = strong style)"),
gr.Slider(50, 300, value=100, step=50, label="Steps (speed vs quality)")
],
outputs=[
gr.Image(type="pil", label="🧠 Stylized Output"),
gr.Textbox(label="⏱️ Time Info")
],
title="🎨 Fast AI Neural Style Transfer",
description="Upload content and style images, then tune how much structure vs style you want. Powered by PyTorch + VGG19.",
allow_flagging="never"
).launch(share=True)