Fgdfgfthgr commited on
Commit
110b2ff
·
verified ·
1 Parent(s): 4936d6c

Delete network_diffusion_unet.py

Browse files
Files changed (1) hide show
  1. network_diffusion_unet.py +0 -389
network_diffusion_unet.py DELETED
@@ -1,389 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from torch.utils.checkpoint import checkpoint
5
-
6
-
7
- class SinusoidalEmbedding(nn.Module):
8
- def __init__(self, embedding_dim=128, base=1000, scaling=1000):
9
- super().__init__()
10
- self.embedding_dim = embedding_dim
11
- half_dim = embedding_dim // 2
12
- freqs = torch.exp(-math.log(base) * torch.arange(0, half_dim) / half_dim)
13
- # at base 1000, max-range = +=500pi = -1571 to 1571
14
- self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
15
- self.freqs = nn.parameter.Buffer(freqs)
16
-
17
- def forward(self, scaler):
18
- scaler = scaler * self.scaling
19
- args = scaler[:, None] * self.freqs[None]
20
- embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
21
- return embedding
22
-
23
-
24
- class SinusoidalPositionalEmbedding2D(nn.Module):
25
-
26
- def __init__(self, embedding_dim):
27
- super().__init__()
28
- assert embedding_dim % 2 == 0, "embedding_dim must be even"
29
- self.embedding_dim = embedding_dim
30
- half_dim = self.embedding_dim // 2
31
- div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(100.0) / half_dim))
32
- # Since our grid size is small, 100 should be enough
33
- self.div_term = nn.parameter.Buffer(div_term.to(torch.float32))
34
-
35
- def forward(self, height, width):
36
- """Generate embeddings for a grid of size (height, width)."""
37
-
38
- # Generate grid coordinates
39
- y_pos = torch.arange(height, dtype=torch.float32, device=self.div_term.device)
40
- x_pos = torch.arange(width, dtype=torch.float32, device=self.div_term.device)
41
-
42
- # Compute sinusoidal components for height and width
43
- y_sin = torch.sin(y_pos[:, None] * self.div_term[None, :])
44
- y_cos = torch.cos(y_pos[:, None] * self.div_term[None, :])
45
- x_sin = torch.sin(x_pos[:, None] * self.div_term[None, :])
46
- x_cos = torch.cos(x_pos[:, None] * self.div_term[None, :])
47
-
48
- # Interleave sin and cos components
49
- y_embed = torch.stack([y_sin, y_cos], dim=-1).view(height, -1)
50
- x_embed = torch.stack([x_sin, x_cos], dim=-1).view(width, -1)
51
-
52
- # Combine height and width embeddings
53
- pos_embed = torch.cat([y_embed[:, None, :].expand(-1, width, -1),
54
- x_embed[None, :, :].expand(height, -1, -1)], dim=-1)
55
- return pos_embed.view(height * width, self.embedding_dim)
56
-
57
-
58
- class ImageLinearAttention(nn.Module):
59
- def __init__(self, chan, kernel_size=3, heads=4, norm_queries=True, embd_dim=None):
60
- super().__init__()
61
- self.chan = chan
62
- self.heads = heads
63
- self.key_dim = key_dim = chan // heads
64
- self.value_dim = value_dim = chan // heads
65
- self.norm_queries = norm_queries
66
-
67
- # Convolutional projections for Q, K, V
68
- self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, padding='same', padding_mode='replicate')
69
- self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, padding='same', padding_mode='replicate')
70
- self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, padding='same', padding_mode='replicate')
71
- self.to_out = nn.Conv2d(value_dim * heads, chan, kernel_size, padding='same', padding_mode='replicate')
72
-
73
- # Adaptive normalization: Project embedding to scale/shift for group norm
74
- if embd_dim is not None:
75
- self.norm = nn.GroupNorm(1, key_dim * heads, affine=False) # Normalize without inherent affine params
76
- self.emb_proj = nn.Linear(embd_dim, 2 * key_dim * heads) # Project emb to scale/shift
77
- else:
78
- self.norm = nn.GroupNorm(1, key_dim * heads, affine=True)
79
- self.emb_proj = None
80
-
81
- def forward(self, x, emb=None):
82
- b, c, h, w = x.shape
83
- heads = self.heads
84
- key_dim = self.key_dim
85
-
86
- # Project input to queries, keys, and values
87
- q = self.to_q(x)
88
- k = self.to_k(x)
89
- v = self.to_v(x)
90
-
91
- # Apply adaptive normalization if embedding is provided
92
- if emb is not None and self.emb_proj is not None:
93
- emb_params = self.emb_proj(emb).view(b, 2, -1) # (b, 2, key_dim * heads)
94
- scale, shift = emb_params[:, 0], emb_params[:, 1] # Split into scale and shift
95
- # Normalize and modulate Q, K, V
96
- q = self.norm(q)
97
- k = self.norm(k)
98
- v = self.norm(v)
99
- # Apply scale and shift across spatial dimensions
100
- q = q * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
101
- k = k * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
102
- v = v * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
103
-
104
- # Reshape Q, K, V for multi-head attention
105
- q = q.view(b, heads, key_dim, h * w)
106
- k = k.view(b, heads, key_dim, h * w)
107
- v = v.view(b, heads, self.value_dim, h * w)
108
-
109
- # Scale queries and keys
110
- q = q * (key_dim ** -0.25)
111
- k = k * (key_dim ** -0.25)
112
-
113
- # Softmax on keys along the sequence dimension
114
- k = k.softmax(dim=-1)
115
- if self.norm_queries:
116
- q = q.softmax(dim=-2)
117
-
118
- # Compute context and output
119
- context = torch.einsum('bhdn,bhen->bhde', k, v)
120
- out = torch.einsum('bhdn,bhde->bhen', q, context)
121
- out = out.reshape(b, -1, h, w)
122
- out = self.to_out(out)
123
- return x + out
124
-
125
-
126
- class ResConvBlock(nn.Module):
127
- def __init__(self, channels, time_dim):
128
- super().__init__()
129
- self.first_conv = nn.Conv2d(channels, channels, 3, padding=1, bias=False, padding_mode='replicate')
130
- self.second_conv = nn.Conv2d(channels, channels, 3, padding=1, padding_mode='replicate')
131
- self.gn1 = nn.GroupNorm(8, channels, affine=True)
132
- self.gn2 = nn.GroupNorm(8, channels, affine=False)
133
- self.embd_affine = nn.Linear(time_dim, channels * 2)
134
- self.act = nn.LeakyReLU(inplace=True)
135
-
136
- def forward(self, x, t_emb):
137
- # Get affine parameters from time embedding
138
- affine_params = self.embd_affine(t_emb)
139
- scale, shift = affine_params.chunk(2, dim=1)
140
-
141
- # First convolution path
142
- h = self.first_conv(self.act(self.gn1(x)))
143
-
144
- # Second convolution path with adaptive normalization
145
- h = self.gn2(h)
146
- h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
147
- h = self.second_conv(self.act(h))
148
-
149
- return x + h
150
-
151
-
152
- class DiTLayer(nn.Module):
153
- def __init__(self, d_model, embd_dim, nhead, dim_feedforward=1024):
154
- super().__init__()
155
- self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
156
- self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=False)
157
- self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
158
- self.embd_affine = nn.Linear(embd_dim, 6*d_model)
159
- self.ffn = nn.Sequential(
160
- nn.Linear(d_model, dim_feedforward),
161
- nn.LeakyReLU(0.2, inplace=True),
162
- nn.Linear(dim_feedforward, d_model),
163
- )
164
-
165
- def forward(self, x, embd):
166
- affine_params = self.embd_affine(embd)
167
- scale1, scale2, shift1, shift2, alpha1, alpha2 = affine_params.chunk(6, dim=1)
168
-
169
- # Self-attention block
170
- x = self.norm1(x)
171
- x = x * (1 + scale1[None, :, :]) + shift1[None, :, :]
172
-
173
- attn_output, _ = self.attn(x, x, x)
174
- x = x + attn_output * alpha1[None, :, :]
175
-
176
- # Feedforward block
177
- x = self.norm2(x)
178
- x = x * (1 + scale2[None, :, :]) + shift2[None, :, :]
179
- ffn_output = self.ffn(x)
180
- x = x + ffn_output * alpha2[None, :, :]
181
- return x
182
-
183
-
184
- class DiTBlock(nn.Module):
185
- def __init__(self, channels, embd_dim, patch_size, nhead, num_layers):
186
- super().__init__()
187
- self.patch_size = patch_size
188
- self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
189
- hidden_size = channels * patch_size**2
190
- self.pos_embd = SinusoidalPositionalEmbedding2D(hidden_size)
191
- self.dit_layers = nn.ModuleList([
192
- DiTLayer(hidden_size, embd_dim, nhead, 2*hidden_size)
193
- for _ in range(num_layers)
194
- ])
195
-
196
- def forward(self, x, embd):
197
- B, C, H, W = x.shape
198
- H_p, W_p = H // self.patch_size, W // self.patch_size
199
- x = self.patchify(x).permute(0, 2, 1) # [B, num_patches, d_main]
200
- pos_embd = self.pos_embd(H_p, W_p).to(dtype=x.dtype)
201
- x = x + pos_embd.unsqueeze(0)
202
- x = x.permute(1, 0, 2) # [num_patches, B, d_main)
203
- for dit_layer in self.dit_layers:
204
- x = dit_layer(x, embd)
205
- x = x.permute(1, 2, 0) # [B, d_main, num_patches]
206
- x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
207
- return x
208
-
209
-
210
- class UpBlock(nn.Module):
211
- def __init__(self, in_ch, out_ch, time_dim, cat):
212
- super().__init__()
213
- self.res = ResConvBlock(in_ch, time_dim)
214
- self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
215
- self.cat = cat
216
-
217
- def forward(self, x, t_emb, skip=None):
218
- x = self.res(x, t_emb)
219
- x = self.up(x)
220
- if self.cat:
221
- x = torch.cat([x, skip], dim=1)
222
- else:
223
- x = x + skip
224
- return x
225
-
226
- class UpBlockWithDit(nn.Module):
227
- def __init__(self, in_ch, mid_ch, out_ch, patch_size, nhead, time_dim, layers, cat):
228
- super().__init__()
229
- self.res = ResConvBlock(in_ch, time_dim)
230
- self.down_map = nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=False)
231
- self.down_norm = nn.GroupNorm(4, mid_ch)
232
- self.dit = DiTBlock(mid_ch, time_dim, patch_size, nhead, layers)
233
- self.up_map = nn.Conv2d(mid_ch, in_ch, kernel_size=1)
234
- self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
235
- self.cat = cat
236
-
237
- def forward(self, x, embd, skip=None):
238
- x = self.res(x, embd)
239
- h = self.down_norm(self.down_map(x))
240
- h = self.dit(h, embd)
241
- h = self.up_map(h)
242
- x = x + h
243
- x = self.up(x)
244
- if self.cat:
245
- x = torch.cat([x, skip], dim=1)
246
- else:
247
- x = x + skip
248
- return x
249
-
250
-
251
- def run_block(module, *args):
252
- return module(*args)
253
-
254
-
255
- class ConditionalUNet(nn.Module):
256
- def __init__(self, base_ch=16, embd_dim=64, depth=5):
257
- super().__init__()
258
- self.depth = depth
259
- self.time_embd = SinusoidalEmbedding(embd_dim)
260
- self.waterlevel_embd = SinusoidalEmbedding(embd_dim, 10)
261
- embd_dim *= 2
262
-
263
- # Input channels = noisy height (1) + ridge map (1) + lake map (1)
264
- self.expand = nn.Conv2d(4, base_ch, 3, padding=1, padding_mode='replicate')
265
-
266
- # Encoder layers
267
- self.enc_blocks = nn.ModuleList()
268
- self.enc_dit_blocks = nn.ModuleList()
269
- self.down_convs = nn.ModuleList()
270
- current_ch = base_ch
271
-
272
- for i in range(depth):
273
- self.enc_blocks.append(ResConvBlock(current_ch, embd_dim))
274
- if i < depth - 1:
275
- self.down_convs.append(
276
- nn.Conv2d(current_ch, current_ch * 2, 4, stride=2, padding=1, padding_mode='replicate')
277
- )
278
- current_ch *= 2
279
-
280
- # Bottleneck
281
- self.bottleneck = nn.Conv2d(current_ch, current_ch * 2, 4, stride=2, padding=1, padding_mode='replicate')
282
- current_ch *= 2
283
-
284
- # Decoder layers
285
- self.up_blocks = nn.ModuleList()
286
- for i in range(depth):
287
- cat = (i == depth - 1) # Only concatenate in the final up block
288
- self.up_blocks.append(UpBlock(current_ch, current_ch // 2, embd_dim, cat))
289
- current_ch = current_ch // 2 * (2 if cat else 1)
290
-
291
- self.out = ResConvBlock(current_ch, embd_dim)
292
- self.final = nn.Conv2d(current_ch, 1, 1)
293
-
294
-
295
- def forward(self, x, map_average, ridge_map, basin_map, water_level, t):
296
- t_embed = self.time_embd(t).to(x.dtype)
297
- waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
298
- embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
299
-
300
- h = torch.cat([x, ridge_map, basin_map, map_average], dim=1)
301
- h = checkpoint(run_block, self.expand, h, use_reentrant=False) if self.training else self.expand(h)
302
-
303
- # Encoder
304
- skips = []
305
- for i in range(self.depth):
306
- h = checkpoint(run_block, self.enc_blocks[i], h, embeds, use_reentrant=False) if self.training else self.enc_blocks[i](h, embeds)
307
- skips.append(h)
308
- if i < self.depth - 1:
309
- h = checkpoint(run_block, self.down_convs[i], h, use_reentrant=False) if self.training else self.down_convs[i](h)
310
-
311
- # Bottleneck
312
- h = checkpoint(run_block, self.bottleneck, h, use_reentrant=False) if self.training else self.bottleneck(h)
313
-
314
- # Decoder
315
- for i in range(self.depth):
316
- h = checkpoint(run_block, self.up_blocks[i], h, embeds, skips[-(i + 1)], use_reentrant=False) if self.training else self.up_blocks[i](h, embeds, skips[-(i + 1)])
317
-
318
- h = checkpoint(run_block, self.out, h, embeds, use_reentrant=False) if self.training else self.out(h, embeds)
319
- h = checkpoint(run_block, self.final, h, use_reentrant=False) if self.training else self.final(h)
320
- return h
321
-
322
-
323
- class ConditionalUNetDiT(nn.Module):
324
- def __init__(self, base_ch=8, embd_dim=16):
325
- super().__init__()
326
- self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
327
- self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
328
- embd_dim *= 2
329
-
330
- # Input channels = noisy height (1) + ridge map (1) + lake map (1)
331
- self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode='replicate')
332
- self.enc_0 = ResConvBlock(base_ch, embd_dim)
333
-
334
- self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode='replicate') # 1024->512
335
- self.enc_1 = ResConvBlock(base_ch * 2, embd_dim)
336
- #self.enc_1_dit = DiTBlock(base_ch * 2, 16, 1024, 8, 4)
337
-
338
- self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode='replicate') # 512->256
339
-
340
- self.up1 = UpBlockWithDit(base_ch * 4, base_ch, base_ch * 2, 8, 8, embd_dim, 6, False) # 256->512
341
- self.up0 = UpBlockWithDit(base_ch * 2, base_ch//2, base_ch, 16, 16, embd_dim, 3, True) # 512->1024
342
- self.out = ResConvBlock(base_ch * 2, embd_dim)
343
- self.final = nn.Conv2d(base_ch * 2, 1, 1)
344
-
345
- def initialize(self):
346
- for name, m in self.named_modules():
347
- if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name):
348
- m.weight.data.zero_()
349
- m.bias.data.zero_()
350
- if isinstance(m, nn.Conv2d) and 'second_conv' in name:
351
- m.weight.data.zero_()
352
- m.bias.data.zero_()
353
-
354
- def forward(self, x, ridge_map, basin_map, water_level, t):
355
- t_embed = self.time_embd(t).to(x.dtype)
356
- waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
357
- embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
358
- # x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
359
- h0 = torch.cat([x, ridge_map, basin_map], dim=1) # concat condition
360
- # encode
361
- h0 = self.expand(h0)
362
- h0 = self.enc_0(h0, embeds)
363
- h1 = self.down0(h0)
364
- h1 = self.enc_1(h1, embeds) # 512x512
365
- #h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level)
366
- h2 = self.down1(h1) # 256x256
367
- # decode with skip connections
368
- out = self.up1(h2, embeds, h1) # 512x512
369
- out = self.up0(out, embeds, h0) # 1024x1024
370
- out = self.out(out, embeds)
371
- out = self.final(out)
372
- return out # predicted noise for diffusion loss
373
-
374
-
375
-
376
- if __name__ == "__main__":
377
- #a = ConditionalUNet()
378
- #t = SinusoidalEmbedding(256)
379
- #t_embd = t(torch.randint(0, 100, (1,)))
380
- #x = torch.randn(1, 1, 256, 256)
381
- #r = torch.randn(1, 1, 256, 256)
382
- #c = a(x, r, t_embd)
383
- #print(c)
384
- #print(c.shape)
385
- network = ConditionalUNetDiT()
386
- for name, m in network.named_modules():
387
- if isinstance(m, nn.Linear) and 'time_affine':
388
- m.weight.data.zero_()
389
- m.bias.data.zero_()