himanshu-skid19 commited on
Commit
af99973
·
1 Parent(s): 411bad7

Update app.py

Browse files

it was not redundant...

Files changed (1) hide show
  1. app.py +108 -0
app.py CHANGED
@@ -7,6 +7,114 @@ import tensorflow as tf
7
  import math
8
  import torch.nn.functional as F
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def linear_beta_schedule(timesteps):
12
  beta_start = 0.0001
 
7
  import math
8
  import torch.nn.functional as F
9
 
10
+ from torch import nn
11
+
12
+
13
+
14
+ class Block(nn.Module):
15
+ def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
16
+ super().__init__()
17
+ self.time_mlp = nn.Linear(time_emb_dim, out_ch)
18
+ if up:
19
+ self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
20
+ self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
21
+ self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear')
22
+
23
+ else:
24
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
25
+ self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
26
+ self.maxpool = nn.MaxPool2d(4, 2, 1)
27
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
28
+ self.bnorm1 = nn.BatchNorm2d(out_ch)
29
+ self.bnorm2 = nn.BatchNorm2d(out_ch)
30
+ self.silu = nn.SiLU()
31
+ self.relu = nn.ReLU()
32
+
33
+ def forward(self, x, t, ):
34
+ # First Conv
35
+ h = (self.silu(self.bnorm1(self.conv1(x))))
36
+ # Time embedding
37
+ time_emb = self.relu(self.time_mlp(t))
38
+ # Extend last 2 dimensions
39
+ time_emb = time_emb[(..., ) + (None, ) * 2]
40
+ # Add time channel
41
+ h = h + time_emb
42
+ # Second Conv
43
+ h = (self.silu(self.bnorm2(self.conv2(h))))
44
+ # Down or Upsample
45
+ return self.transform(h)
46
+
47
+
48
+ class SinusoidalPositionEmbeddings(nn.Module):
49
+ def __init__(self, dim):
50
+ super().__init__()
51
+ self.dim = dim
52
+
53
+ def forward(self, time):
54
+ device = time.device
55
+ half_dim = self.dim // 2
56
+ embeddings = math.log(10000) / (half_dim - 1)
57
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
58
+ embeddings = time[:, None] * embeddings[None, :]
59
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
60
+ # TODO: Double check the ordering here
61
+ return embeddings
62
+
63
+
64
+ class SimpleUnet(nn.Module):
65
+ """
66
+ A simplified variant of the Unet architecture.
67
+ """
68
+ def __init__(self):
69
+ super().__init__()
70
+ image_channels = 3
71
+ down_channels = (32, 64, 128, 256, 512)
72
+ up_channels = (512, 256, 128, 64, 32)
73
+ out_dim = 3
74
+ time_emb_dim = 32
75
+
76
+ # Time embedding
77
+ self.time_mlp = nn.Sequential(
78
+ SinusoidalPositionEmbeddings(time_emb_dim),
79
+ nn.Linear(time_emb_dim, time_emb_dim),
80
+ nn.ReLU()
81
+ )
82
+
83
+ # Initial projection
84
+ self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
85
+
86
+ # Downsample
87
+ self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
88
+ time_emb_dim) \
89
+ for i in range(len(down_channels)-1)])
90
+ # Upsample
91
+ self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
92
+ time_emb_dim, up=True) \
93
+ for i in range(len(up_channels)-1)])
94
+
95
+ # Edit: Corrected a bug found by Jakub C (see YouTube comment)
96
+ self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
97
+
98
+ def forward(self, x, timestep):
99
+ # Embedd time
100
+ t = self.time_mlp(timestep)
101
+ # Initial conv
102
+ x = self.conv0(x)
103
+ # Unet
104
+ residual_inputs = []
105
+ for down in self.downs:
106
+ x = down(x, t)
107
+ residual_inputs.append(x)
108
+ for up in self.ups:
109
+ residual_x = residual_inputs.pop()
110
+ # Add residual x as additional channels
111
+ x = torch.cat((x, residual_x), dim=1)
112
+ x = up(x, t)
113
+ return self.output(x)
114
+
115
+ model = SimpleUnet()
116
+ print("Num params: ", sum(p.numel() for p in model.parameters()))
117
+ model
118
 
119
  def linear_beta_schedule(timesteps):
120
  beta_start = 0.0001