Dhenenjay commited on
Commit
97b55e6
·
verified ·
1 Parent(s): a9abf27

Upload unet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. unet.py +352 -0
unet.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """E3Diff UNet Architecture - exact copy from original with fixed imports."""
2
+
3
+ import math
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from inspect import isfunction
8
+ from softpool import soft_pool2d, SoftPool2d
9
+
10
+
11
+ def exists(x):
12
+ return x is not None
13
+
14
+
15
+ def default(val, d):
16
+ if exists(val):
17
+ return val
18
+ return d() if isfunction(d) else d
19
+
20
+
21
+ class PositionalEncoding(nn.Module):
22
+ def __init__(self, dim):
23
+ super().__init__()
24
+ self.dim = dim
25
+
26
+ def forward(self, noise_level):
27
+ count = self.dim // 2
28
+ step = torch.arange(count, dtype=noise_level.dtype,
29
+ device=noise_level.device) / count
30
+ encoding = noise_level.unsqueeze(
31
+ 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
32
+ encoding = torch.cat(
33
+ [torch.sin(encoding), torch.cos(encoding)], dim=-1)
34
+ return encoding
35
+
36
+
37
+ class FeatureWiseAffine(nn.Module):
38
+ def __init__(self, in_channels, out_channels, use_affine_level=False):
39
+ super(FeatureWiseAffine, self).__init__()
40
+ self.use_affine_level = use_affine_level
41
+ self.noise_func = nn.Sequential(
42
+ nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
43
+ )
44
+
45
+ def forward(self, x, noise_embed):
46
+ batch = x.shape[0]
47
+ if self.use_affine_level:
48
+ gamma, beta = self.noise_func(noise_embed).view(
49
+ batch, -1, 1, 1).chunk(2, dim=1)
50
+ x = (1 + gamma) * x + beta
51
+ else:
52
+ x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
53
+ return x
54
+
55
+
56
+ class Swish(nn.Module):
57
+ def forward(self, x):
58
+ return x * torch.sigmoid(x)
59
+
60
+
61
+ class Upsample(nn.Module):
62
+ def __init__(self, dim):
63
+ super().__init__()
64
+ self.up = nn.Upsample(scale_factor=2, mode="nearest")
65
+ self.conv = nn.Conv2d(dim, dim, 3, padding=1)
66
+
67
+ def forward(self, x):
68
+ return self.conv(self.up(x))
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
75
+
76
+ def forward(self, x):
77
+ return self.conv(x)
78
+
79
+
80
+ class Block(nn.Module):
81
+ def __init__(self, dim, dim_out, groups=32, dropout=0, stride=1):
82
+ super().__init__()
83
+ self.block = nn.Sequential(
84
+ nn.GroupNorm(groups, dim),
85
+ Swish(),
86
+ nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
87
+ nn.Conv2d(dim, dim_out, 3, stride=stride, padding=1)
88
+ )
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class ResnetBlock(nn.Module):
95
+ def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
96
+ super().__init__()
97
+ self.noise_func = FeatureWiseAffine(
98
+ noise_level_emb_dim, dim_out, use_affine_level)
99
+ self.c_func = nn.Conv2d(dim_out, dim_out, 1)
100
+
101
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
102
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
103
+ self.res_conv = nn.Conv2d(
104
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
105
+
106
+ def forward(self, x, time_emb, c):
107
+ h = self.block1(x)
108
+ h = self.noise_func(h, time_emb)
109
+ h = self.block2(h)
110
+ h = self.c_func(c) + h
111
+ return h + self.res_conv(x)
112
+
113
+
114
+ class SelfAttention(nn.Module):
115
+ def __init__(self, in_channel, n_head=1, norm_groups=32):
116
+ super().__init__()
117
+ self.n_head = n_head
118
+ self.norm = nn.GroupNorm(norm_groups, in_channel)
119
+ self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
120
+ self.out = nn.Conv2d(in_channel, in_channel, 1)
121
+
122
+ def forward(self, input, t=None, save_flag=None, file_num=None):
123
+ batch, channel, height, width = input.shape
124
+ n_head = self.n_head
125
+ head_dim = channel // n_head
126
+
127
+ norm = self.norm(input)
128
+ qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
129
+ query, key, value = qkv.chunk(3, dim=2)
130
+
131
+ attn = torch.einsum(
132
+ "bnchw, bncyx -> bnhwyx", query, key
133
+ ).contiguous() / math.sqrt(channel)
134
+ attn = attn.view(batch, n_head, height, width, -1)
135
+ attn = torch.softmax(attn, -1)
136
+ attn = attn.view(batch, n_head, height, width, height, width)
137
+
138
+ out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
139
+ out = self.out(out.view(batch, channel, height, width))
140
+
141
+ return out + input
142
+
143
+
144
+ class ResnetBlocWithAttn(nn.Module):
145
+ def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, size=256):
146
+ super().__init__()
147
+ self.with_attn = with_attn
148
+ self.res_block = ResnetBlock(
149
+ dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
150
+ if with_attn:
151
+ self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
152
+
153
+ def forward(self, x, time_emb, c, t=0, save_flag=False, file_i=0):
154
+ x = self.res_block(x, time_emb, c)
155
+ if self.with_attn:
156
+ x = self.attn(x, t=t, save_flag=save_flag, file_num=file_i)
157
+ return x
158
+
159
+
160
+ class ResBlock_normal(nn.Module):
161
+ def __init__(self, dim, dim_out, dropout=0, norm_groups=32):
162
+ super().__init__()
163
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
164
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
165
+ self.res_conv = nn.Conv2d(
166
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
167
+
168
+ def forward(self, x):
169
+ b, c, h, w = x.shape
170
+ h = self.block1(x)
171
+ h = self.block2(h)
172
+ return h + self.res_conv(x)
173
+
174
+
175
+ class CPEN(nn.Module):
176
+ """Condition Pyramid Encoder Network - EXACT architecture from E3Diff."""
177
+ def __init__(self, inchannel=1):
178
+ super(CPEN, self).__init__()
179
+ self.pool = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
180
+
181
+ self.E1 = nn.Sequential(
182
+ nn.Conv2d(inchannel, 64, kernel_size=3, padding=1),
183
+ Swish()
184
+ )
185
+
186
+ self.E2 = nn.Sequential(
187
+ ResBlock_normal(64, 128, dropout=0, norm_groups=16),
188
+ ResBlock_normal(128, 128, dropout=0, norm_groups=16),
189
+ )
190
+
191
+ self.E3 = nn.Sequential(
192
+ ResBlock_normal(128, 256, dropout=0, norm_groups=16),
193
+ ResBlock_normal(256, 256, dropout=0, norm_groups=16),
194
+ )
195
+
196
+ self.E4 = nn.Sequential(
197
+ ResBlock_normal(256, 512, dropout=0, norm_groups=16),
198
+ ResBlock_normal(512, 512, dropout=0, norm_groups=16),
199
+ )
200
+
201
+ self.E5 = nn.Sequential(
202
+ ResBlock_normal(512, 512, dropout=0, norm_groups=16),
203
+ ResBlock_normal(512, 1024, dropout=0, norm_groups=16),
204
+ )
205
+
206
+ def forward(self, x):
207
+ x1 = self.E1(x) # 256x256, 64ch
208
+
209
+ x2 = self.pool(x1) # 128x128
210
+ x2 = self.E2(x2) # 128x128, 128ch
211
+
212
+ x3 = self.pool(x2) # 64x64
213
+ x3 = self.E3(x3) # 64x64, 256ch
214
+
215
+ x4 = self.pool(x3) # 32x32
216
+ x4 = self.E4(x4) # 32x32, 512ch
217
+
218
+ x5 = self.pool(x4) # 16x16
219
+ x5 = self.E5(x5) # 16x16, 1024ch
220
+
221
+ return x1, x2, x3, x4, x5
222
+
223
+
224
+ class UNet(nn.Module):
225
+ def __init__(
226
+ self,
227
+ in_channel=6,
228
+ out_channel=3,
229
+ inner_channel=32,
230
+ norm_groups=32,
231
+ channel_mults=(1, 2, 4, 8, 8),
232
+ attn_res=(8,),
233
+ res_blocks=3,
234
+ dropout=0,
235
+ with_noise_level_emb=True,
236
+ image_size=128,
237
+ lowres_cond=True,
238
+ condition_ch=3
239
+ ):
240
+ super().__init__()
241
+
242
+ if with_noise_level_emb:
243
+ noise_level_channel = inner_channel
244
+ self.noise_level_mlp = nn.Sequential(
245
+ PositionalEncoding(inner_channel),
246
+ nn.Linear(inner_channel, inner_channel * 4),
247
+ Swish(),
248
+ nn.Linear(inner_channel * 4, inner_channel)
249
+ )
250
+ else:
251
+ noise_level_channel = None
252
+ self.noise_level_mlp = None
253
+
254
+ self.res_blocks = res_blocks
255
+ num_mults = len(channel_mults)
256
+ self.num_mults = num_mults
257
+ pre_channel = inner_channel
258
+ feat_channels = [pre_channel]
259
+ now_res = image_size
260
+
261
+ downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
262
+ for ind in range(num_mults):
263
+ is_last = (ind == num_mults - 1)
264
+ use_attn = (now_res in attn_res)
265
+ channel_mult = inner_channel * channel_mults[ind]
266
+ for _ in range(0, res_blocks):
267
+ downs.append(ResnetBlocWithAttn(
268
+ pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
269
+ norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
270
+ feat_channels.append(channel_mult)
271
+ pre_channel = channel_mult
272
+ if not is_last:
273
+ downs.append(Downsample(pre_channel))
274
+ feat_channels.append(pre_channel)
275
+ now_res = now_res // 2
276
+ self.downs = nn.ModuleList(downs)
277
+
278
+ self.mid = nn.ModuleList([
279
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
280
+ norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res),
281
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
282
+ norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res)
283
+ ])
284
+
285
+ ups = []
286
+ for ind in reversed(range(num_mults)):
287
+ is_last = (ind < 1)
288
+ use_attn = (now_res in attn_res)
289
+ channel_mult = inner_channel * channel_mults[ind]
290
+ for _ in range(0, res_blocks + 1):
291
+ ups.append(ResnetBlocWithAttn(
292
+ pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel,
293
+ norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
294
+ pre_channel = channel_mult
295
+ if not is_last:
296
+ ups.append(Upsample(pre_channel))
297
+ now_res = now_res * 2
298
+ self.ups = nn.ModuleList(ups)
299
+
300
+ self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
301
+
302
+ self.condition = CPEN(inchannel=condition_ch)
303
+ self.condition_ch = condition_ch
304
+
305
+ def forward(self, x, time, img_s1=None, class_label=None, return_condition=False, t_ori=0):
306
+ condition = x[:, :self.condition_ch, ...].clone()
307
+ x = x[:, self.condition_ch:, ...]
308
+
309
+ c1, c2, c3, c4, c5 = self.condition(condition)
310
+ c_base = [c1, c2, c3, c4, c5]
311
+
312
+ c = []
313
+ for i in range(len(c_base)):
314
+ for _ in range(self.res_blocks):
315
+ c.append(c_base[i])
316
+
317
+ t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None
318
+
319
+ feats = []
320
+ i = 0
321
+ for layer in self.downs:
322
+ if isinstance(layer, ResnetBlocWithAttn):
323
+ x = layer(x, t, c[i])
324
+ i += 1
325
+ else:
326
+ x = layer(x)
327
+ feats.append(x)
328
+
329
+ for layer in self.mid:
330
+ if isinstance(layer, ResnetBlocWithAttn):
331
+ x = layer(x, t, c5)
332
+ else:
333
+ x = layer(x)
334
+
335
+ c_base = [c5, c4, c3, c2, c1]
336
+ c = []
337
+ for i in range(len(c_base)):
338
+ for _ in range(self.res_blocks + 1):
339
+ c.append(c_base[i])
340
+
341
+ i = 0
342
+ for layer in self.ups:
343
+ if isinstance(layer, ResnetBlocWithAttn):
344
+ x = layer(torch.cat((x, feats.pop()), dim=1), t, c[i])
345
+ i += 1
346
+ else:
347
+ x = layer(x)
348
+
349
+ if not return_condition:
350
+ return self.final_conv(x)
351
+ else:
352
+ return self.final_conv(x), [c1, c2, c3, c4, c5]