waleed-12 commited on
Commit
48074a3
Β·
verified Β·
1 Parent(s): 9a1f4d7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +345 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py β€” DDPM Image Generation Demo
3
+ Deploy on Hugging Face Spaces (SDK: gradio)
4
+
5
+ Repository structure expected:
6
+ .
7
+ β”œβ”€β”€ app.py ← this file
8
+ β”œβ”€β”€ requirements.txt
9
+ └── ddpm_model.pth ← your trained weights (upload via git-lfs)
10
+ """
11
+
12
+ import math
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+ import torchvision.utils as vutils
19
+ import gradio as gr
20
+
21
+ # ──────────────────────────────────────────────────────────────
22
+ # 1. CONFIGURATION (must match your training config exactly)
23
+ # ──────────────────────────────────────────────────────────────
24
+ IMG_SIZE = 128 # change to 256 if you trained at 256
25
+ BASE_CHANNELS = 64
26
+ TIME_EMB_DIM = 256
27
+ T = 300 # total diffusion timesteps
28
+ BETA_START = 1e-4
29
+ BETA_END = 0.02
30
+ MODEL_PATH = "ddpm_model.pth"
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+
34
+ # ──────────────────────────────────────────────────────────────
35
+ # 2. MODEL ARCHITECTURE (identical to training notebook)
36
+ # ──────────────────────────────────────────────────────────────
37
+
38
+ class SinusoidalTimeEmbedding(nn.Module):
39
+ """
40
+ Encodes integer timestep t into a fixed-dimensional vector using
41
+ sine / cosine positional encoding, then projects it through an MLP.
42
+ """
43
+ def __init__(self, dim: int):
44
+ super().__init__()
45
+ self.dim = dim
46
+ self.mlp = nn.Sequential(
47
+ nn.Linear(dim, dim * 4),
48
+ nn.SiLU(),
49
+ nn.Linear(dim * 4, dim),
50
+ )
51
+
52
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
53
+ half = self.dim // 2
54
+ freq = torch.exp(
55
+ -math.log(10_000) * torch.arange(half, device=t.device) / (half - 1)
56
+ )
57
+ args = t[:, None].float() * freq[None, :]
58
+ emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
59
+ return self.mlp(emb)
60
+
61
+
62
+ class ResidualBlock(nn.Module):
63
+ """Conv residual block with time-embedding injection (scale + shift)."""
64
+
65
+ def __init__(self, in_ch: int, out_ch: int, time_emb_dim: int,
66
+ groups: int = 8, dropout: float = 0.1):
67
+ super().__init__()
68
+ self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch * 2))
69
+ self.norm1 = nn.GroupNorm(groups, in_ch)
70
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
71
+ self.norm2 = nn.GroupNorm(groups, out_ch)
72
+ self.dropout = nn.Dropout(dropout)
73
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
74
+ self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
75
+
76
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
77
+ h = self.conv1(F.silu(self.norm1(x)))
78
+ scale, shift = self.time_proj(t_emb).chunk(2, dim=-1)
79
+ h = h * (scale[:, :, None, None] + 1) + shift[:, :, None, None]
80
+ h = self.conv2(self.dropout(F.silu(self.norm2(h))))
81
+ return h + self.shortcut(x)
82
+
83
+
84
+ class Downsample(nn.Module):
85
+ """Halves spatial resolution via strided convolution."""
86
+ def __init__(self, channels: int):
87
+ super().__init__()
88
+ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
89
+
90
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ return self.conv(x)
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ """Doubles spatial resolution via nearest-neighbour interpolation + conv."""
96
+ def __init__(self, channels: int):
97
+ super().__init__()
98
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
99
+
100
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
101
+ return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
102
+
103
+
104
+ class UNet(nn.Module):
105
+ """
106
+ Simplified U-Net for DDPM noise prediction.
107
+ Channel progression: 64 β†’ 128 β†’ 256 (encoder), mirrored in decoder.
108
+ """
109
+
110
+ def __init__(self, in_channels: int = 3,
111
+ base_channels: int = 64,
112
+ time_emb_dim: int = 256):
113
+ super().__init__()
114
+ ch, ch2, ch4 = base_channels, base_channels * 2, base_channels * 4
115
+ T_DIM = time_emb_dim
116
+
117
+ # Time embedding
118
+ self.time_emb = SinusoidalTimeEmbedding(T_DIM)
119
+ self.init_conv = nn.Conv2d(in_channels, ch, 3, padding=1)
120
+
121
+ # Encoder
122
+ self.enc1_res1 = ResidualBlock(ch, ch, T_DIM)
123
+ self.enc1_res2 = ResidualBlock(ch, ch, T_DIM)
124
+ self.down1 = Downsample(ch)
125
+
126
+ self.enc2_res1 = ResidualBlock(ch, ch2, T_DIM)
127
+ self.enc2_res2 = ResidualBlock(ch2, ch2, T_DIM)
128
+ self.down2 = Downsample(ch2)
129
+
130
+ self.enc3_res1 = ResidualBlock(ch2, ch4, T_DIM)
131
+ self.enc3_res2 = ResidualBlock(ch4, ch4, T_DIM)
132
+ self.down3 = Downsample(ch4)
133
+
134
+ # Bottleneck
135
+ self.mid_res1 = ResidualBlock(ch4, ch4, T_DIM)
136
+ self.mid_res2 = ResidualBlock(ch4, ch4, T_DIM)
137
+
138
+ # Decoder
139
+ self.up3 = Upsample(ch4)
140
+ self.dec3_res1 = ResidualBlock(ch4 + ch4, ch4, T_DIM)
141
+ self.dec3_res2 = ResidualBlock(ch4, ch4, T_DIM)
142
+
143
+ self.up2 = Upsample(ch4)
144
+ self.dec2_res1 = ResidualBlock(ch4 + ch2, ch2, T_DIM)
145
+ self.dec2_res2 = ResidualBlock(ch2, ch2, T_DIM)
146
+
147
+ self.up1 = Upsample(ch2)
148
+ self.dec1_res1 = ResidualBlock(ch2 + ch, ch, T_DIM)
149
+ self.dec1_res2 = ResidualBlock(ch, ch, T_DIM)
150
+
151
+ # Output
152
+ self.out_norm = nn.GroupNorm(8, ch)
153
+ self.out_conv = nn.Conv2d(ch, in_channels, 1)
154
+
155
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
156
+ t_emb = self.time_emb(t)
157
+
158
+ x0 = self.init_conv(x)
159
+
160
+ e1 = self.enc1_res2(self.enc1_res1(x0, t_emb), t_emb)
161
+ e1d = self.down1(e1)
162
+
163
+ e2 = self.enc2_res2(self.enc2_res1(e1d, t_emb), t_emb)
164
+ e2d = self.down2(e2)
165
+
166
+ e3 = self.enc3_res2(self.enc3_res1(e2d, t_emb), t_emb)
167
+ e3d = self.down3(e3)
168
+
169
+ b = self.mid_res2(self.mid_res1(e3d, t_emb), t_emb)
170
+
171
+ d3 = self.up3(b)
172
+ d3 = self.dec3_res2(self.dec3_res1(torch.cat([d3, e3], dim=1), t_emb), t_emb)
173
+
174
+ d2 = self.up2(d3)
175
+ d2 = self.dec2_res2(self.dec2_res1(torch.cat([d2, e2], dim=1), t_emb), t_emb)
176
+
177
+ d1 = self.up1(d2)
178
+ d1 = self.dec1_res2(self.dec1_res1(torch.cat([d1, e1], dim=1), t_emb), t_emb)
179
+
180
+ return self.out_conv(F.silu(self.out_norm(d1)))
181
+
182
+
183
+ # ──────────────────────────────────────────────────────────────
184
+ # 3. NOISE SCHEDULE (pre-computed tensors on DEVICE)
185
+ # ──────────────────────────────────────────────────────────────
186
+ betas = torch.linspace(BETA_START, BETA_END, T).to(DEVICE)
187
+ alphas = 1.0 - betas
188
+ alpha_hat = torch.cumprod(alphas, dim=0)
189
+ sqrt_1m_ah = torch.sqrt(1.0 - alpha_hat)
190
+
191
+
192
+ # ──────────────────────────────────────────────────────────────
193
+ # 4. LOAD MODEL WEIGHTS
194
+ # ──────────────────────────────────────────────────────────────
195
+ model = UNet(
196
+ in_channels = 3,
197
+ base_channels = BASE_CHANNELS,
198
+ time_emb_dim = TIME_EMB_DIM,
199
+ ).to(DEVICE)
200
+
201
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
202
+
203
+ # Strip DataParallel "module." prefix if present
204
+ if any(k.startswith("module.") for k in state_dict):
205
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
206
+
207
+ model.load_state_dict(state_dict)
208
+ model.eval()
209
+ print(f"[INFO] Model loaded from '{MODEL_PATH}' on {DEVICE}")
210
+
211
+
212
+ # ──────────────────────────────────────────────────────────────
213
+ # 5. HELPER: tensor β†’ PIL
214
+ # ──────────────────────────────────────────────────────────────
215
+ def tensor_to_pil(t: torch.Tensor) -> Image.Image:
216
+ """Convert a (3, H, W) tensor in [-1, 1] to a uint8 PIL image."""
217
+ arr = (
218
+ t.squeeze().cpu().clamp(-1, 1)
219
+ .add(1).div(2) # β†’ [0, 1]
220
+ .mul(255).byte()
221
+ .permute(1, 2, 0) # β†’ (H, W, 3)
222
+ .numpy()
223
+ )
224
+ return Image.fromarray(arr)
225
+
226
+
227
+ # ──────────────────────────────────────────────────────────────
228
+ # 6. GENERATION FUNCTION (called by Gradio)
229
+ # ──────────────────────────────────────────────────────────────
230
+ @torch.no_grad()
231
+ def generate_image(n_vis_steps: int = 7) -> tuple[Image.Image, Image.Image]:
232
+ """
233
+ Run the full DDPM reverse process (T β†’ 0).
234
+
235
+ Args:
236
+ n_vis_steps : how many intermediate frames to show in the
237
+ denoising-steps grid (evenly spaced across T)
238
+ Returns:
239
+ final_pil : PIL image of the final generated output
240
+ steps_pil : PIL image showing the denoising progression grid
241
+ """
242
+ x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=DEVICE)
243
+
244
+ # Timesteps at which we capture intermediate frames
245
+ capture_at = set(
246
+ np.linspace(T - 1, 1, int(n_vis_steps), dtype=int).tolist()
247
+ )
248
+ frames: list[torch.Tensor] = []
249
+
250
+ for t_val in reversed(range(1, T)):
251
+ t_tensor = torch.full((1,), t_val, device=DEVICE, dtype=torch.long)
252
+
253
+ # U-Net predicts the noise at this timestep
254
+ eps_pred = model(x, t_tensor)
255
+
256
+ # DDPM reverse update
257
+ coeff = betas[t_val] / sqrt_1m_ah[t_val]
258
+ mean = (1.0 / torch.sqrt(alphas[t_val])) * (x - coeff * eps_pred)
259
+
260
+ if t_val > 1:
261
+ x = mean + torch.sqrt(betas[t_val]) * torch.randn_like(x)
262
+ else:
263
+ x = mean # final step: no extra noise
264
+
265
+ if t_val in capture_at:
266
+ frames.append(x.clone().cpu())
267
+
268
+ # ── Final generated image ────────────────────────────────
269
+ final_pil = tensor_to_pil(x)
270
+
271
+ # ── Intermediate steps grid ──────────────────────────────
272
+ if frames:
273
+ grid_tensor = torch.cat(frames, dim=0) # (n, 3, H, W)
274
+ grid = vutils.make_grid(
275
+ grid_tensor.clamp(-1, 1),
276
+ nrow = len(frames),
277
+ normalize = True,
278
+ value_range = (-1, 1),
279
+ )
280
+ steps_pil = Image.fromarray(
281
+ (grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
282
+ )
283
+ else:
284
+ steps_pil = final_pil
285
+
286
+ return final_pil, steps_pil
287
+
288
+
289
+ # ──────────────────────────────────────────────────────────────
290
+ # 7. GRADIO INTERFACE
291
+ # ──────────────────────────────────────────────────────────────
292
+ with gr.Blocks(title="DDPM Image Generator", theme=gr.themes.Soft()) as demo:
293
+
294
+ gr.Markdown(
295
+ """
296
+ # πŸ–ΌοΈ DDPM Image Generator
297
+ Generates a **new image from pure Gaussian noise** using a
298
+ Denoising Diffusion Probabilistic Model trained from scratch in PyTorch.
299
+
300
+ Click **Generate** to run the full reverse diffusion process.
301
+ The right panel shows intermediate denoising steps so you can
302
+ watch the image emerge from noise.
303
+ """
304
+ )
305
+
306
+ with gr.Row():
307
+ n_steps_slider = gr.Slider(
308
+ minimum = 4,
309
+ maximum = 12,
310
+ value = 7,
311
+ step = 1,
312
+ label = "Number of intermediate steps to visualise",
313
+ )
314
+
315
+ with gr.Row():
316
+ btn = gr.Button("✨ Generate Image", variant="primary", scale=1)
317
+
318
+ with gr.Row():
319
+ out_final = gr.Image(
320
+ label = "Final Generated Image",
321
+ type = "pil",
322
+ height = IMG_SIZE * 2,
323
+ )
324
+ out_steps = gr.Image(
325
+ label = "Intermediate Denoising Steps (noise β†’ image)",
326
+ type = "pil",
327
+ )
328
+
329
+ btn.click(
330
+ fn = generate_image,
331
+ inputs = [n_steps_slider],
332
+ outputs = [out_final, out_steps],
333
+ )
334
+
335
+ gr.Markdown(
336
+ """
337
+ ---
338
+ **Model:** Custom U-Net (64β†’128β†’256 channels) trained with MSE loss on image noise.
339
+ **Assignment:** Generative AI (AI4009) β€” Spring 2026, NUCES.
340
+ """
341
+ )
342
+
343
+
344
+ if __name__ == "__main__":
345
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Deep learning
3
+ torch
4
+ torchvision
5
+
6
+ # App framework
7
+ gradio
8
+
9
+ # Numerical / image utilities
10
+ numpy
11
+ Pillow
12
+ scikit-image