Spaces:
Running
Running
File size: 4,880 Bytes
a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 a60a7da 146cb65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.models import VGG19_Weights
from PIL import Image
import gradio as gr
import time
# ✅ Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# --- Image Utilities ---
def load_image(img, max_size=384):
transform = transforms.Compose([
transforms.Resize(max_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = img.convert('RGB')
image = transform(image).unsqueeze(0)
return image.to(device)
def tensor_to_image(tensor):
unnormalize = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
)
image = tensor.clone().detach().squeeze(0)
image = unnormalize(image)
image = torch.clamp(image, 0, 1)
return transforms.ToPILImage()(image)
# --- Style Transfer Utilities ---
def gram_matrix(tensor):
b, c, h, w = tensor.size()
features = tensor.view(b * c, h * w)
return torch.mm(features, features.t())
class StyleTransferNet(nn.Module):
def __init__(self, style_img, content_img):
super().__init__()
weights = VGG19_Weights.DEFAULT
self.vgg = models.vgg19(weights=weights).features.to(device).eval()
self.style_img = style_img
self.content_img = content_img
self.content_layers = ['conv_4']
self.style_layers = ['conv_1', 'conv_3', 'conv_5', 'conv_9']
def get_features(self, x):
features = {}
i = 0
for layer in self.vgg.children():
x = layer(x)
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv_{i}'
if name in self.content_layers + self.style_layers:
features[name] = x
return features
def forward(self, input_img, steps=100, style_weight=1e6, content_weight=1e5):
input_img = input_img.clone().requires_grad_(True)
optimizer = optim.Adam([input_img], lr=0.02)
style_features = self.get_features(self.style_img)
content_features = self.get_features(self.content_img)
style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
for step in range(steps):
optimizer.zero_grad()
target_features = self.get_features(input_img)
style_loss = 0
content_loss = 0
for layer in self.style_layers:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = style_grams[layer]
style_loss += torch.mean((target_gram - style_gram)**2)
for layer in self.content_layers:
target_feature = target_features[layer]
content_feature = content_features[layer]
content_loss += torch.mean((target_feature - content_feature)**2)
total_loss = style_weight * style_loss + content_weight * content_loss
total_loss.backward()
optimizer.step()
return input_img
# --- Gradio App ---
def style_transfer_app(content_img, style_img, content_weight_ui, style_weight_ui, steps):
start_time = time.time()
content = load_image(content_img)
style = load_image(style_img)
# Map intuitive UI weights (1-10) to actual values
content_weight = content_weight_ui * 1e5
style_weight = style_weight_ui * 1e6
model = StyleTransferNet(style, content)
output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
stylized = tensor_to_image(output)
elapsed = round(time.time() - start_time)
# Estimated time display
estimate_note = f"🕒 Estimated processing time: {elapsed} seconds for {steps} steps."
return stylized, estimate_note
# --- Launch Interface ---
gr.Interface(
fn=style_transfer_app,
inputs=[
gr.Image(type="pil", label="🖼️ Content Image"),
gr.Image(type="pil", label="🎨 Style Image"),
gr.Slider(1, 10, value=1, step=1, label="Content Weight (1 = weak structure, 10 = strong)"),
gr.Slider(1, 10, value=6, step=1, label="Style Weight (1 = subtle, 10 = strong style)"),
gr.Slider(50, 300, value=100, step=50, label="Steps (speed vs quality)")
],
outputs=[
gr.Image(type="pil", label="🧠 Stylized Output"),
gr.Textbox(label="⏱️ Time Info")
],
title="🎨 Fast AI Neural Style Transfer",
description="Upload content and style images, then tune how much structure vs style you want. Powered by PyTorch + VGG19.",
allow_flagging="never"
).launch(share=True)
|