Fgdfgfthgr commited on
Commit
a33b794
·
verified ·
1 Parent(s): 55e667b

Upload 2 files

Browse files
Files changed (2) hide show
  1. FlashScape.safetensors +3 -0
  2. network_diffusion_unet.py +78 -55
FlashScape.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:611ea2049bb713991b76000f8b24296b252cd81a10cadd2de1995aa1d045e154
3
+ size 78154234
network_diffusion_unet.py CHANGED
@@ -5,11 +5,12 @@ from torch.utils.checkpoint import checkpoint
5
 
6
 
7
  class SinusoidalEmbedding(nn.Module):
8
- def __init__(self, embedding_dim=128, scaling=1000):
9
  super().__init__()
10
  self.embedding_dim = embedding_dim
11
  half_dim = embedding_dim // 2
12
- freqs = torch.exp(-math.log(10000) * torch.arange(0, half_dim) / half_dim)
 
13
  self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
14
  self.freqs = nn.parameter.Buffer(freqs)
15
 
@@ -27,8 +28,9 @@ class SinusoidalPositionalEmbedding2D(nn.Module):
27
  assert embedding_dim % 2 == 0, "embedding_dim must be even"
28
  self.embedding_dim = embedding_dim
29
  half_dim = self.embedding_dim // 2
30
- div_term = torch.exp(torch.arange(0, half_dim, 2, dtype=torch.float32) * (-math.log(10000.0) / half_dim))
31
- self.div_term = nn.parameter.Buffer(div_term)
 
32
 
33
  def forward(self, height, width):
34
  """Generate embeddings for a grid of size (height, width)."""
@@ -124,83 +126,85 @@ class ImageLinearAttention(nn.Module):
124
  class ResConvBlock(nn.Module):
125
  def __init__(self, channels, time_dim):
126
  super().__init__()
127
- self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False, padding_mode='replicate')
128
- self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, padding_mode='replicate')
129
  self.gn1 = nn.GroupNorm(8, channels, affine=True)
130
  self.gn2 = nn.GroupNorm(8, channels, affine=False)
131
- self.time_affine = nn.Linear(time_dim, channels * 2)
132
  self.act = nn.LeakyReLU(inplace=True)
133
 
134
  def forward(self, x, t_emb):
135
  # Get affine parameters from time embedding
136
- affine_params = self.time_affine(t_emb)
137
  scale, shift = affine_params.chunk(2, dim=1)
138
 
139
  # First convolution path
140
- h = self.conv1(self.act(self.gn1(x)))
141
 
142
  # Second convolution path with adaptive normalization
143
  h = self.gn2(h)
144
  h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
145
- h = self.conv2(self.act(h))
146
 
147
  return x + h
148
 
149
 
150
  class DiTLayer(nn.Module):
151
- def __init__(self, d_model, nhead, dim_feedforward=1024):
152
  super().__init__()
153
- self.norm1 = nn.LayerNorm(d_model)
154
- self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
155
- self.norm2 = nn.LayerNorm(d_model)
 
156
  self.ffn = nn.Sequential(
157
  nn.Linear(d_model, dim_feedforward),
158
  nn.LeakyReLU(0.2, inplace=True),
159
  nn.Linear(dim_feedforward, d_model),
160
  )
161
 
162
- def forward(self, src):
 
 
 
163
  # Self-attention block
164
- attn_output, _ = self.attn(self.norm1(src), self.norm1(src), self.norm1(src))
165
- src = src + attn_output
 
 
 
166
 
167
  # Feedforward block
168
- ffn_output = self.ffn(self.norm2(src))
169
- src = src + ffn_output
170
- return src
 
 
171
 
172
 
173
  class DiTBlock(nn.Module):
174
- def __init__(self, channels, patch_size, hidden_size, nhead, num_layers=2):
175
  super().__init__()
176
  self.patch_size = patch_size
177
  self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
178
- self.patch_embedding_in = nn.Linear(channels * patch_size**2, hidden_size)
179
  self.pos_embd = SinusoidalPositionalEmbedding2D(hidden_size)
180
- self.waterlevel_embd = SinusoidalEmbedding(hidden_size, 10)
181
- self.patch_embedding_out = nn.Linear(hidden_size, channels * patch_size**2)
182
  self.dit_layers = nn.ModuleList([
183
- DiTLayer(hidden_size, nhead, 2*hidden_size)
184
  for _ in range(num_layers)
185
  ])
186
- self.norm = nn.GroupNorm(8, channels)
187
 
188
- def forward(self, src, water_level):
189
- B, C, H, W = src.shape
190
  H_p, W_p = H // self.patch_size, W // self.patch_size
191
- x = self.norm(src)
192
- x = self.patchify(x).permute(0, 2, 1)
193
- x = self.patch_embedding_in(x)
194
  pos_embd = self.pos_embd(H_p, W_p).to(dtype=x.dtype)
195
  x = x + pos_embd.unsqueeze(0)
196
- water_level_cls = self.waterlevel_embd(water_level).unsqueeze(1)
197
- x = torch.cat((x, water_level_cls), dim=1)
198
  for dit_layer in self.dit_layers:
199
- x = dit_layer(x)
200
- x = self.patch_embedding_out(x).permute(0, 2, 1)
201
- x = x[:, :, :-1]
202
  x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
203
- return src + x
204
 
205
 
206
  class UpBlock(nn.Module):
@@ -220,16 +224,22 @@ class UpBlock(nn.Module):
220
  return x
221
 
222
  class UpBlockWithDit(nn.Module):
223
- def __init__(self, in_ch, out_ch, patch_size, hidden_size, nhead, time_dim, cat):
224
  super().__init__()
225
  self.res = ResConvBlock(in_ch, time_dim)
226
- self.dit = DiTBlock(in_ch, patch_size, hidden_size, nhead, 4)
 
 
 
227
  self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
228
  self.cat = cat
229
 
230
- def forward(self, x, t_emb, water_level, skip=None):
231
- x = self.res(x, t_emb)
232
- x = self.dit(x, water_level)
 
 
 
233
  x = self.up(x)
234
  if self.cat:
235
  x = torch.cat([x, skip], dim=1)
@@ -313,7 +323,9 @@ class ConditionalUNet(nn.Module):
313
  class ConditionalUNetDiT(nn.Module):
314
  def __init__(self, base_ch=8, embd_dim=16):
315
  super().__init__()
316
- self.time_embd = SinusoidalEmbedding(embd_dim, 1000)
 
 
317
 
318
  # Input channels = noisy height (1) + ridge map (1) + lake map (1)
319
  self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode='replicate')
@@ -321,30 +333,41 @@ class ConditionalUNetDiT(nn.Module):
321
 
322
  self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode='replicate') # 1024->512
323
  self.enc_1 = ResConvBlock(base_ch * 2, embd_dim)
324
- self.enc_1_dit = DiTBlock(base_ch * 2, 16, 1024, 8, 4)
325
 
326
  self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode='replicate') # 512->256
327
 
328
- self.up1 = UpBlockWithDit(base_ch * 4, base_ch * 2, 8, 1024, 8, embd_dim, False) # 256->512
329
- self.up0 = UpBlockWithDit(base_ch * 2, base_ch, 16, 1024, 8, embd_dim, True) # 512->1024
330
  self.out = ResConvBlock(base_ch * 2, embd_dim)
331
  self.final = nn.Conv2d(base_ch * 2, 1, 1)
332
 
 
 
 
 
 
 
 
 
 
333
  def forward(self, x, ridge_map, basin_map, water_level, t):
334
  t_embed = self.time_embd(t).to(x.dtype)
 
 
335
  # x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
336
  h0 = torch.cat([x, ridge_map, basin_map], dim=1) # concat condition
337
  # encode
338
- h0 = checkpoint(run_block, self.expand, h0, use_reentrant=False) if self.training else self.expand(h0)
339
- h0 = checkpoint(run_block, self.enc_0, h0, t_embed, use_reentrant=False) if self.training else self.enc_0(h0, t_embed)
340
- h1 = checkpoint(run_block, self.down0, h0, use_reentrant=False) if self.training else self.down0(h0)
341
- h1 = checkpoint(run_block, self.enc_1, h1, t_embed, use_reentrant=False) if self.training else self.enc_1(h1, t_embed)
342
- h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level) # 512x512
343
- h2 = checkpoint(run_block, self.down1, h1, use_reentrant=False) if self.training else self.down1(h1) # 256x256
344
  # decode with skip connections
345
- out = checkpoint(run_block, self.up1, h2, t_embed, water_level, h1, use_reentrant=False) if self.training else self.up1(h2, t_embed, water_level, h1) # 512x512
346
- out = checkpoint(run_block, self.up0, out, t_embed, water_level, h0, use_reentrant=False) if self.training else self.up0(out, t_embed, water_level, h0) # 1024x1024
347
- out = checkpoint(run_block, self.out, out, t_embed, use_reentrant=False) if self.training else self.out(out, t_embed)
348
  out = self.final(out)
349
  return out # predicted noise for diffusion loss
350
 
 
5
 
6
 
7
  class SinusoidalEmbedding(nn.Module):
8
+ def __init__(self, embedding_dim=128, base=1000, scaling=1000):
9
  super().__init__()
10
  self.embedding_dim = embedding_dim
11
  half_dim = embedding_dim // 2
12
+ freqs = torch.exp(-math.log(base) * torch.arange(0, half_dim) / half_dim)
13
+ # at base 1000, max-range = +=500pi = -1571 to 1571
14
  self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
15
  self.freqs = nn.parameter.Buffer(freqs)
16
 
 
28
  assert embedding_dim % 2 == 0, "embedding_dim must be even"
29
  self.embedding_dim = embedding_dim
30
  half_dim = self.embedding_dim // 2
31
+ div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(100.0) / half_dim))
32
+ # Since our grid size is small, 100 should be enough
33
+ self.div_term = nn.parameter.Buffer(div_term.to(torch.float32))
34
 
35
  def forward(self, height, width):
36
  """Generate embeddings for a grid of size (height, width)."""
 
126
  class ResConvBlock(nn.Module):
127
  def __init__(self, channels, time_dim):
128
  super().__init__()
129
+ self.first_conv = nn.Conv2d(channels, channels, 3, padding=1, bias=False, padding_mode='replicate')
130
+ self.second_conv = nn.Conv2d(channels, channels, 3, padding=1, padding_mode='replicate')
131
  self.gn1 = nn.GroupNorm(8, channels, affine=True)
132
  self.gn2 = nn.GroupNorm(8, channels, affine=False)
133
+ self.embd_affine = nn.Linear(time_dim, channels * 2)
134
  self.act = nn.LeakyReLU(inplace=True)
135
 
136
  def forward(self, x, t_emb):
137
  # Get affine parameters from time embedding
138
+ affine_params = self.embd_affine(t_emb)
139
  scale, shift = affine_params.chunk(2, dim=1)
140
 
141
  # First convolution path
142
+ h = self.first_conv(self.act(self.gn1(x)))
143
 
144
  # Second convolution path with adaptive normalization
145
  h = self.gn2(h)
146
  h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
147
+ h = self.second_conv(self.act(h))
148
 
149
  return x + h
150
 
151
 
152
  class DiTLayer(nn.Module):
153
+ def __init__(self, d_model, embd_dim, nhead, dim_feedforward=1024):
154
  super().__init__()
155
+ self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
156
+ self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=False)
157
+ self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
158
+ self.embd_affine = nn.Linear(embd_dim, 6*d_model)
159
  self.ffn = nn.Sequential(
160
  nn.Linear(d_model, dim_feedforward),
161
  nn.LeakyReLU(0.2, inplace=True),
162
  nn.Linear(dim_feedforward, d_model),
163
  )
164
 
165
+ def forward(self, x, embd):
166
+ affine_params = self.embd_affine(embd)
167
+ scale1, scale2, shift1, shift2, alpha1, alpha2 = affine_params.chunk(6, dim=1)
168
+
169
  # Self-attention block
170
+ x = self.norm1(x)
171
+ x = x * (1 + scale1[None, :, :]) + shift1[None, :, :]
172
+
173
+ attn_output, _ = self.attn(x, x, x)
174
+ x = x + attn_output * alpha1[None, :, :]
175
 
176
  # Feedforward block
177
+ x = self.norm2(x)
178
+ x = x * (1 + scale2[None, :, :]) + shift2[None, :, :]
179
+ ffn_output = self.ffn(x)
180
+ x = x + ffn_output * alpha2[None, :, :]
181
+ return x
182
 
183
 
184
  class DiTBlock(nn.Module):
185
+ def __init__(self, channels, embd_dim, patch_size, nhead, num_layers):
186
  super().__init__()
187
  self.patch_size = patch_size
188
  self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
189
+ hidden_size = channels * patch_size**2
190
  self.pos_embd = SinusoidalPositionalEmbedding2D(hidden_size)
 
 
191
  self.dit_layers = nn.ModuleList([
192
+ DiTLayer(hidden_size, embd_dim, nhead, 2*hidden_size)
193
  for _ in range(num_layers)
194
  ])
 
195
 
196
+ def forward(self, x, embd):
197
+ B, C, H, W = x.shape
198
  H_p, W_p = H // self.patch_size, W // self.patch_size
199
+ x = self.patchify(x).permute(0, 2, 1) # [B, num_patches, d_main]
 
 
200
  pos_embd = self.pos_embd(H_p, W_p).to(dtype=x.dtype)
201
  x = x + pos_embd.unsqueeze(0)
202
+ x = x.permute(1, 0, 2) # [num_patches, B, d_main)
 
203
  for dit_layer in self.dit_layers:
204
+ x = dit_layer(x, embd)
205
+ x = x.permute(1, 2, 0) # [B, d_main, num_patches]
 
206
  x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
207
+ return x
208
 
209
 
210
  class UpBlock(nn.Module):
 
224
  return x
225
 
226
  class UpBlockWithDit(nn.Module):
227
+ def __init__(self, in_ch, mid_ch, out_ch, patch_size, nhead, time_dim, layers, cat):
228
  super().__init__()
229
  self.res = ResConvBlock(in_ch, time_dim)
230
+ self.down_map = nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=False)
231
+ self.down_norm = nn.GroupNorm(4, mid_ch)
232
+ self.dit = DiTBlock(mid_ch, time_dim, patch_size, nhead, layers)
233
+ self.up_map = nn.Conv2d(mid_ch, in_ch, kernel_size=1)
234
  self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
235
  self.cat = cat
236
 
237
+ def forward(self, x, embd, skip=None):
238
+ x = self.res(x, embd)
239
+ h = self.down_norm(self.down_map(x))
240
+ h = self.dit(h, embd)
241
+ h = self.up_map(h)
242
+ x = x + h
243
  x = self.up(x)
244
  if self.cat:
245
  x = torch.cat([x, skip], dim=1)
 
323
  class ConditionalUNetDiT(nn.Module):
324
  def __init__(self, base_ch=8, embd_dim=16):
325
  super().__init__()
326
+ self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
327
+ self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
328
+ embd_dim *= 2
329
 
330
  # Input channels = noisy height (1) + ridge map (1) + lake map (1)
331
  self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode='replicate')
 
333
 
334
  self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode='replicate') # 1024->512
335
  self.enc_1 = ResConvBlock(base_ch * 2, embd_dim)
336
+ #self.enc_1_dit = DiTBlock(base_ch * 2, 16, 1024, 8, 4)
337
 
338
  self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode='replicate') # 512->256
339
 
340
+ self.up1 = UpBlockWithDit(base_ch * 4, base_ch, base_ch * 2, 8, 8, embd_dim, 6, False) # 256->512
341
+ self.up0 = UpBlockWithDit(base_ch * 2, base_ch//2, base_ch, 16, 16, embd_dim, 3, True) # 512->1024
342
  self.out = ResConvBlock(base_ch * 2, embd_dim)
343
  self.final = nn.Conv2d(base_ch * 2, 1, 1)
344
 
345
+ def initialize(self):
346
+ for name, m in self.named_modules():
347
+ if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name):
348
+ m.weight.data.zero_()
349
+ m.bias.data.zero_()
350
+ if isinstance(m, nn.Conv2d) and 'second_conv' in name:
351
+ m.weight.data.zero_()
352
+ m.bias.data.zero_()
353
+
354
  def forward(self, x, ridge_map, basin_map, water_level, t):
355
  t_embed = self.time_embd(t).to(x.dtype)
356
+ waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
357
+ embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
358
  # x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
359
  h0 = torch.cat([x, ridge_map, basin_map], dim=1) # concat condition
360
  # encode
361
+ h0 = self.expand(h0)
362
+ h0 = self.enc_0(h0, embeds)
363
+ h1 = self.down0(h0)
364
+ h1 = self.enc_1(h1, embeds) # 512x512
365
+ #h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level)
366
+ h2 = self.down1(h1) # 256x256
367
  # decode with skip connections
368
+ out = self.up1(h2, embeds, h1) # 512x512
369
+ out = self.up0(out, embeds, h0) # 1024x1024
370
+ out = self.out(out, embeds)
371
  out = self.final(out)
372
  return out # predicted noise for diffusion loss
373