himanshu-skid19 commited on
Commit
411bad7
·
1 Parent(s): 7e639df

Update app.py

Browse files

removed the redundant part

Files changed (1) hide show
  1. app.py +0 -100
app.py CHANGED
@@ -24,106 +24,6 @@ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
24
  sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
25
  posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
26
 
27
- class Block(nn.Module):
28
- def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
29
- super().__init__()
30
- self.time_mlp = nn.Linear(time_emb_dim, out_ch)
31
- if up:
32
- self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
33
- self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
34
- self.Upsample = nn.Upsample(scale_factor = 2, mode ='bilinear')
35
-
36
- else:
37
- self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
38
- self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
39
- self.maxpool = nn.MaxPool2d(4, 2, 1)
40
- self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
41
- self.bnorm1 = nn.BatchNorm2d(out_ch)
42
- self.bnorm2 = nn.BatchNorm2d(out_ch)
43
- self.silu = nn.SiLU()
44
- self.relu = nn.ReLU()
45
-
46
- def forward(self, x, t, ):
47
- # First Conv
48
- h = (self.silu(self.bnorm1(self.conv1(x))))
49
- # Time embedding
50
- time_emb = self.relu(self.time_mlp(t))
51
- # Extend last 2 dimensions
52
- time_emb = time_emb[(..., ) + (None, ) * 2]
53
- # Add time channel
54
- h = h + time_emb
55
- # Second Conv
56
- h = (self.silu(self.bnorm2(self.conv2(h))))
57
- # Down or Upsample
58
- return self.transform(h)
59
-
60
-
61
- class SinusoidalPositionEmbeddings(nn.Module):
62
- def __init__(self, dim):
63
- super().__init__()
64
- self.dim = dim
65
-
66
- def forward(self, time):
67
- device = time.device
68
- half_dim = self.dim // 2
69
- embeddings = math.log(10000) / (half_dim - 1)
70
- embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
71
- embeddings = time[:, None] * embeddings[None, :]
72
- embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
73
- # TODO: Double check the ordering here
74
- return embeddings
75
-
76
-
77
- class SimpleUnet(nn.Module):
78
- """
79
- A simplified variant of the Unet architecture.
80
- """
81
- def __init__(self):
82
- super().__init__()
83
- image_channels = 3
84
- down_channels = (32, 64, 128, 256, 512)
85
- up_channels = (512, 256, 128, 64, 32)
86
- out_dim = 3
87
- time_emb_dim = 32
88
-
89
- # Time embedding
90
- self.time_mlp = nn.Sequential(
91
- SinusoidalPositionEmbeddings(time_emb_dim),
92
- nn.Linear(time_emb_dim, time_emb_dim),
93
- nn.ReLU()
94
- )
95
-
96
- # Initial projection
97
- self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
98
-
99
- # Downsample
100
- self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
101
- time_emb_dim) \
102
- for i in range(len(down_channels)-1)])
103
- # Upsample
104
- self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
105
- time_emb_dim, up=True) \
106
- for i in range(len(up_channels)-1)])
107
-
108
- # Edit: Corrected a bug found by Jakub C (see YouTube comment)
109
- self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
110
-
111
- def forward(self, x, timestep):
112
- # Embedd time
113
- t = self.time_mlp(timestep)
114
- # Initial conv
115
- x = self.conv0(x)
116
- # Unet
117
- residual_inputs = []
118
- for down in self.downs:
119
- x = down(x, t)
120
- residual_inputs.append(x)
121
- for up in self.ups:
122
- residual_x = residual_inputs.pop()
123
- # Add residual x as additional channels
124
- x = torch.cat((x, residual_x), dim=1)
125
- x = up(x, t)
126
- return self.output(x)
127
 
128
  def extract(a, t, x_shape):
129
  batch_size = t.shape[0]
 
24
  sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
25
  posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def extract(a, t, x_shape):
29
  batch_size = t.shape[0]