mgreg555 commited on
Commit
34c931d
·
1 Parent(s): 0184996

Create app.py

Browse files

first commit to main

Files changed (1) hide show
  1. app.py +522 -0
app.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Generalas.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1xauVtAe5BhUoYH2JItzQtqvFTDxKDSyZ
8
+ """
9
+
10
+ !gdown https://drive.google.com/uc?id=1-1pfFJoxzU6iYsGBmVclJA1hlXHjj-8B -O /content/model.pth
11
+
12
+ # Commented out IPython magic to ensure Python compatibility.
13
+ !pip install -q -U einops datasets matplotlib tqdm
14
+
15
+ import math
16
+ from inspect import isfunction
17
+ from functools import partial
18
+
19
+ # %matplotlib inline
20
+ import matplotlib.pyplot as plt
21
+ from tqdm.auto import tqdm
22
+ from einops import rearrange
23
+
24
+ import torch
25
+ from torch import nn, einsum
26
+ import torch.nn.functional as F
27
+ import numpy as np
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ def exists(x):
32
+ return x is not None
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+ class Residual(nn.Module):
40
+ def __init__(self, fn):
41
+ super().__init__()
42
+ self.fn = fn
43
+
44
+ def forward(self, x, *args, **kwargs):
45
+ return self.fn(x, *args, **kwargs) + x
46
+
47
+ def Upsample(dim):
48
+ return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
49
+
50
+ def Downsample(dim):
51
+ return nn.Conv2d(dim, dim, 4, 2, 1)
52
+
53
+ class SinusoidalPositionEmbeddings(nn.Module):
54
+ def __init__(self, dim):
55
+ super().__init__()
56
+ self.dim = dim
57
+
58
+ def forward(self, time):
59
+ device = time.device
60
+ half_dim = self.dim // 2
61
+ embeddings = math.log(10000) / (half_dim - 1)
62
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
63
+ embeddings = time[:, None] * embeddings[None, :]
64
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
65
+ return embeddings
66
+
67
+ #ITT
68
+
69
+ class Block(nn.Module):
70
+ def __init__(self, dim, dim_out, groups = 8):
71
+ super().__init__()
72
+ self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
73
+ self.norm = nn.GroupNorm(groups, dim_out)
74
+ self.act = nn.SiLU()
75
+
76
+ def forward(self, x, scale_shift = None):
77
+ x = self.proj(x)
78
+ x = self.norm(x)
79
+
80
+ if exists(scale_shift):
81
+ scale, shift = scale_shift
82
+ x = x * (scale + 1) + shift
83
+
84
+ x = self.act(x)
85
+ return x
86
+
87
+ class ResnetBlock(nn.Module):
88
+ """https://arxiv.org/abs/1512.03385"""
89
+
90
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
91
+ super().__init__()
92
+ self.mlp = (
93
+ nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
94
+ if exists(time_emb_dim)
95
+ else None
96
+ )
97
+
98
+ self.block1 = Block(dim, dim_out, groups=groups)
99
+ self.block2 = Block(dim_out, dim_out, groups=groups)
100
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
101
+
102
+ def forward(self, x, time_emb=None):
103
+ h = self.block1(x)
104
+
105
+ if exists(self.mlp) and exists(time_emb):
106
+ time_emb = self.mlp(time_emb)
107
+ h = rearrange(time_emb, "b c -> b c 1 1") + h
108
+
109
+ h = self.block2(h)
110
+ return h + self.res_conv(x)
111
+
112
+ class ConvNextBlock(nn.Module):
113
+ """https://arxiv.org/abs/2201.03545"""
114
+
115
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
116
+ super().__init__()
117
+ self.mlp = (
118
+ nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
119
+ if exists(time_emb_dim)
120
+ else None
121
+ )
122
+
123
+ self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
124
+
125
+ self.net = nn.Sequential(
126
+ nn.GroupNorm(1, dim) if norm else nn.Identity(),
127
+ nn.Conv2d(dim, dim_out * mult, 3, padding=1),
128
+ nn.GELU(),
129
+ nn.GroupNorm(1, dim_out * mult),
130
+ nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
131
+ )
132
+
133
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
134
+
135
+ def forward(self, x, time_emb=None):
136
+ h = self.ds_conv(x)
137
+
138
+ if exists(self.mlp) and exists(time_emb):
139
+ assert exists(time_emb), "time embedding must be passed in"
140
+ condition = self.mlp(time_emb)
141
+ h = h + rearrange(condition, "b c -> b c 1 1")
142
+
143
+ h = self.net(h)
144
+ return h + self.res_conv(x)
145
+
146
+ class Attention(nn.Module):
147
+ def __init__(self, dim, heads=4, dim_head=32):
148
+ super().__init__()
149
+ self.scale = dim_head**-0.5
150
+ self.heads = heads
151
+ hidden_dim = dim_head * heads
152
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
153
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
154
+
155
+ def forward(self, x):
156
+ b, c, h, w = x.shape
157
+ qkv = self.to_qkv(x).chunk(3, dim=1)
158
+ q, k, v = map(
159
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
160
+ )
161
+ q = q * self.scale
162
+
163
+ sim = einsum("b h d i, b h d j -> b h i j", q, k)
164
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
165
+ attn = sim.softmax(dim=-1)
166
+
167
+ out = einsum("b h i j, b h d j -> b h i d", attn, v)
168
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
169
+ return self.to_out(out)
170
+
171
+ class LinearAttention(nn.Module):
172
+ def __init__(self, dim, heads=4, dim_head=32):
173
+ super().__init__()
174
+ self.scale = dim_head**-0.5
175
+ self.heads = heads
176
+ hidden_dim = dim_head * heads
177
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
178
+
179
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
180
+ nn.GroupNorm(1, dim))
181
+
182
+ def forward(self, x):
183
+ b, c, h, w = x.shape
184
+ qkv = self.to_qkv(x).chunk(3, dim=1)
185
+ q, k, v = map(
186
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
187
+ )
188
+
189
+ q = q.softmax(dim=-2)
190
+ k = k.softmax(dim=-1)
191
+
192
+ q = q * self.scale
193
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
194
+
195
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
196
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
197
+ return self.to_out(out)
198
+
199
+ class PreNorm(nn.Module):
200
+ def __init__(self, dim, fn):
201
+ super().__init__()
202
+ self.fn = fn
203
+ self.norm = nn.GroupNorm(1, dim)
204
+
205
+ def forward(self, x):
206
+ x = self.norm(x)
207
+ return self.fn(x)
208
+
209
+ class Unet(nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim,
213
+ init_dim=None,
214
+ out_dim=None,
215
+ dim_mults=(1, 2, 4, 8),
216
+ channels=3,
217
+ with_time_emb=True,
218
+ resnet_block_groups=8,
219
+ use_convnext=True,
220
+ convnext_mult=2,
221
+ ):
222
+ super().__init__()
223
+
224
+ # determine dimensions
225
+ self.channels = channels
226
+
227
+ init_dim = default(init_dim, dim // 3 * 2)
228
+ self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
229
+
230
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
231
+ in_out = list(zip(dims[:-1], dims[1:]))
232
+
233
+ if use_convnext:
234
+ block_klass = partial(ConvNextBlock, mult=convnext_mult)
235
+ else:
236
+ block_klass = partial(ResnetBlock, groups=resnet_block_groups)
237
+
238
+ # time embeddings
239
+ if with_time_emb:
240
+ time_dim = dim * 4
241
+ self.time_mlp = nn.Sequential(
242
+ SinusoidalPositionEmbeddings(dim),
243
+ nn.Linear(dim, time_dim),
244
+ nn.GELU(),
245
+ nn.Linear(time_dim, time_dim),
246
+ )
247
+ else:
248
+ time_dim = None
249
+ self.time_mlp = None
250
+
251
+ # layers
252
+ self.downs = nn.ModuleList([])
253
+ self.ups = nn.ModuleList([])
254
+ num_resolutions = len(in_out)
255
+
256
+ for ind, (dim_in, dim_out) in enumerate(in_out):
257
+ is_last = ind >= (num_resolutions - 1)
258
+
259
+ self.downs.append(
260
+ nn.ModuleList(
261
+ [
262
+ block_klass(dim_in, dim_out, time_emb_dim=time_dim),
263
+ block_klass(dim_out, dim_out, time_emb_dim=time_dim),
264
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))),
265
+ Downsample(dim_out) if not is_last else nn.Identity(),
266
+ ]
267
+ )
268
+ )
269
+
270
+ mid_dim = dims[-1]
271
+ self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
272
+ self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
273
+ self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
274
+
275
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
276
+ is_last = ind >= (num_resolutions - 1)
277
+
278
+ self.ups.append(
279
+ nn.ModuleList(
280
+ [
281
+ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
282
+ block_klass(dim_in, dim_in, time_emb_dim=time_dim),
283
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))),
284
+ Upsample(dim_in) if not is_last else nn.Identity(),
285
+ ]
286
+ )
287
+ )
288
+
289
+ out_dim = default(out_dim, channels)
290
+ self.final_conv = nn.Sequential(
291
+ block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
292
+ )
293
+
294
+ def forward(self, x, time):
295
+ x = self.init_conv(x)
296
+
297
+ t = self.time_mlp(time) if exists(self.time_mlp) else None
298
+
299
+ h = []
300
+
301
+ # downsample
302
+ for block1, block2, attn, downsample in self.downs:
303
+ x = block1(x, t)
304
+ x = block2(x, t)
305
+ x = attn(x)
306
+ h.append(x)
307
+ x = downsample(x)
308
+
309
+ # bottleneck
310
+ x = self.mid_block1(x, t)
311
+ x = self.mid_attn(x)
312
+ x = self.mid_block2(x, t)
313
+
314
+ # upsample
315
+ for block1, block2, attn, upsample in self.ups:
316
+ x = torch.cat((x, h.pop()), dim=1)
317
+ x = block1(x, t)
318
+ x = block2(x, t)
319
+ x = attn(x)
320
+ x = upsample(x)
321
+
322
+ return self.final_conv(x)
323
+
324
+ image_size = 64
325
+ channels = 3
326
+ batch_size = 32
327
+
328
+ best_model = Unet(
329
+ dim=image_size,
330
+ channels=channels,
331
+ dim_mults=(1, 2, 4, 8)
332
+ )
333
+ best_model.load_state_dict(torch.load(str("/content/model.pth")))
334
+ best_model.to(device)
335
+
336
+ def cosine_beta_schedule(timesteps, s=0.008):
337
+ """
338
+ cosine schedule as proposed in https://arxiv.org/abs/2102.09672
339
+ """
340
+ steps = timesteps + 1
341
+ x = torch.linspace(0, timesteps, steps)
342
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
343
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
344
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
345
+ return torch.clip(betas, 0.0001, 0.9999)
346
+
347
+ def linear_beta_schedule(timesteps):
348
+ beta_start = 0.0001
349
+ beta_end = 0.02
350
+ return torch.linspace(beta_start, beta_end, timesteps)
351
+
352
+ def quadratic_beta_schedule(timesteps):
353
+ beta_start = 0.0001
354
+ beta_end = 0.02
355
+ return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
356
+
357
+ def sigmoid_beta_schedule(timesteps):
358
+ beta_start = 0.0001
359
+ beta_end = 0.02
360
+ betas = torch.linspace(-6, 6, timesteps)
361
+ return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
362
+
363
+ timesteps = 200
364
+
365
+ # define beta schedule
366
+ betas = linear_beta_schedule(timesteps=timesteps)
367
+
368
+ # define alphas
369
+ alphas = 1. - betas
370
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
371
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
372
+ sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
373
+
374
+ # calculations for diffusion q(x_t | x_{t-1}) and others
375
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
376
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
377
+
378
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
379
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
380
+
381
+ def extract(a, t, x_shape):
382
+ batch_size = t.shape[0]
383
+ out = a.gather(-1, t.cpu())
384
+
385
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
386
+
387
+ # forward diffusion
388
+ def q_sample(x_start, t, noise=None):
389
+ if noise is None:
390
+ noise = torch.randn_like(x_start)
391
+
392
+ sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
393
+ sqrt_one_minus_alphas_cumprod_t = extract(
394
+ sqrt_one_minus_alphas_cumprod, t, x_start.shape
395
+ )
396
+
397
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
398
+
399
+ # def get_noisy_image(x_start, t):
400
+ # # add noise
401
+ # x_noisy = q_sample(x_start, t=t)
402
+
403
+ # # turn back into PIL image
404
+ # noisy_image = reverse_transform(x_noisy.squeeze())
405
+
406
+ # return noisy_image
407
+
408
+ import matplotlib.pyplot as plt
409
+
410
+ # use seed for reproducability
411
+ torch.manual_seed(0)
412
+
413
+ # source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
414
+ def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
415
+ if not isinstance(imgs[0], list):
416
+ # Make a 2d grid even if there's just 1 row
417
+ imgs = [imgs]
418
+
419
+ num_rows = len(imgs)
420
+ num_cols = len(imgs[0]) + with_orig
421
+ fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
422
+ for row_idx, row in enumerate(imgs):
423
+ row = [image] + row if with_orig else row
424
+ for col_idx, img in enumerate(row):
425
+ ax = axs[row_idx, col_idx]
426
+ ax.imshow(np.asarray(img), **imshow_kwargs)
427
+ ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
428
+
429
+ if with_orig:
430
+ axs[0, 0].set(title='Original image')
431
+ axs[0, 0].title.set_size(8)
432
+ if row_title is not None:
433
+ for row_idx in range(num_rows):
434
+ axs[row_idx, 0].set(ylabel=row_title[row_idx])
435
+
436
+ plt.tight_layout()
437
+
438
+ def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
439
+ if noise is None:
440
+ noise = torch.randn_like(x_start)
441
+
442
+ x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
443
+ predicted_noise = denoise_model(x_noisy, t)
444
+
445
+ if loss_type == 'l1':
446
+ loss = F.l1_loss(noise, predicted_noise)
447
+ elif loss_type == 'l2':
448
+ loss = F.mse_loss(noise, predicted_noise)
449
+ elif loss_type == "huber":
450
+ loss = F.smooth_l1_loss(noise, predicted_noise)
451
+ else:
452
+ raise NotImplementedError()
453
+
454
+ return loss
455
+
456
+ @torch.no_grad()
457
+ def p_sample(model, x, t, t_index):
458
+ betas_t = extract(betas, t, x.shape)
459
+ sqrt_one_minus_alphas_cumprod_t = extract(
460
+ sqrt_one_minus_alphas_cumprod, t, x.shape
461
+ )
462
+ sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
463
+
464
+ # Equation 11 in the paper
465
+ # Use our model (noise predictor) to predict the mean
466
+ model_mean = sqrt_recip_alphas_t * (
467
+ x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
468
+ )
469
+
470
+ if t_index == 0:
471
+ return model_mean
472
+ else:
473
+ posterior_variance_t = extract(posterior_variance, t, x.shape)
474
+ noise = torch.randn_like(x)
475
+ # Algorithm 2 line 4:
476
+ return model_mean + torch.sqrt(posterior_variance_t) * noise
477
+
478
+ # Algorithm 2 but save all images:
479
+ @torch.no_grad()
480
+ def p_sample_loop(model, shape):
481
+ device = next(model.parameters()).device
482
+
483
+ b = shape[0]
484
+ # start from pure noise (for each example in the batch)
485
+ img = torch.randn(shape, device=device)
486
+ imgs = []
487
+
488
+ for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
489
+ img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
490
+ imgs.append(img.cpu().numpy())
491
+ return imgs
492
+
493
+ @torch.no_grad()
494
+ def sample(model, image_size, batch_size=16, channels=3):
495
+ return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
496
+
497
+ # sample 64 images
498
+ sample_size = 64
499
+ samples = sample(best_model, image_size=image_size, batch_size=sample_size, channels=channels)
500
+
501
+ """UI"""
502
+
503
+ !pip install typing-extensions==3.7.4
504
+
505
+ !pip install gradio
506
+
507
+ import gradio as gr
508
+
509
+ def show_picture(random_index):
510
+ image=(samples[-1][random_index].transpose(1, 2, 0) + 1.0) / 2.0
511
+ clipped_image = np.clip(image, 0.0, 1.0)
512
+ return clipped_image
513
+
514
+ demo = gr.Interface(fn=show_picture, inputs=gr.Slider(minimum=0, maximum=sample_size, step=1), outputs="image")
515
+
516
+ if __name__ == "__main__":
517
+ demo.launch(show_api=False, share=True)
518
+
519
+ # random_index = 7
520
+ # image=(samples[-1][random_index].transpose(1, 2, 0) + 1.0) / 2.0
521
+ # clipped_image = np.clip(image, 0.0, 1.0)
522
+ # plt.imshow(clipped_image) #clip/clamp