heramb04 commited on
Commit
d945be4
Β·
verified Β·
1 Parent(s): b54ae8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -169
app.py CHANGED
@@ -1,170 +1,168 @@
1
- import gradio as gr
2
- import torch
3
- import torchvision.transforms as transforms
4
- import torchvision.models as models
5
- import torch.nn as nn
6
- import torch.optim as optim
7
-
8
- from PIL import Image
9
- from io import BytesIO
10
- import random
11
- from datasets import load_dataset
12
- from datetime import datetime
13
-
14
- # βœ… Device setup
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
- # πŸ“¦ Image preprocessing
18
- transform = transforms.Compose([
19
- transforms.Resize((512, 512)),
20
- transforms.ToTensor()
21
- ])
22
-
23
- def load_image(img):
24
- image = img.convert('RGB')
25
- return transform(image).unsqueeze(0).to(device)
26
-
27
- # πŸ”§ NST Core Classes
28
- class Normalization(nn.Module):
29
- def __init__(self, mean, std):
30
- super().__init__()
31
- self.mean = mean.view(-1, 1, 1)
32
- self.std = std.view(-1, 1, 1)
33
- def forward(self, img):
34
- return (img - self.mean) / self.std
35
-
36
- class ContentLoss(nn.Module):
37
- def __init__(self, target):
38
- super().__init__()
39
- self.target = target.detach()
40
- self.loss = 0
41
- def forward(self, input):
42
- self.loss = nn.functional.mse_loss(input, self.target)
43
- return input
44
-
45
- def gram_matrix(input):
46
- b, c, h, w = input.size()
47
- features = input.view(c, h * w)
48
- G = torch.mm(features, features.t())
49
- return G.div(c * h * w)
50
-
51
- class StyleLoss(nn.Module):
52
- def __init__(self, target_feature):
53
- super().__init__()
54
- self.target = gram_matrix(target_feature).detach()
55
- self.loss = 0
56
- def forward(self, input):
57
- G = gram_matrix(input)
58
- self.loss = nn.functional.mse_loss(G, self.target)
59
- return input
60
-
61
- # 🧠 Model builder
62
- def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img):
63
- norm = Normalization(norm_mean, norm_std).to(device)
64
- model = nn.Sequential(norm)
65
-
66
- content_losses, style_losses = [], []
67
- i = 0
68
- for layer in cnn.children():
69
- name = None
70
- if isinstance(layer, nn.Conv2d):
71
- i += 1
72
- name = f"conv_{i}"
73
- elif isinstance(layer, nn.ReLU):
74
- name = f"relu_{i}"
75
- layer = nn.ReLU(inplace=False)
76
- elif isinstance(layer, nn.MaxPool2d):
77
- name = f"pool_{i}"
78
- elif isinstance(layer, nn.BatchNorm2d):
79
- name = f"bn_{i}"
80
- if name:
81
- model.add_module(name, layer)
82
- if name == "conv_4":
83
- target = model(content_img).detach()
84
- content_loss = ContentLoss(target)
85
- model.add_module(f"content_loss_{i}", content_loss)
86
- content_losses.append(content_loss)
87
- if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]:
88
- target_feature = model(style_img).detach()
89
- style_loss = StyleLoss(target_feature)
90
- model.add_module(f"style_loss_{i}", style_loss)
91
- style_losses.append(style_loss)
92
-
93
- for j in range(len(model) - 1, -1, -1):
94
- if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
95
- break
96
- return model[:j + 1], style_losses, content_losses
97
-
98
- # 🎲 Random selector from Hugging Face dataset
99
- def get_random_image_pair():
100
- ds = load_dataset("heramb04/Famous-paintings", split="train")
101
- samples = random.sample(list(ds), 2)
102
- imgs = [Image.open(BytesIO(sample["image"]["bytes"])).convert("RGB") for sample in samples]
103
- return imgs[0], imgs[1]
104
-
105
- # πŸ–ŒοΈ NST logic
106
- def run_nst(content_pil, style_pil, steps=300):
107
- content = load_image(content_pil)
108
- style = load_image(style_pil)
109
- input_img = content.clone().requires_grad_(True)
110
-
111
- cnn = models.vgg19(pretrained=True).features.to(device).eval()
112
- norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
113
- norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
114
-
115
- model, style_losses, content_losses = get_model_losses(cnn, norm_mean, norm_std, style, content)
116
- optimizer = optim.LBFGS([input_img])
117
- run = [0]
118
-
119
- while run[0] <= steps:
120
- def closure():
121
- input_img.data.clamp_(0, 1)
122
- optimizer.zero_grad()
123
- model(input_img)
124
- style_score = sum(sl.loss for sl in style_losses)
125
- content_score = sum(cl.loss for cl in content_losses)
126
- loss = content_score + 1e6 * style_score
127
- loss.backward()
128
- run[0] += 1
129
- return loss
130
- optimizer.step(closure)
131
-
132
- output = input_img.clone().detach().cpu().squeeze(0)
133
- return transforms.ToPILImage()(output)
134
-
135
- # πŸŽ›οΈ Gradio UI
136
- with gr.Blocks(title="Neural Style Transfer β€” A + B = C") as demo:
137
- gr.Markdown("## 🎨 Neural Style Transfer<br>Upload two images OR pick random paintings to remix")
138
-
139
- with gr.Row():
140
- with gr.Column():
141
- content_input = gr.Image(label="πŸ–ΌοΈ Content Image", type="pil")
142
- style_input = gr.Image(label="🎨 Style Image", type="pil")
143
- steps_slider = gr.Slider(100, 500, value=300, step=50, label="Optimization Steps")
144
- upload_button = gr.Button("✨ Stylize Uploaded Images")
145
- random_button = gr.Button("🎲 Pick Random & Generate")
146
-
147
- with gr.Column():
148
- gr.Markdown("### 🧠 A + B = C")
149
- content_preview = gr.Image(label="A: Content", interactive=False)
150
- style_preview = gr.Image(label="B: Style", interactive=False)
151
- output_preview = gr.Image(label="C: Stylized Output", interactive=False)
152
-
153
- upload_button.click(
154
- fn=run_nst,
155
- inputs=[content_input, style_input, steps_slider],
156
- outputs=output_preview
157
- )
158
-
159
- def random_nst_wrapper(steps):
160
- content_img, style_img = get_random_image_pair()
161
- result = run_nst(content_img, style_img, steps)
162
- return content_img, style_img, result
163
-
164
- random_button.click(
165
- fn=random_nst_wrapper,
166
- inputs=[steps_slider],
167
- outputs=[content_preview, style_preview, output_preview]
168
- )
169
-
170
  demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+
8
+ from PIL import Image
9
+ import random
10
+ from datasets import load_dataset
11
+
12
+ # βœ… Device setup
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # πŸ“¦ Image preprocessing
16
+ transform = transforms.Compose([
17
+ transforms.Resize((512, 512)),
18
+ transforms.ToTensor()
19
+ ])
20
+
21
+ def load_image(img):
22
+ image = img.convert("RGB")
23
+ return transform(image).unsqueeze(0).to(device)
24
+
25
+ # πŸ”§ NST Core Classes
26
+ class Normalization(nn.Module):
27
+ def __init__(self, mean, std):
28
+ super().__init__()
29
+ self.mean = mean.view(-1, 1, 1)
30
+ self.std = std.view(-1, 1, 1)
31
+ def forward(self, img):
32
+ return (img - self.mean) / self.std
33
+
34
+ class ContentLoss(nn.Module):
35
+ def __init__(self, target):
36
+ super().__init__()
37
+ self.target = target.detach()
38
+ self.loss = 0
39
+ def forward(self, input):
40
+ self.loss = nn.functional.mse_loss(input, self.target)
41
+ return input
42
+
43
+ def gram_matrix(input):
44
+ b, c, h, w = input.size()
45
+ features = input.view(c, h * w)
46
+ G = torch.mm(features, features.t())
47
+ return G.div(c * h * w)
48
+
49
+ class StyleLoss(nn.Module):
50
+ def __init__(self, target_feature):
51
+ super().__init__()
52
+ self.target = gram_matrix(target_feature).detach()
53
+ self.loss = 0
54
+ def forward(self, input):
55
+ G = gram_matrix(input)
56
+ self.loss = nn.functional.mse_loss(G, self.target)
57
+ return input
58
+
59
+ # 🧠 Model builder
60
+ def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img):
61
+ norm = Normalization(norm_mean, norm_std).to(device)
62
+ model = nn.Sequential(norm)
63
+
64
+ content_losses, style_losses = [], []
65
+ i = 0
66
+ for layer in cnn.children():
67
+ name = None
68
+ if isinstance(layer, nn.Conv2d):
69
+ i += 1
70
+ name = f"conv_{i}"
71
+ elif isinstance(layer, nn.ReLU):
72
+ name = f"relu_{i}"
73
+ layer = nn.ReLU(inplace=False)
74
+ elif isinstance(layer, nn.MaxPool2d):
75
+ name = f"pool_{i}"
76
+ elif isinstance(layer, nn.BatchNorm2d):
77
+ name = f"bn_{i}"
78
+ if name:
79
+ model.add_module(name, layer)
80
+ if name == "conv_4":
81
+ target = model(content_img).detach()
82
+ content_loss = ContentLoss(target)
83
+ model.add_module(f"content_loss_{i}", content_loss)
84
+ content_losses.append(content_loss)
85
+ if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]:
86
+ target_feature = model(style_img).detach()
87
+ style_loss = StyleLoss(target_feature)
88
+ model.add_module(f"style_loss_{i}", style_loss)
89
+ style_losses.append(style_loss)
90
+
91
+ for j in range(len(model) - 1, -1, -1):
92
+ if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
93
+ break
94
+ return model[:j + 1], style_losses, content_losses
95
+
96
+ # 🎲 Random selector from Hugging Face dataset
97
+ def get_random_image_pair():
98
+ ds = load_dataset("heramb04/Famous-paintings", split="train")
99
+ samples = random.sample(list(ds), 2)
100
+ imgs = [sample["image"].convert("RGB") for sample in samples]
101
+ return imgs[0], imgs[1]
102
+
103
+ # πŸ–ŒοΈ NST logic
104
+ def run_nst(content_pil, style_pil, steps=300):
105
+ content = load_image(content_pil)
106
+ style = load_image(style_pil)
107
+ input_img = content.clone().requires_grad_(True)
108
+
109
+ cnn = models.vgg19(pretrained=True).features.to(device).eval()
110
+ norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
111
+ norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
112
+
113
+ model, style_losses, content_losses = get_model_losses(cnn, norm_mean, norm_std, style, content)
114
+ optimizer = optim.LBFGS([input_img])
115
+ run = [0]
116
+
117
+ while run[0] <= steps:
118
+ def closure():
119
+ input_img.data.clamp_(0, 1)
120
+ optimizer.zero_grad()
121
+ model(input_img)
122
+ style_score = sum(sl.loss for sl in style_losses)
123
+ content_score = sum(cl.loss for cl in content_losses)
124
+ loss = content_score + 1e6 * style_score
125
+ loss.backward()
126
+ run[0] += 1
127
+ return loss
128
+ optimizer.step(closure)
129
+
130
+ output = input_img.clone().detach().cpu().squeeze(0)
131
+ return transforms.ToPILImage()(output)
132
+
133
+ # πŸŽ›οΈ Gradio UI
134
+ with gr.Blocks(title="Neural Style Transfer β€” A + B = C") as demo:
135
+ gr.Markdown("## 🎨 Neural Style Transfer<br>Upload two images OR pick random paintings to remix")
136
+
137
+ with gr.Row():
138
+ with gr.Column():
139
+ content_input = gr.Image(label="πŸ–ΌοΈ Content Image", type="pil")
140
+ style_input = gr.Image(label="🎨 Style Image", type="pil")
141
+ steps_slider = gr.Slider(100, 500, value=300, step=50, label="Optimization Steps")
142
+ upload_button = gr.Button("✨ Stylize Uploaded Images")
143
+ random_button = gr.Button("🎲 Pick Random & Generate")
144
+
145
+ with gr.Column():
146
+ gr.Markdown("### 🧠 A + B = C")
147
+ content_preview = gr.Image(label="A: Content", interactive=False)
148
+ style_preview = gr.Image(label="B: Style", interactive=False)
149
+ output_preview = gr.Image(label="C: Stylized Output", interactive=False)
150
+
151
+ upload_button.click(
152
+ fn=run_nst,
153
+ inputs=[content_input, style_input, steps_slider],
154
+ outputs=output_preview
155
+ )
156
+
157
+ def random_nst_wrapper(steps):
158
+ content_img, style_img = get_random_image_pair()
159
+ result = run_nst(content_img, style_img, steps)
160
+ return content_img, style_img, result
161
+
162
+ random_button.click(
163
+ fn=random_nst_wrapper,
164
+ inputs=[steps_slider],
165
+ outputs=[content_preview, style_preview, output_preview]
166
+ )
167
+
 
 
168
  demo.launch(share=True)