Nano233 commited on
Commit
146cb65
·
verified ·
1 Parent(s): 673bae7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -120
app.py CHANGED
@@ -1,120 +1,122 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torchvision import models, transforms
5
- from torchvision.models import VGG19_Weights
6
- from PIL import Image
7
- import gradio as gr
8
-
9
- # ✅ Use GPU if available
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- print("Using device:", device)
12
-
13
- # --- Image Utilities ---
14
- def load_image(img, max_size=512):
15
- transform = transforms.Compose([
16
- transforms.Resize(max_size),
17
- transforms.ToTensor(),
18
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
19
- std=[0.229, 0.224, 0.225])
20
- ])
21
- image = img.convert('RGB')
22
- image = transform(image).unsqueeze(0)
23
- return image.to(device)
24
-
25
- def tensor_to_image(tensor):
26
- unnormalize = transforms.Normalize(
27
- mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
28
- std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
29
- )
30
- image = tensor.clone().detach().squeeze(0)
31
- image = unnormalize(image)
32
- image = torch.clamp(image, 0, 1)
33
- return transforms.ToPILImage()(image)
34
-
35
- # --- Style Transfer Utilities ---
36
- def gram_matrix(tensor):
37
- b, c, h, w = tensor.size()
38
- features = tensor.view(b * c, h * w)
39
- return torch.mm(features, features.t())
40
-
41
- class StyleTransferNet(nn.Module):
42
- def __init__(self, style_img, content_img):
43
- super().__init__()
44
- weights = VGG19_Weights.DEFAULT
45
- self.vgg = models.vgg19(weights=weights).features.to(device).eval()
46
- self.style_img = style_img
47
- self.content_img = content_img
48
- self.content_layers = ['conv_4']
49
- self.style_layers = ['conv_1', 'conv_3', 'conv_5', 'conv_9']
50
-
51
- def get_features(self, x):
52
- features = {}
53
- i = 0
54
- for layer in self.vgg.children():
55
- x = layer(x)
56
- if isinstance(layer, nn.Conv2d):
57
- i += 1
58
- name = f'conv_{i}'
59
- if name in self.content_layers + self.style_layers:
60
- features[name] = x
61
- return features
62
-
63
- def forward(self, input_img, steps=300, style_weight=1e6, content_weight=0.25):
64
- input_img = input_img.clone().requires_grad_(True)
65
- optimizer = optim.LBFGS([input_img])
66
-
67
- style_features = self.get_features(self.style_img)
68
- content_features = self.get_features(self.content_img)
69
- style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
70
-
71
- run = [0]
72
- while run[0] <= steps:
73
- def closure():
74
- optimizer.zero_grad()
75
- target_features = self.get_features(input_img)
76
- style_loss = 0
77
- content_loss = 0
78
-
79
- for layer in self.style_layers:
80
- target_feature = target_features[layer]
81
- target_gram = gram_matrix(target_feature)
82
- style_gram = style_grams[layer]
83
- style_loss += torch.mean((target_gram - style_gram)**2)
84
-
85
- for layer in self.content_layers:
86
- target_feature = target_features[layer]
87
- content_feature = content_features[layer]
88
- content_loss += torch.mean((target_feature - content_feature)**2)
89
-
90
- total_loss = style_weight * style_loss + content_weight * content_loss
91
- total_loss.backward(retain_graph=True)
92
- run[0] += 1
93
- return total_loss
94
-
95
- optimizer.step(closure)
96
- return input_img
97
-
98
- # --- Gradio App ---
99
- def style_transfer_app(content_img, style_img, content_weight, style_weight, steps):
100
- content = load_image(content_img)
101
- style = load_image(style_img)
102
- model = StyleTransferNet(style, content)
103
- output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
104
- return tensor_to_image(output)
105
-
106
- # --- Launch Interface ---
107
- gr.Interface(
108
- fn=style_transfer_app,
109
- inputs=[
110
- gr.Image(type="pil", label="🖼️ Content Image"),
111
- gr.Image(type="pil", label="🎨 Style Image"),
112
- gr.Slider(0.05, 1.0, value=0.25, step=0.05, label="Content Weight"),
113
- gr.Slider(1e5, 5e6, value=1e6, step=1e5, label="Style Weight"),
114
- gr.Slider(50, 500, value=300, step=50, label="Steps")
115
- ],
116
- outputs=gr.Image(type="pil", label="🧠 Stylized Output"),
117
- title="🧠 AI Neural Style Transfer Lab",
118
- description="Upload a content image and a style image. Then tweak the controls below to explore the balance between structure and stylization. Powered by PyTorch + VGG19.",
119
- allow_flagging="never"
120
- ).launch(share=True)
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import models, transforms
5
+ from torchvision.models import VGG19_Weights
6
+ from PIL import Image
7
+ import gradio as gr
8
+
9
+ # ✅ Use GPU if available
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print("Using device:", device)
12
+
13
+ # --- Image Utilities ---
14
+ def load_image(img, max_size=512):
15
+ transform = transforms.Compose([
16
+ transforms.Resize(max_size),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
19
+ std=[0.229, 0.224, 0.225])
20
+ ])
21
+ image = img.convert('RGB')
22
+ image = transform(image).unsqueeze(0)
23
+ return image.to(device)
24
+
25
+ def tensor_to_image(tensor):
26
+ unnormalize = transforms.Normalize(
27
+ mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
28
+ std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
29
+ )
30
+ image = tensor.clone().detach().squeeze(0)
31
+ image = unnormalize(image)
32
+ image = torch.clamp(image, 0, 1)
33
+ return transforms.ToPILImage()(image)
34
+
35
+ # --- Style Transfer Utilities ---
36
+ def gram_matrix(tensor):
37
+ b, c, h, w = tensor.size()
38
+ features = tensor.view(b * c, h * w)
39
+ return torch.mm(features, features.t())
40
+
41
+ class StyleTransferNet(nn.Module):
42
+ def __init__(self, style_img, content_img):
43
+ super().__init__()
44
+ weights = VGG19_Weights.DEFAULT
45
+ self.vgg = models.vgg19(weights=weights).features.to(device).eval()
46
+ self.style_img = style_img
47
+ self.content_img = content_img
48
+ self.content_layers = ['conv_4']
49
+ self.style_layers = ['conv_1', 'conv_3', 'conv_5', 'conv_9']
50
+
51
+ def get_features(self, x):
52
+ features = {}
53
+ i = 0
54
+ for layer in self.vgg.children():
55
+ x = layer(x)
56
+ if isinstance(layer, nn.Conv2d):
57
+ i += 1
58
+ name = f'conv_{i}'
59
+ if name in self.content_layers + self.style_layers:
60
+ features[name] = x
61
+ return features
62
+
63
+ def forward(self, input_img, steps=300, style_weight=1e6, content_weight=0.25):
64
+ input_img = input_img.clone().requires_grad_(True)
65
+ optimizer = optim.LBFGS([input_img])
66
+
67
+ style_features = self.get_features(self.style_img)
68
+ content_features = self.get_features(self.content_img)
69
+ style_grams = {k: gram_matrix(v) for k, v in style_features.items()}
70
+
71
+ run = [0]
72
+ while run[0] <= steps:
73
+ def closure():
74
+ optimizer.zero_grad()
75
+ target_features = self.get_features(input_img)
76
+ style_loss = 0
77
+ content_loss = 0
78
+
79
+ for layer in self.style_layers:
80
+ target_feature = target_features[layer]
81
+ target_gram = gram_matrix(target_feature)
82
+ style_gram = style_grams[layer]
83
+ style_loss += torch.mean((target_gram - style_gram)**2)
84
+
85
+ for layer in self.content_layers:
86
+ target_feature = target_features[layer]
87
+ content_feature = content_features[layer]
88
+ content_loss += torch.mean((target_feature - content_feature)**2)
89
+
90
+ total_loss = style_weight * style_loss + content_weight * content_loss
91
+ total_loss.backward(retain_graph=True)
92
+ run[0] += 1
93
+ return total_loss
94
+
95
+ optimizer.step(closure)
96
+ return input_img
97
+
98
+ # --- Gradio App ---
99
+ def style_transfer_app(content_img, style_img, content_weight, style_weight, steps):
100
+ content = load_image(content_img)
101
+ style = load_image(style_img)
102
+ model = StyleTransferNet(style, content)
103
+ output = model(content, steps=int(steps), content_weight=content_weight, style_weight=style_weight)
104
+ return tensor_to_image(output)
105
+
106
+ # --- Launch Interface ---
107
+ gr.Interface(
108
+ fn=style_transfer_app,
109
+ inputs=[
110
+ gr.Image(type="pil", label="🖼️ Content Image"),
111
+ gr.Image(type="pil", label="🎨 Style Image"),
112
+ gr.Slider(0.05, 1.0, value=0.25, step=0.05, label="Content Weight"),
113
+ gr.Slider(1e5, 5e6, value=1e6, step=1e5, label="Style Weight"),
114
+ gr.Slider(50, 500, value=300, step=50, label="Steps")
115
+ ],
116
+ outputs=gr.Image(type="pil", label="🧠 Stylized Output"),
117
+ title="🧠 AI Neural Style Transfer Lab",
118
+ description="Upload a content image and a style image. Then tweak the controls below to explore the balance between structure and stylization. Powered by PyTorch + VGG19.",
119
+ allow_flagging="never"
120
+ ).launch(share=True)
121
+
122
+ gr.Interface(...).launch()