Upload 2 files
Browse files- FlashScape.safetensors +3 -0
- network_diffusion_unet.py +78 -55
FlashScape.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:611ea2049bb713991b76000f8b24296b252cd81a10cadd2de1995aa1d045e154
|
| 3 |
+
size 78154234
|
network_diffusion_unet.py
CHANGED
|
@@ -5,11 +5,12 @@ from torch.utils.checkpoint import checkpoint
|
|
| 5 |
|
| 6 |
|
| 7 |
class SinusoidalEmbedding(nn.Module):
|
| 8 |
-
def __init__(self, embedding_dim=128, scaling=1000):
|
| 9 |
super().__init__()
|
| 10 |
self.embedding_dim = embedding_dim
|
| 11 |
half_dim = embedding_dim // 2
|
| 12 |
-
freqs = torch.exp(-math.log(
|
|
|
|
| 13 |
self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
|
| 14 |
self.freqs = nn.parameter.Buffer(freqs)
|
| 15 |
|
|
@@ -27,8 +28,9 @@ class SinusoidalPositionalEmbedding2D(nn.Module):
|
|
| 27 |
assert embedding_dim % 2 == 0, "embedding_dim must be even"
|
| 28 |
self.embedding_dim = embedding_dim
|
| 29 |
half_dim = self.embedding_dim // 2
|
| 30 |
-
div_term = torch.exp(torch.arange(0, half_dim, 2
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
def forward(self, height, width):
|
| 34 |
"""Generate embeddings for a grid of size (height, width)."""
|
|
@@ -124,83 +126,85 @@ class ImageLinearAttention(nn.Module):
|
|
| 124 |
class ResConvBlock(nn.Module):
|
| 125 |
def __init__(self, channels, time_dim):
|
| 126 |
super().__init__()
|
| 127 |
-
self.
|
| 128 |
-
self.
|
| 129 |
self.gn1 = nn.GroupNorm(8, channels, affine=True)
|
| 130 |
self.gn2 = nn.GroupNorm(8, channels, affine=False)
|
| 131 |
-
self.
|
| 132 |
self.act = nn.LeakyReLU(inplace=True)
|
| 133 |
|
| 134 |
def forward(self, x, t_emb):
|
| 135 |
# Get affine parameters from time embedding
|
| 136 |
-
affine_params = self.
|
| 137 |
scale, shift = affine_params.chunk(2, dim=1)
|
| 138 |
|
| 139 |
# First convolution path
|
| 140 |
-
h = self.
|
| 141 |
|
| 142 |
# Second convolution path with adaptive normalization
|
| 143 |
h = self.gn2(h)
|
| 144 |
h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
| 145 |
-
h = self.
|
| 146 |
|
| 147 |
return x + h
|
| 148 |
|
| 149 |
|
| 150 |
class DiTLayer(nn.Module):
|
| 151 |
-
def __init__(self, d_model, nhead, dim_feedforward=1024):
|
| 152 |
super().__init__()
|
| 153 |
-
self.norm1 = nn.LayerNorm(d_model)
|
| 154 |
-
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=
|
| 155 |
-
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
| 156 |
self.ffn = nn.Sequential(
|
| 157 |
nn.Linear(d_model, dim_feedforward),
|
| 158 |
nn.LeakyReLU(0.2, inplace=True),
|
| 159 |
nn.Linear(dim_feedforward, d_model),
|
| 160 |
)
|
| 161 |
|
| 162 |
-
def forward(self,
|
|
|
|
|
|
|
|
|
|
| 163 |
# Self-attention block
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Feedforward block
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
class DiTBlock(nn.Module):
|
| 174 |
-
def __init__(self, channels,
|
| 175 |
super().__init__()
|
| 176 |
self.patch_size = patch_size
|
| 177 |
self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
| 178 |
-
|
| 179 |
self.pos_embd = SinusoidalPositionalEmbedding2D(hidden_size)
|
| 180 |
-
self.waterlevel_embd = SinusoidalEmbedding(hidden_size, 10)
|
| 181 |
-
self.patch_embedding_out = nn.Linear(hidden_size, channels * patch_size**2)
|
| 182 |
self.dit_layers = nn.ModuleList([
|
| 183 |
-
DiTLayer(hidden_size, nhead, 2*hidden_size)
|
| 184 |
for _ in range(num_layers)
|
| 185 |
])
|
| 186 |
-
self.norm = nn.GroupNorm(8, channels)
|
| 187 |
|
| 188 |
-
def forward(self,
|
| 189 |
-
B, C, H, W =
|
| 190 |
H_p, W_p = H // self.patch_size, W // self.patch_size
|
| 191 |
-
x = self.
|
| 192 |
-
x = self.patchify(x).permute(0, 2, 1)
|
| 193 |
-
x = self.patch_embedding_in(x)
|
| 194 |
pos_embd = self.pos_embd(H_p, W_p).to(dtype=x.dtype)
|
| 195 |
x = x + pos_embd.unsqueeze(0)
|
| 196 |
-
|
| 197 |
-
x = torch.cat((x, water_level_cls), dim=1)
|
| 198 |
for dit_layer in self.dit_layers:
|
| 199 |
-
x = dit_layer(x)
|
| 200 |
-
x =
|
| 201 |
-
x = x[:, :, :-1]
|
| 202 |
x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
|
| 203 |
-
return
|
| 204 |
|
| 205 |
|
| 206 |
class UpBlock(nn.Module):
|
|
@@ -220,16 +224,22 @@ class UpBlock(nn.Module):
|
|
| 220 |
return x
|
| 221 |
|
| 222 |
class UpBlockWithDit(nn.Module):
|
| 223 |
-
def __init__(self, in_ch, out_ch, patch_size,
|
| 224 |
super().__init__()
|
| 225 |
self.res = ResConvBlock(in_ch, time_dim)
|
| 226 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 227 |
self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
|
| 228 |
self.cat = cat
|
| 229 |
|
| 230 |
-
def forward(self, x,
|
| 231 |
-
x = self.res(x,
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
| 233 |
x = self.up(x)
|
| 234 |
if self.cat:
|
| 235 |
x = torch.cat([x, skip], dim=1)
|
|
@@ -313,7 +323,9 @@ class ConditionalUNet(nn.Module):
|
|
| 313 |
class ConditionalUNetDiT(nn.Module):
|
| 314 |
def __init__(self, base_ch=8, embd_dim=16):
|
| 315 |
super().__init__()
|
| 316 |
-
self.time_embd = SinusoidalEmbedding(embd_dim, 1000)
|
|
|
|
|
|
|
| 317 |
|
| 318 |
# Input channels = noisy height (1) + ridge map (1) + lake map (1)
|
| 319 |
self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode='replicate')
|
|
@@ -321,30 +333,41 @@ class ConditionalUNetDiT(nn.Module):
|
|
| 321 |
|
| 322 |
self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode='replicate') # 1024->512
|
| 323 |
self.enc_1 = ResConvBlock(base_ch * 2, embd_dim)
|
| 324 |
-
self.enc_1_dit = DiTBlock(base_ch * 2, 16, 1024, 8, 4)
|
| 325 |
|
| 326 |
self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode='replicate') # 512->256
|
| 327 |
|
| 328 |
-
self.up1 = UpBlockWithDit(base_ch * 4, base_ch * 2, 8,
|
| 329 |
-
self.up0 = UpBlockWithDit(base_ch * 2, base_ch,
|
| 330 |
self.out = ResConvBlock(base_ch * 2, embd_dim)
|
| 331 |
self.final = nn.Conv2d(base_ch * 2, 1, 1)
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
def forward(self, x, ridge_map, basin_map, water_level, t):
|
| 334 |
t_embed = self.time_embd(t).to(x.dtype)
|
|
|
|
|
|
|
| 335 |
# x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
|
| 336 |
h0 = torch.cat([x, ridge_map, basin_map], dim=1) # concat condition
|
| 337 |
# encode
|
| 338 |
-
h0 =
|
| 339 |
-
h0 =
|
| 340 |
-
h1 =
|
| 341 |
-
h1 =
|
| 342 |
-
h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level)
|
| 343 |
-
h2 =
|
| 344 |
# decode with skip connections
|
| 345 |
-
out =
|
| 346 |
-
out =
|
| 347 |
-
out =
|
| 348 |
out = self.final(out)
|
| 349 |
return out # predicted noise for diffusion loss
|
| 350 |
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class SinusoidalEmbedding(nn.Module):
|
| 8 |
+
def __init__(self, embedding_dim=128, base=1000, scaling=1000):
|
| 9 |
super().__init__()
|
| 10 |
self.embedding_dim = embedding_dim
|
| 11 |
half_dim = embedding_dim // 2
|
| 12 |
+
freqs = torch.exp(-math.log(base) * torch.arange(0, half_dim) / half_dim)
|
| 13 |
+
# at base 1000, max-range = +=500pi = -1571 to 1571
|
| 14 |
self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
|
| 15 |
self.freqs = nn.parameter.Buffer(freqs)
|
| 16 |
|
|
|
|
| 28 |
assert embedding_dim % 2 == 0, "embedding_dim must be even"
|
| 29 |
self.embedding_dim = embedding_dim
|
| 30 |
half_dim = self.embedding_dim // 2
|
| 31 |
+
div_term = torch.exp(torch.arange(0, half_dim, 2) * (-math.log(100.0) / half_dim))
|
| 32 |
+
# Since our grid size is small, 100 should be enough
|
| 33 |
+
self.div_term = nn.parameter.Buffer(div_term.to(torch.float32))
|
| 34 |
|
| 35 |
def forward(self, height, width):
|
| 36 |
"""Generate embeddings for a grid of size (height, width)."""
|
|
|
|
| 126 |
class ResConvBlock(nn.Module):
|
| 127 |
def __init__(self, channels, time_dim):
|
| 128 |
super().__init__()
|
| 129 |
+
self.first_conv = nn.Conv2d(channels, channels, 3, padding=1, bias=False, padding_mode='replicate')
|
| 130 |
+
self.second_conv = nn.Conv2d(channels, channels, 3, padding=1, padding_mode='replicate')
|
| 131 |
self.gn1 = nn.GroupNorm(8, channels, affine=True)
|
| 132 |
self.gn2 = nn.GroupNorm(8, channels, affine=False)
|
| 133 |
+
self.embd_affine = nn.Linear(time_dim, channels * 2)
|
| 134 |
self.act = nn.LeakyReLU(inplace=True)
|
| 135 |
|
| 136 |
def forward(self, x, t_emb):
|
| 137 |
# Get affine parameters from time embedding
|
| 138 |
+
affine_params = self.embd_affine(t_emb)
|
| 139 |
scale, shift = affine_params.chunk(2, dim=1)
|
| 140 |
|
| 141 |
# First convolution path
|
| 142 |
+
h = self.first_conv(self.act(self.gn1(x)))
|
| 143 |
|
| 144 |
# Second convolution path with adaptive normalization
|
| 145 |
h = self.gn2(h)
|
| 146 |
h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
| 147 |
+
h = self.second_conv(self.act(h))
|
| 148 |
|
| 149 |
return x + h
|
| 150 |
|
| 151 |
|
| 152 |
class DiTLayer(nn.Module):
|
| 153 |
+
def __init__(self, d_model, embd_dim, nhead, dim_feedforward=1024):
|
| 154 |
super().__init__()
|
| 155 |
+
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
|
| 156 |
+
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=False)
|
| 157 |
+
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
|
| 158 |
+
self.embd_affine = nn.Linear(embd_dim, 6*d_model)
|
| 159 |
self.ffn = nn.Sequential(
|
| 160 |
nn.Linear(d_model, dim_feedforward),
|
| 161 |
nn.LeakyReLU(0.2, inplace=True),
|
| 162 |
nn.Linear(dim_feedforward, d_model),
|
| 163 |
)
|
| 164 |
|
| 165 |
+
def forward(self, x, embd):
|
| 166 |
+
affine_params = self.embd_affine(embd)
|
| 167 |
+
scale1, scale2, shift1, shift2, alpha1, alpha2 = affine_params.chunk(6, dim=1)
|
| 168 |
+
|
| 169 |
# Self-attention block
|
| 170 |
+
x = self.norm1(x)
|
| 171 |
+
x = x * (1 + scale1[None, :, :]) + shift1[None, :, :]
|
| 172 |
+
|
| 173 |
+
attn_output, _ = self.attn(x, x, x)
|
| 174 |
+
x = x + attn_output * alpha1[None, :, :]
|
| 175 |
|
| 176 |
# Feedforward block
|
| 177 |
+
x = self.norm2(x)
|
| 178 |
+
x = x * (1 + scale2[None, :, :]) + shift2[None, :, :]
|
| 179 |
+
ffn_output = self.ffn(x)
|
| 180 |
+
x = x + ffn_output * alpha2[None, :, :]
|
| 181 |
+
return x
|
| 182 |
|
| 183 |
|
| 184 |
class DiTBlock(nn.Module):
|
| 185 |
+
def __init__(self, channels, embd_dim, patch_size, nhead, num_layers):
|
| 186 |
super().__init__()
|
| 187 |
self.patch_size = patch_size
|
| 188 |
self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
| 189 |
+
hidden_size = channels * patch_size**2
|
| 190 |
self.pos_embd = SinusoidalPositionalEmbedding2D(hidden_size)
|
|
|
|
|
|
|
| 191 |
self.dit_layers = nn.ModuleList([
|
| 192 |
+
DiTLayer(hidden_size, embd_dim, nhead, 2*hidden_size)
|
| 193 |
for _ in range(num_layers)
|
| 194 |
])
|
|
|
|
| 195 |
|
| 196 |
+
def forward(self, x, embd):
|
| 197 |
+
B, C, H, W = x.shape
|
| 198 |
H_p, W_p = H // self.patch_size, W // self.patch_size
|
| 199 |
+
x = self.patchify(x).permute(0, 2, 1) # [B, num_patches, d_main]
|
|
|
|
|
|
|
| 200 |
pos_embd = self.pos_embd(H_p, W_p).to(dtype=x.dtype)
|
| 201 |
x = x + pos_embd.unsqueeze(0)
|
| 202 |
+
x = x.permute(1, 0, 2) # [num_patches, B, d_main)
|
|
|
|
| 203 |
for dit_layer in self.dit_layers:
|
| 204 |
+
x = dit_layer(x, embd)
|
| 205 |
+
x = x.permute(1, 2, 0) # [B, d_main, num_patches]
|
|
|
|
| 206 |
x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
|
| 207 |
+
return x
|
| 208 |
|
| 209 |
|
| 210 |
class UpBlock(nn.Module):
|
|
|
|
| 224 |
return x
|
| 225 |
|
| 226 |
class UpBlockWithDit(nn.Module):
|
| 227 |
+
def __init__(self, in_ch, mid_ch, out_ch, patch_size, nhead, time_dim, layers, cat):
|
| 228 |
super().__init__()
|
| 229 |
self.res = ResConvBlock(in_ch, time_dim)
|
| 230 |
+
self.down_map = nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=False)
|
| 231 |
+
self.down_norm = nn.GroupNorm(4, mid_ch)
|
| 232 |
+
self.dit = DiTBlock(mid_ch, time_dim, patch_size, nhead, layers)
|
| 233 |
+
self.up_map = nn.Conv2d(mid_ch, in_ch, kernel_size=1)
|
| 234 |
self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
|
| 235 |
self.cat = cat
|
| 236 |
|
| 237 |
+
def forward(self, x, embd, skip=None):
|
| 238 |
+
x = self.res(x, embd)
|
| 239 |
+
h = self.down_norm(self.down_map(x))
|
| 240 |
+
h = self.dit(h, embd)
|
| 241 |
+
h = self.up_map(h)
|
| 242 |
+
x = x + h
|
| 243 |
x = self.up(x)
|
| 244 |
if self.cat:
|
| 245 |
x = torch.cat([x, skip], dim=1)
|
|
|
|
| 323 |
class ConditionalUNetDiT(nn.Module):
|
| 324 |
def __init__(self, base_ch=8, embd_dim=16):
|
| 325 |
super().__init__()
|
| 326 |
+
self.time_embd = SinusoidalEmbedding(embd_dim, scaling=1000)
|
| 327 |
+
self.waterlevel_embd = SinusoidalEmbedding(embd_dim, scaling=1)
|
| 328 |
+
embd_dim *= 2
|
| 329 |
|
| 330 |
# Input channels = noisy height (1) + ridge map (1) + lake map (1)
|
| 331 |
self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode='replicate')
|
|
|
|
| 333 |
|
| 334 |
self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode='replicate') # 1024->512
|
| 335 |
self.enc_1 = ResConvBlock(base_ch * 2, embd_dim)
|
| 336 |
+
#self.enc_1_dit = DiTBlock(base_ch * 2, 16, 1024, 8, 4)
|
| 337 |
|
| 338 |
self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode='replicate') # 512->256
|
| 339 |
|
| 340 |
+
self.up1 = UpBlockWithDit(base_ch * 4, base_ch, base_ch * 2, 8, 8, embd_dim, 6, False) # 256->512
|
| 341 |
+
self.up0 = UpBlockWithDit(base_ch * 2, base_ch//2, base_ch, 16, 16, embd_dim, 3, True) # 512->1024
|
| 342 |
self.out = ResConvBlock(base_ch * 2, embd_dim)
|
| 343 |
self.final = nn.Conv2d(base_ch * 2, 1, 1)
|
| 344 |
|
| 345 |
+
def initialize(self):
|
| 346 |
+
for name, m in self.named_modules():
|
| 347 |
+
if isinstance(m, nn.Linear) and ('embd_affine' in name or 'water_level_affine' in name):
|
| 348 |
+
m.weight.data.zero_()
|
| 349 |
+
m.bias.data.zero_()
|
| 350 |
+
if isinstance(m, nn.Conv2d) and 'second_conv' in name:
|
| 351 |
+
m.weight.data.zero_()
|
| 352 |
+
m.bias.data.zero_()
|
| 353 |
+
|
| 354 |
def forward(self, x, ridge_map, basin_map, water_level, t):
|
| 355 |
t_embed = self.time_embd(t).to(x.dtype)
|
| 356 |
+
waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
|
| 357 |
+
embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
|
| 358 |
# x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
|
| 359 |
h0 = torch.cat([x, ridge_map, basin_map], dim=1) # concat condition
|
| 360 |
# encode
|
| 361 |
+
h0 = self.expand(h0)
|
| 362 |
+
h0 = self.enc_0(h0, embeds)
|
| 363 |
+
h1 = self.down0(h0)
|
| 364 |
+
h1 = self.enc_1(h1, embeds) # 512x512
|
| 365 |
+
#h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level)
|
| 366 |
+
h2 = self.down1(h1) # 256x256
|
| 367 |
# decode with skip connections
|
| 368 |
+
out = self.up1(h2, embeds, h1) # 512x512
|
| 369 |
+
out = self.up0(out, embeds, h0) # 1024x1024
|
| 370 |
+
out = self.out(out, embeds)
|
| 371 |
out = self.final(out)
|
| 372 |
return out # predicted noise for diffusion loss
|
| 373 |
|