heramb04 commited on
Commit
3369069
·
verified ·
1 Parent(s): ba38386

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +170 -0
  2. requirements.txt +6 -0
  3. style_transfer.py +130 -0
  4. utils.py +34 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import random
11
+ from datasets import load_dataset
12
+ from datetime import datetime
13
+
14
+ # ✅ Device setup
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # 📦 Image preprocessing
18
+ transform = transforms.Compose([
19
+ transforms.Resize((512, 512)),
20
+ transforms.ToTensor()
21
+ ])
22
+
23
+ def load_image(img):
24
+ image = img.convert('RGB')
25
+ return transform(image).unsqueeze(0).to(device)
26
+
27
+ # 🔧 NST Core Classes
28
+ class Normalization(nn.Module):
29
+ def __init__(self, mean, std):
30
+ super().__init__()
31
+ self.mean = mean.view(-1, 1, 1)
32
+ self.std = std.view(-1, 1, 1)
33
+ def forward(self, img):
34
+ return (img - self.mean) / self.std
35
+
36
+ class ContentLoss(nn.Module):
37
+ def __init__(self, target):
38
+ super().__init__()
39
+ self.target = target.detach()
40
+ self.loss = 0
41
+ def forward(self, input):
42
+ self.loss = nn.functional.mse_loss(input, self.target)
43
+ return input
44
+
45
+ def gram_matrix(input):
46
+ b, c, h, w = input.size()
47
+ features = input.view(c, h * w)
48
+ G = torch.mm(features, features.t())
49
+ return G.div(c * h * w)
50
+
51
+ class StyleLoss(nn.Module):
52
+ def __init__(self, target_feature):
53
+ super().__init__()
54
+ self.target = gram_matrix(target_feature).detach()
55
+ self.loss = 0
56
+ def forward(self, input):
57
+ G = gram_matrix(input)
58
+ self.loss = nn.functional.mse_loss(G, self.target)
59
+ return input
60
+
61
+ # 🧠 Model builder
62
+ def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img):
63
+ norm = Normalization(norm_mean, norm_std).to(device)
64
+ model = nn.Sequential(norm)
65
+
66
+ content_losses, style_losses = [], []
67
+ i = 0
68
+ for layer in cnn.children():
69
+ name = None
70
+ if isinstance(layer, nn.Conv2d):
71
+ i += 1
72
+ name = f"conv_{i}"
73
+ elif isinstance(layer, nn.ReLU):
74
+ name = f"relu_{i}"
75
+ layer = nn.ReLU(inplace=False)
76
+ elif isinstance(layer, nn.MaxPool2d):
77
+ name = f"pool_{i}"
78
+ elif isinstance(layer, nn.BatchNorm2d):
79
+ name = f"bn_{i}"
80
+ if name:
81
+ model.add_module(name, layer)
82
+ if name == "conv_4":
83
+ target = model(content_img).detach()
84
+ content_loss = ContentLoss(target)
85
+ model.add_module(f"content_loss_{i}", content_loss)
86
+ content_losses.append(content_loss)
87
+ if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]:
88
+ target_feature = model(style_img).detach()
89
+ style_loss = StyleLoss(target_feature)
90
+ model.add_module(f"style_loss_{i}", style_loss)
91
+ style_losses.append(style_loss)
92
+
93
+ for j in range(len(model) - 1, -1, -1):
94
+ if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
95
+ break
96
+ return model[:j + 1], style_losses, content_losses
97
+
98
+ # 🎲 Random selector from Hugging Face dataset
99
+ def get_random_image_pair():
100
+ ds = load_dataset("heramb04/Famous-paintings", split="train")
101
+ samples = random.sample(list(ds), 2)
102
+ imgs = [Image.open(BytesIO(sample["image"]["bytes"])).convert("RGB") for sample in samples]
103
+ return imgs[0], imgs[1]
104
+
105
+ # 🖌️ NST logic
106
+ def run_nst(content_pil, style_pil, steps=300):
107
+ content = load_image(content_pil)
108
+ style = load_image(style_pil)
109
+ input_img = content.clone().requires_grad_(True)
110
+
111
+ cnn = models.vgg19(pretrained=True).features.to(device).eval()
112
+ norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
113
+ norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
114
+
115
+ model, style_losses, content_losses = get_model_losses(cnn, norm_mean, norm_std, style, content)
116
+ optimizer = optim.LBFGS([input_img])
117
+ run = [0]
118
+
119
+ while run[0] <= steps:
120
+ def closure():
121
+ input_img.data.clamp_(0, 1)
122
+ optimizer.zero_grad()
123
+ model(input_img)
124
+ style_score = sum(sl.loss for sl in style_losses)
125
+ content_score = sum(cl.loss for cl in content_losses)
126
+ loss = content_score + 1e6 * style_score
127
+ loss.backward()
128
+ run[0] += 1
129
+ return loss
130
+ optimizer.step(closure)
131
+
132
+ output = input_img.clone().detach().cpu().squeeze(0)
133
+ return transforms.ToPILImage()(output)
134
+
135
+ # 🎛️ Gradio UI
136
+ with gr.Blocks(title="Neural Style Transfer — A + B = C") as demo:
137
+ gr.Markdown("## 🎨 Neural Style Transfer<br>Upload two images OR pick random paintings to remix")
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ content_input = gr.Image(label="🖼️ Content Image", type="pil")
142
+ style_input = gr.Image(label="🎨 Style Image", type="pil")
143
+ steps_slider = gr.Slider(100, 500, value=300, step=50, label="Optimization Steps")
144
+ upload_button = gr.Button("✨ Stylize Uploaded Images")
145
+ random_button = gr.Button("🎲 Pick Random & Generate")
146
+
147
+ with gr.Column():
148
+ gr.Markdown("### 🧠 A + B = C")
149
+ content_preview = gr.Image(label="A: Content", interactive=False)
150
+ style_preview = gr.Image(label="B: Style", interactive=False)
151
+ output_preview = gr.Image(label="C: Stylized Output", interactive=False)
152
+
153
+ upload_button.click(
154
+ fn=run_nst,
155
+ inputs=[content_input, style_input, steps_slider],
156
+ outputs=output_preview
157
+ )
158
+
159
+ def random_nst_wrapper(steps):
160
+ content_img, style_img = get_random_image_pair()
161
+ result = run_nst(content_img, style_img, steps)
162
+ return content_img, style_img, result
163
+
164
+ random_button.click(
165
+ fn=random_nst_wrapper,
166
+ inputs=[steps_slider],
167
+ outputs=[content_preview, style_preview, output_preview]
168
+ )
169
+
170
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow
5
+ datasets
6
+ matplotlib
style_transfer.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.transforms as transforms
5
+ import torchvision.models as models
6
+ from PIL import Image
7
+ import torchvision.transforms.functional as TF
8
+
9
+ # 🚀 Device configuration
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # 🔧 Preprocessing
13
+ transform = transforms.Compose([
14
+ transforms.Resize((512, 512)),
15
+ transforms.ToTensor()
16
+ ])
17
+
18
+ def load_image(img):
19
+ image = img.convert("RGB")
20
+ return transform(image).unsqueeze(0).to(device)
21
+
22
+ # 🎯 Loss modules
23
+ class Normalization(nn.Module):
24
+ def __init__(self, mean, std):
25
+ super().__init__()
26
+ self.mean = mean.view(-1, 1, 1)
27
+ self.std = std.view(-1, 1, 1)
28
+ def forward(self, img):
29
+ return (img - self.mean) / self.std
30
+
31
+ class ContentLoss(nn.Module):
32
+ def __init__(self, target):
33
+ super().__init__()
34
+ self.target = target.detach()
35
+ self.loss = 0
36
+ def forward(self, input):
37
+ self.loss = nn.functional.mse_loss(input, self.target)
38
+ return input
39
+
40
+ def gram_matrix(input):
41
+ b, c, h, w = input.size()
42
+ features = input.view(c, h * w)
43
+ G = torch.mm(features, features.t())
44
+ return G.div(c * h * w)
45
+
46
+ class StyleLoss(nn.Module):
47
+ def __init__(self, target_feature):
48
+ super().__init__()
49
+ self.target = gram_matrix(target_feature).detach()
50
+ self.loss = 0
51
+ def forward(self, input):
52
+ G = gram_matrix(input)
53
+ self.loss = nn.functional.mse_loss(G, self.target)
54
+ return input
55
+
56
+ # 🧬 Model builder
57
+ def get_model_losses(cnn, norm_mean, norm_std, style_img, content_img):
58
+ normalization = Normalization(norm_mean, norm_std).to(device)
59
+ model = nn.Sequential(normalization)
60
+
61
+ content_losses = []
62
+ style_losses = []
63
+
64
+ i = 0
65
+ for layer in cnn.children():
66
+ name = None
67
+ if isinstance(layer, nn.Conv2d):
68
+ i += 1
69
+ name = f"conv_{i}"
70
+ elif isinstance(layer, nn.ReLU):
71
+ name = f"relu_{i}"
72
+ layer = nn.ReLU(inplace=False)
73
+ elif isinstance(layer, nn.MaxPool2d):
74
+ name = f"pool_{i}"
75
+ elif isinstance(layer, nn.BatchNorm2d):
76
+ name = f"bn_{i}"
77
+ if name:
78
+ model.add_module(name, layer)
79
+ if name == "conv_4":
80
+ target = model(content_img).detach()
81
+ content_loss = ContentLoss(target)
82
+ model.add_module(f"content_loss_{i}", content_loss)
83
+ content_losses.append(content_loss)
84
+ if name in ["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"]:
85
+ target_feature = model(style_img).detach()
86
+ style_loss = StyleLoss(target_feature)
87
+ model.add_module(f"style_loss_{i}", style_loss)
88
+ style_losses.append(style_loss)
89
+
90
+ for j in range(len(model) - 1, -1, -1):
91
+ if isinstance(model[j], ContentLoss) or isinstance(model[j], StyleLoss):
92
+ break
93
+
94
+ return model[:j + 1], style_losses, content_losses
95
+
96
+ # ✨ Stylization pipeline
97
+ def run_nst(content_pil, style_pil, steps=300):
98
+ content = load_image(content_pil)
99
+ style = load_image(style_pil)
100
+ input_img = content.clone().requires_grad_(True)
101
+
102
+ cnn = models.vgg19(pretrained=True).features.to(device).eval()
103
+ norm_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
104
+ norm_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
105
+
106
+ model, style_losses, content_losses = get_model_losses(
107
+ cnn, norm_mean, norm_std, style, content
108
+ )
109
+
110
+ optimizer = optim.LBFGS([input_img])
111
+ run = [0]
112
+
113
+ while run[0] <= steps:
114
+ def closure():
115
+ input_img.data.clamp_(0, 1)
116
+ optimizer.zero_grad()
117
+ model(input_img)
118
+
119
+ style_score = sum(sl.loss for sl in style_losses)
120
+ content_score = sum(cl.loss for cl in content_losses)
121
+
122
+ loss = content_score + 1e6 * style_score
123
+ loss.backward()
124
+
125
+ run[0] += 1
126
+ return loss
127
+ optimizer.step(closure)
128
+
129
+ output = input_img.clone().detach().cpu().squeeze(0)
130
+ return TF.to_pil_image(output)
utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+
7
+ # 📦 Resize + convert image
8
+ def resize_image(image, size=(512, 512)):
9
+ return image.convert("RGB").resize(size)
10
+
11
+ # 📂 Save output with timestamp
12
+ def save_output(image, save_dir="output", prefix="stylized"):
13
+ os.makedirs(save_dir, exist_ok=True)
14
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
15
+ filename = f"{prefix}_{timestamp}.jpg"
16
+ path = os.path.join(save_dir, filename)
17
+ image.save(path)
18
+ return path
19
+
20
+ # 📜 Log content + style pairing
21
+ def log_run(content_name, style_name, log_path="log.txt"):
22
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
23
+ line = f"{timestamp} | Content: {content_name} + Style: {style_name} = Stylized\n"
24
+ with open(log_path, "a") as f:
25
+ f.write(line)
26
+
27
+ # 🖼️ Triptych visualizer: A + B = C
28
+ def show_triptych(content_img, style_img, output_img):
29
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
30
+ axs[0].imshow(content_img); axs[0].set_title("A: Content"); axs[0].axis('off')
31
+ axs[1].imshow(style_img); axs[1].set_title("B: Style"); axs[1].axis('off')
32
+ axs[2].imshow(output_img); axs[2].set_title("C: Output"); axs[2].axis('off')
33
+ plt.suptitle("A + B = C", fontsize=20)
34
+ plt.show()