Spaces:
Running
Running
Update main
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.optim as optim
|
|
@@ -5,13 +6,14 @@ from torchvision import models, transforms
|
|
| 5 |
from torchvision.models import VGG19_Weights
|
| 6 |
from PIL import Image
|
| 7 |
import gradio as gr
|
|
|
|
| 8 |
|
| 9 |
# ✅ Use GPU if available
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
print("Using device:", device)
|
| 12 |
|
| 13 |
# --- Image Utilities ---
|
| 14 |
-
def load_image(img, max_size=
|
| 15 |
transform = transforms.Compose([
|
| 16 |
transforms.Resize(max_size),
|
| 17 |
transforms.ToTensor(),
|
|
@@ -60,48 +62,55 @@ class StyleTransferNet(nn.Module):
|
|
| 60 |
features[name] = x
|
| 61 |
return features
|
| 62 |
|
| 63 |
-
def forward(self, input_img, steps=
|
| 64 |
input_img = input_img.clone().requires_grad_(True)
|
| 65 |
-
optimizer = optim.
|
| 66 |
|
| 67 |
style_features = self.get_features(self.style_img)
|
| 68 |
content_features = self.get_features(self.content_img)
|
| 69 |
style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
run[0] += 1
|
| 93 |
-
return total_loss
|
| 94 |
-
|
| 95 |
-
optimizer.step(closure)
|
| 96 |
return input_img
|
| 97 |
|
| 98 |
# --- Gradio App ---
|
| 99 |
-
def style_transfer_app(content_img, style_img,
|
|
|
|
| 100 |
content = load_image(content_img)
|
| 101 |
style = load_image(style_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
model = StyleTransferNet(style, content)
|
| 103 |
output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# --- Launch Interface ---
|
| 107 |
gr.Interface(
|
|
@@ -109,14 +118,15 @@ gr.Interface(
|
|
| 109 |
inputs=[
|
| 110 |
gr.Image(type="pil", label="🖼️ Content Image"),
|
| 111 |
gr.Image(type="pil", label="🎨 Style Image"),
|
| 112 |
-
gr.Slider(
|
| 113 |
-
gr.Slider(
|
| 114 |
-
gr.Slider(50,
|
| 115 |
],
|
| 116 |
-
outputs=
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
allow_flagging="never"
|
| 120 |
).launch(share=True)
|
| 121 |
-
|
| 122 |
-
gr.Interface(...).launch()
|
|
|
|
| 1 |
+
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.optim as optim
|
|
|
|
| 6 |
from torchvision.models import VGG19_Weights
|
| 7 |
from PIL import Image
|
| 8 |
import gradio as gr
|
| 9 |
+
import time
|
| 10 |
|
| 11 |
# ✅ Use GPU if available
|
| 12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
print("Using device:", device)
|
| 14 |
|
| 15 |
# --- Image Utilities ---
|
| 16 |
+
def load_image(img, max_size=384):
|
| 17 |
transform = transforms.Compose([
|
| 18 |
transforms.Resize(max_size),
|
| 19 |
transforms.ToTensor(),
|
|
|
|
| 62 |
features[name] = x
|
| 63 |
return features
|
| 64 |
|
| 65 |
+
def forward(self, input_img, steps=100, style_weight=1e6, content_weight=1e5):
|
| 66 |
input_img = input_img.clone().requires_grad_(True)
|
| 67 |
+
optimizer = optim.Adam([input_img], lr=0.02)
|
| 68 |
|
| 69 |
style_features = self.get_features(self.style_img)
|
| 70 |
content_features = self.get_features(self.content_img)
|
| 71 |
style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
|
| 72 |
|
| 73 |
+
for step in range(steps):
|
| 74 |
+
optimizer.zero_grad()
|
| 75 |
+
target_features = self.get_features(input_img)
|
| 76 |
+
style_loss = 0
|
| 77 |
+
content_loss = 0
|
| 78 |
+
|
| 79 |
+
for layer in self.style_layers:
|
| 80 |
+
target_feature = target_features[layer]
|
| 81 |
+
target_gram = gram_matrix(target_feature)
|
| 82 |
+
style_gram = style_grams[layer]
|
| 83 |
+
style_loss += torch.mean((target_gram - style_gram)**2)
|
| 84 |
+
|
| 85 |
+
for layer in self.content_layers:
|
| 86 |
+
target_feature = target_features[layer]
|
| 87 |
+
content_feature = content_features[layer]
|
| 88 |
+
content_loss += torch.mean((target_feature - content_feature)**2)
|
| 89 |
+
|
| 90 |
+
total_loss = style_weight * style_loss + content_weight * content_loss
|
| 91 |
+
total_loss.backward()
|
| 92 |
+
optimizer.step()
|
| 93 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
return input_img
|
| 95 |
|
| 96 |
# --- Gradio App ---
|
| 97 |
+
def style_transfer_app(content_img, style_img, content_weight_ui, style_weight_ui, steps):
|
| 98 |
+
start_time = time.time()
|
| 99 |
content = load_image(content_img)
|
| 100 |
style = load_image(style_img)
|
| 101 |
+
|
| 102 |
+
# Map intuitive UI weights (1-10) to actual values
|
| 103 |
+
content_weight = content_weight_ui * 1e5
|
| 104 |
+
style_weight = style_weight_ui * 1e6
|
| 105 |
+
|
| 106 |
model = StyleTransferNet(style, content)
|
| 107 |
output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
|
| 108 |
+
stylized = tensor_to_image(output)
|
| 109 |
+
elapsed = round(time.time() - start_time)
|
| 110 |
+
|
| 111 |
+
# Estimated time display
|
| 112 |
+
estimate_note = f"🕒 Estimated processing time: {elapsed} seconds for {steps} steps."
|
| 113 |
+
return stylized, estimate_note
|
| 114 |
|
| 115 |
# --- Launch Interface ---
|
| 116 |
gr.Interface(
|
|
|
|
| 118 |
inputs=[
|
| 119 |
gr.Image(type="pil", label="🖼️ Content Image"),
|
| 120 |
gr.Image(type="pil", label="🎨 Style Image"),
|
| 121 |
+
gr.Slider(1, 10, value=1, step=1, label="Content Weight (1 = weak structure, 10 = strong)"),
|
| 122 |
+
gr.Slider(1, 10, value=6, step=1, label="Style Weight (1 = subtle, 10 = strong style)"),
|
| 123 |
+
gr.Slider(50, 300, value=100, step=50, label="Steps (speed vs quality)")
|
| 124 |
],
|
| 125 |
+
outputs=[
|
| 126 |
+
gr.Image(type="pil", label="🧠 Stylized Output"),
|
| 127 |
+
gr.Textbox(label="⏱️ Time Info")
|
| 128 |
+
],
|
| 129 |
+
title="🎨 Fast AI Neural Style Transfer",
|
| 130 |
+
description="Upload content and style images, then tune how much structure vs style you want. Powered by PyTorch + VGG19.",
|
| 131 |
allow_flagging="never"
|
| 132 |
).launch(share=True)
|
|
|
|
|
|