Nano233 commited on
Commit
a60a7da
·
verified ·
1 Parent(s): 1f8377e

Update main

Browse files
Files changed (1) hide show
  1. app.py +48 -38
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
@@ -5,13 +6,14 @@ from torchvision import models, transforms
5
  from torchvision.models import VGG19_Weights
6
  from PIL import Image
7
  import gradio as gr
 
8
 
9
  # ✅ Use GPU if available
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  print("Using device:", device)
12
 
13
  # --- Image Utilities ---
14
- def load_image(img, max_size=512):
15
  transform = transforms.Compose([
16
  transforms.Resize(max_size),
17
  transforms.ToTensor(),
@@ -60,48 +62,55 @@ class StyleTransferNet(nn.Module):
60
  features[name] = x
61
  return features
62
 
63
- def forward(self, input_img, steps=300, style_weight=1e6, content_weight=0.25):
64
  input_img = input_img.clone().requires_grad_(True)
65
- optimizer = optim.LBFGS([input_img])
66
 
67
  style_features = self.get_features(self.style_img)
68
  content_features = self.get_features(self.content_img)
69
  style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
70
 
71
- run = [0]
72
- while run[0] <= steps:
73
- def closure():
74
- optimizer.zero_grad()
75
- target_features = self.get_features(input_img)
76
- style_loss = 0
77
- content_loss = 0
78
-
79
- for layer in self.style_layers:
80
- target_feature = target_features[layer]
81
- target_gram = gram_matrix(target_feature)
82
- style_gram = style_grams[layer]
83
- style_loss += torch.mean((target_gram - style_gram)**2)
84
-
85
- for layer in self.content_layers:
86
- target_feature = target_features[layer]
87
- content_feature = content_features[layer]
88
- content_loss += torch.mean((target_feature - content_feature)**2)
89
-
90
- total_loss = style_weight * style_loss + content_weight * content_loss
91
- total_loss.backward(retain_graph=True)
92
- run[0] += 1
93
- return total_loss
94
-
95
- optimizer.step(closure)
96
  return input_img
97
 
98
  # --- Gradio App ---
99
- def style_transfer_app(content_img, style_img, content_weight, style_weight, steps):
 
100
  content = load_image(content_img)
101
  style = load_image(style_img)
 
 
 
 
 
102
  model = StyleTransferNet(style, content)
103
  output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
104
- return tensor_to_image(output)
 
 
 
 
 
105
 
106
  # --- Launch Interface ---
107
  gr.Interface(
@@ -109,14 +118,15 @@ gr.Interface(
109
  inputs=[
110
  gr.Image(type="pil", label="🖼️ Content Image"),
111
  gr.Image(type="pil", label="🎨 Style Image"),
112
- gr.Slider(0.05, 1.0, value=0.25, step=0.05, label="Content Weight"),
113
- gr.Slider(1e5, 5e6, value=1e6, step=1e5, label="Style Weight"),
114
- gr.Slider(50, 500, value=300, step=50, label="Steps")
115
  ],
116
- outputs=gr.Image(type="pil", label="🧠 Stylized Output"),
117
- title="🧠 AI Neural Style Transfer Lab",
118
- description="Upload a content image and a style image. Then tweak the controls below to explore the balance between structure and stylization. Powered by PyTorch + VGG19.",
 
 
 
119
  allow_flagging="never"
120
  ).launch(share=True)
121
-
122
- gr.Interface(...).launch()
 
1
+
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
 
6
  from torchvision.models import VGG19_Weights
7
  from PIL import Image
8
  import gradio as gr
9
+ import time
10
 
11
  # ✅ Use GPU if available
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print("Using device:", device)
14
 
15
  # --- Image Utilities ---
16
+ def load_image(img, max_size=384):
17
  transform = transforms.Compose([
18
  transforms.Resize(max_size),
19
  transforms.ToTensor(),
 
62
  features[name] = x
63
  return features
64
 
65
+ def forward(self, input_img, steps=100, style_weight=1e6, content_weight=1e5):
66
  input_img = input_img.clone().requires_grad_(True)
67
+ optimizer = optim.Adam([input_img], lr=0.02)
68
 
69
  style_features = self.get_features(self.style_img)
70
  content_features = self.get_features(self.content_img)
71
  style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
72
 
73
+ for step in range(steps):
74
+ optimizer.zero_grad()
75
+ target_features = self.get_features(input_img)
76
+ style_loss = 0
77
+ content_loss = 0
78
+
79
+ for layer in self.style_layers:
80
+ target_feature = target_features[layer]
81
+ target_gram = gram_matrix(target_feature)
82
+ style_gram = style_grams[layer]
83
+ style_loss += torch.mean((target_gram - style_gram)**2)
84
+
85
+ for layer in self.content_layers:
86
+ target_feature = target_features[layer]
87
+ content_feature = content_features[layer]
88
+ content_loss += torch.mean((target_feature - content_feature)**2)
89
+
90
+ total_loss = style_weight * style_loss + content_weight * content_loss
91
+ total_loss.backward()
92
+ optimizer.step()
93
+
 
 
 
 
94
  return input_img
95
 
96
  # --- Gradio App ---
97
+ def style_transfer_app(content_img, style_img, content_weight_ui, style_weight_ui, steps):
98
+ start_time = time.time()
99
  content = load_image(content_img)
100
  style = load_image(style_img)
101
+
102
+ # Map intuitive UI weights (1-10) to actual values
103
+ content_weight = content_weight_ui * 1e5
104
+ style_weight = style_weight_ui * 1e6
105
+
106
  model = StyleTransferNet(style, content)
107
  output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
108
+ stylized = tensor_to_image(output)
109
+ elapsed = round(time.time() - start_time)
110
+
111
+ # Estimated time display
112
+ estimate_note = f"🕒 Estimated processing time: {elapsed} seconds for {steps} steps."
113
+ return stylized, estimate_note
114
 
115
  # --- Launch Interface ---
116
  gr.Interface(
 
118
  inputs=[
119
  gr.Image(type="pil", label="🖼️ Content Image"),
120
  gr.Image(type="pil", label="🎨 Style Image"),
121
+ gr.Slider(1, 10, value=1, step=1, label="Content Weight (1 = weak structure, 10 = strong)"),
122
+ gr.Slider(1, 10, value=6, step=1, label="Style Weight (1 = subtle, 10 = strong style)"),
123
+ gr.Slider(50, 300, value=100, step=50, label="Steps (speed vs quality)")
124
  ],
125
+ outputs=[
126
+ gr.Image(type="pil", label="🧠 Stylized Output"),
127
+ gr.Textbox(label="⏱️ Time Info")
128
+ ],
129
+ title="🎨 Fast AI Neural Style Transfer",
130
+ description="Upload content and style images, then tune how much structure vs style you want. Powered by PyTorch + VGG19.",
131
  allow_flagging="never"
132
  ).launch(share=True)