Fgdfgfthgr commited on
Commit
c5373a4
·
verified ·
1 Parent(s): 1e3f3a2

Upload 4 files

Browse files
data_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import random
4
+ import torch
5
+ import imageio.v3 as imageio
6
+ import numpy as np
7
+ import skimage.morphology as morph
8
+ import torchvision.transforms.v2.functional as T_F
9
+
10
+ from skimage.filters import sato
11
+ from pathlib import Path
12
+ from scipy.ndimage import zoom
13
+ from torchvision.datasets.folder import has_file_allowed_extension
14
+
15
+
16
+ def make_dataset_t(image_dir, extensions=(".tif", ".tiff")):
17
+ image_dir = Path(image_dir)
18
+ images = [
19
+ (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
20
+ for path in sorted(image_dir.iterdir())
21
+ if (has_file_allowed_extension(path.name, extensions)
22
+ and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
23
+ ]
24
+ return images
25
+
26
+ def make_dataset_t_v(image_dir, extensions=(".tif", ".tiff")):
27
+ image_dir = Path(image_dir)
28
+ # Use list comprehension for faster filtering
29
+ images = [
30
+ (path, image_dir / f'Ridge_{path.name}', image_dir / f'Basins_{path.name}')
31
+ for path in sorted(image_dir.iterdir())
32
+ if (has_file_allowed_extension(path.name, extensions)
33
+ and (not path.name.startswith('Ridge_')) and (not path.name.startswith('Basins_')))
34
+ ]
35
+
36
+ # Shuffle in place
37
+ random.shuffle(images)
38
+
39
+ # Calculate split index once
40
+ split_idx = int(0.95 * len(images))
41
+ return images[:split_idx], images[split_idx:]
42
+
43
+ def augmentations(image, label1, label2):
44
+ if random.random() < 0.5:
45
+ image, label1, label2 = T_F.vflip(image), T_F.vflip(label1), T_F.vflip(label2)
46
+ if random.random() < 0.5:
47
+ image, label1, label2 = T_F.hflip(image), T_F.hflip(label1), T_F.vflip(label2)
48
+ angles = [90, 180, 270]
49
+ angle = random.choice(angles)
50
+ if random.random() < 0.75:
51
+ image, label1, label2 = T_F.rotate(image, angle), T_F.rotate(label1, angle), T_F.rotate(label2, angle)
52
+ return image, label1, label2
53
+
54
+ mean, std = (149.95293407563648, 330.8314960521203)
55
+ target_water_level_range = [-100, 300]
56
+
57
+ class TrainDataset(torch.utils.data.Dataset):
58
+ def __init__(self, train_split):
59
+ self.train_split = train_split
60
+
61
+ def __len__(self):
62
+ return len(self.train_split)
63
+
64
+ def __getitem__(self, index):
65
+ pair = self.train_split[index]
66
+ img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
67
+ img = (img - mean) / std
68
+ ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
69
+ basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
70
+ water_level = random.randint(*target_water_level_range)
71
+ basins = (basins >= water_level).to(torch.float16)
72
+ img, ridge, basins = augmentations(img, ridge, basins)
73
+ return img, ridge, basins, torch.tensor(water_level, dtype=torch.float16)
74
+
75
+ class ValDataset(torch.utils.data.Dataset):
76
+ def __init__(self, val_split):
77
+ self.val_split = val_split
78
+
79
+ def __len__(self):
80
+ return len(self.val_split)
81
+
82
+ def __getitem__(self, index):
83
+ pair = self.val_split[index]
84
+ img = torch.from_numpy(imageio.imread(str(pair[0])))[None, :]
85
+ img = (img - mean) / std
86
+ ridge = torch.from_numpy(imageio.imread(str(pair[1])))[None, :].to(torch.float16)
87
+ basins = torch.from_numpy(imageio.imread(str(pair[2])))[None, :]
88
+ target_level = random.randint(*target_water_level_range)
89
+ basins = (basins >= target_level).to(torch.float16)
90
+ return img, ridge, basins, torch.tensor(target_level, dtype=torch.float16)
91
+
92
+ if __name__ == '__main__':
93
+ train_split, val_split = make_dataset_t_v('dataset')
94
+
95
+ train_dataset = TrainDataset(train_split)
96
+ val_dataset = ValDataset(val_split)
97
+
98
+ print(train_dataset.__getitem__(0))
99
+ print(val_dataset.__getitem__(0))
mass_generate_examples.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ import torch.utils.data
5
+
6
+ import imageio.v3 as imageio
7
+ import lightning.pytorch as pl
8
+ import matplotlib.pyplot as plt
9
+
10
+ from network_diffusion_unet import ConditionalUNetDiT
11
+ from safetensors.torch import load_file
12
+
13
+ class PLModule(pl.LightningModule):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.model = ConditionalUNetDiT(8, 16)
17
+
18
+ @torch.no_grad()
19
+ def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
20
+ device = self.device
21
+ b = ridge_map.shape[0]
22
+ x = torch.randn_like(ridge_map, device=device, dtype=torch.float16)
23
+ water_level = torch.tensor((water_level,), device=device, dtype=torch.float16).expand(b, )
24
+ time = torch.linspace(0, 1, num_steps + 1, device=device, dtype=torch.float16)
25
+
26
+ for i in range(num_steps):
27
+ t = torch.full((b,), time[i], device=device, dtype=torch.float16)
28
+ dt = torch.full((b, 1, 1, 1), time[i + 1] - time[i], device=device, dtype=torch.float16)
29
+
30
+ v = self.model(x, ridge_map, basin_map, water_level, t)
31
+
32
+ x = x + dt * v
33
+
34
+ return x
35
+
36
+ if __name__ == "__main__":
37
+ #model = PLModule.load_from_checkpoint('FlashScape.ckpt').to(device='cuda', dtype=torch.float16)
38
+ model = PLModule()
39
+ model.model.load_state_dict(load_file('FlashScape.safetensors'))
40
+ model.to(device='cuda', dtype=torch.float16)
41
+ model.eval()
42
+
43
+ test_ridge = torch.from_numpy(imageio.imread('dataset_large/Ridge_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
44
+ test_basin = torch.from_numpy(imageio.imread('dataset_large/Basins_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
45
+ gt = torch.from_numpy(imageio.imread('dataset_large/11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
46
+ water_level = 300.0
47
+ num_steps = 10
48
+ num_images = 4
49
+
50
+ test_basin = (test_basin >= water_level).to(torch.float16)
51
+ test_ridge = test_ridge.expand(num_images, -1, -1, -1)
52
+ test_basin = test_basin.expand(num_images, -1, -1, -1)
53
+ generated = model.inference_step(test_ridge, test_basin, water_level, num_steps)
54
+ # Back to original range
55
+ generated = generated * 330.8314960521203 + 149.95293407563648
56
+
57
+ # Prepare images for visualization
58
+ ridge_display = test_ridge[0, 0].cpu().float()
59
+ basin_display = test_basin[0, 0].cpu().float()
60
+ gt_display = gt[0, 0].cpu().float()
61
+ generated_display = generated[:, 0].cpu() # Remove channel dim
62
+
63
+ # Calculate optimal grid layout
64
+ total_images = num_images + 3 # condition1+ condition2 + gt + generated images
65
+ image_size = ridge_display.shape[0] # assuming square images
66
+
67
+ # Determine optimal number of columns (aim for roughly 4:3 aspect ratio)
68
+ max_cols = min(6, total_images) # Maximum 6 columns for readability
69
+ cols = min(max_cols, total_images)
70
+ rows = math.ceil(total_images / cols)
71
+
72
+ # Calculate figure size based on image dimensions and grid layout
73
+ base_height_per_image = 5 # inches per image height
74
+ base_width_per_image = 5 # inches per image width
75
+
76
+ fig_width = cols * base_width_per_image + 0.1 # +1 for colorbar space
77
+ fig_height = rows * base_height_per_image
78
+
79
+ # Create figure with subplots
80
+ fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
81
+
82
+ # Flatten axes array for easier indexing
83
+ if rows > 1 and cols > 1:
84
+ axes = axes.flatten()
85
+ elif rows == 1 and cols > 1:
86
+ axes = axes
87
+ elif rows > 1 and cols == 1:
88
+ axes = axes[:, 0]
89
+ else:
90
+ axes = [axes]
91
+
92
+ # Hide unused subplots
93
+ for i in range(total_images, len(axes)):
94
+ axes[i].set_visible(False)
95
+
96
+ # Plot condition image
97
+ im0 = axes[0].imshow(ridge_display, cmap='gray')
98
+ axes[0].set_title('Ridge Condition', fontsize=12, pad=2)
99
+ axes[0].set_axis_off()
100
+
101
+ # Plot condition image
102
+ im1 = axes[1].imshow(basin_display, cmap='gray')
103
+ axes[1].set_title('Basin Condition', fontsize=12, pad=2)
104
+ axes[1].set_axis_off()
105
+
106
+ # Plot ground truth image
107
+ im2 = axes[2].imshow(gt_display, cmap='gray')
108
+ axes[2].set_title('Ground Truth', fontsize=12, pad=2)
109
+ axes[2].set_axis_off()
110
+
111
+ # Plot generated images
112
+ for i in range(num_images):
113
+ im = axes[i + 3].imshow(generated_display[i], cmap='gray')
114
+ axes[i + 3].set_title(f'Generated {i + 1}', fontsize=10, pad=2)
115
+ axes[i + 3].set_axis_off()
116
+
117
+ # Add colorbar
118
+ cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, location='right')
119
+ cbar.set_label('Elevation', fontsize=14)
120
+
121
+ plt.savefig('result_grid.png', bbox_inches='tight', dpi=300)
122
+ plt.show()
network_diffusion_unet.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.checkpoint import checkpoint
5
+
6
+
7
+ class SinusoidalEmbedding(nn.Module):
8
+ def __init__(self, embedding_dim=128, scaling=1000):
9
+ super().__init__()
10
+ self.embedding_dim = embedding_dim
11
+ half_dim = embedding_dim // 2
12
+ freqs = torch.exp(-math.log(10000) * torch.arange(0, half_dim) / half_dim)
13
+ self.scaling = nn.parameter.Buffer(torch.tensor(scaling))
14
+ self.freqs = nn.parameter.Buffer(freqs)
15
+
16
+ def forward(self, scaler):
17
+ scaler = scaler * self.scaling
18
+ args = scaler[:, None] * self.freqs[None]
19
+ embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
20
+ return embedding
21
+
22
+
23
+ class SinusoidalPositionalEmbedding2D(nn.Module):
24
+
25
+ def __init__(self, embedding_dim):
26
+ super().__init__()
27
+ assert embedding_dim % 2 == 0, "embedding_dim must be even"
28
+ self.embedding_dim = embedding_dim
29
+ half_dim = self.embedding_dim // 2
30
+ div_term = torch.exp(torch.arange(0, half_dim, 2, dtype=torch.float32) * (-math.log(10000.0) / half_dim))
31
+ self.div_term = nn.parameter.Buffer(div_term)
32
+
33
+ def forward(self, height, width):
34
+ """Generate embeddings for a grid of size (height, width)."""
35
+
36
+ # Generate grid coordinates
37
+ y_pos = torch.arange(height, dtype=torch.float32, device=self.div_term.device)
38
+ x_pos = torch.arange(width, dtype=torch.float32, device=self.div_term.device)
39
+
40
+ # Compute sinusoidal components for height and width
41
+ y_sin = torch.sin(y_pos[:, None] * self.div_term[None, :])
42
+ y_cos = torch.cos(y_pos[:, None] * self.div_term[None, :])
43
+ x_sin = torch.sin(x_pos[:, None] * self.div_term[None, :])
44
+ x_cos = torch.cos(x_pos[:, None] * self.div_term[None, :])
45
+
46
+ # Interleave sin and cos components
47
+ y_embed = torch.stack([y_sin, y_cos], dim=-1).view(height, -1)
48
+ x_embed = torch.stack([x_sin, x_cos], dim=-1).view(width, -1)
49
+
50
+ # Combine height and width embeddings
51
+ pos_embed = torch.cat([y_embed[:, None, :].expand(-1, width, -1),
52
+ x_embed[None, :, :].expand(height, -1, -1)], dim=-1)
53
+ return pos_embed.view(height * width, self.embedding_dim)
54
+
55
+
56
+ class ImageLinearAttention(nn.Module):
57
+ def __init__(self, chan, kernel_size=3, heads=4, norm_queries=True, embd_dim=None):
58
+ super().__init__()
59
+ self.chan = chan
60
+ self.heads = heads
61
+ self.key_dim = key_dim = chan // heads
62
+ self.value_dim = value_dim = chan // heads
63
+ self.norm_queries = norm_queries
64
+
65
+ # Convolutional projections for Q, K, V
66
+ self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, padding='same', padding_mode='replicate')
67
+ self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, padding='same', padding_mode='replicate')
68
+ self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, padding='same', padding_mode='replicate')
69
+ self.to_out = nn.Conv2d(value_dim * heads, chan, kernel_size, padding='same', padding_mode='replicate')
70
+
71
+ # Adaptive normalization: Project embedding to scale/shift for group norm
72
+ if embd_dim is not None:
73
+ self.norm = nn.GroupNorm(1, key_dim * heads, affine=False) # Normalize without inherent affine params
74
+ self.emb_proj = nn.Linear(embd_dim, 2 * key_dim * heads) # Project emb to scale/shift
75
+ else:
76
+ self.norm = nn.GroupNorm(1, key_dim * heads, affine=True)
77
+ self.emb_proj = None
78
+
79
+ def forward(self, x, emb=None):
80
+ b, c, h, w = x.shape
81
+ heads = self.heads
82
+ key_dim = self.key_dim
83
+
84
+ # Project input to queries, keys, and values
85
+ q = self.to_q(x)
86
+ k = self.to_k(x)
87
+ v = self.to_v(x)
88
+
89
+ # Apply adaptive normalization if embedding is provided
90
+ if emb is not None and self.emb_proj is not None:
91
+ emb_params = self.emb_proj(emb).view(b, 2, -1) # (b, 2, key_dim * heads)
92
+ scale, shift = emb_params[:, 0], emb_params[:, 1] # Split into scale and shift
93
+ # Normalize and modulate Q, K, V
94
+ q = self.norm(q)
95
+ k = self.norm(k)
96
+ v = self.norm(v)
97
+ # Apply scale and shift across spatial dimensions
98
+ q = q * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
99
+ k = k * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
100
+ v = v * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
101
+
102
+ # Reshape Q, K, V for multi-head attention
103
+ q = q.view(b, heads, key_dim, h * w)
104
+ k = k.view(b, heads, key_dim, h * w)
105
+ v = v.view(b, heads, self.value_dim, h * w)
106
+
107
+ # Scale queries and keys
108
+ q = q * (key_dim ** -0.25)
109
+ k = k * (key_dim ** -0.25)
110
+
111
+ # Softmax on keys along the sequence dimension
112
+ k = k.softmax(dim=-1)
113
+ if self.norm_queries:
114
+ q = q.softmax(dim=-2)
115
+
116
+ # Compute context and output
117
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
118
+ out = torch.einsum('bhdn,bhde->bhen', q, context)
119
+ out = out.reshape(b, -1, h, w)
120
+ out = self.to_out(out)
121
+ return x + out
122
+
123
+
124
+ class ResConvBlock(nn.Module):
125
+ def __init__(self, channels, time_dim):
126
+ super().__init__()
127
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False, padding_mode='replicate')
128
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, padding_mode='replicate')
129
+ self.gn1 = nn.GroupNorm(8, channels, affine=True)
130
+ self.gn2 = nn.GroupNorm(8, channels, affine=False)
131
+ self.time_affine = nn.Linear(time_dim, channels * 2)
132
+ self.act = nn.LeakyReLU(inplace=True)
133
+
134
+ def forward(self, x, t_emb):
135
+ # Get affine parameters from time embedding
136
+ affine_params = self.time_affine(t_emb)
137
+ scale, shift = affine_params.chunk(2, dim=1)
138
+
139
+ # First convolution path
140
+ h = self.conv1(self.act(self.gn1(x)))
141
+
142
+ # Second convolution path with adaptive normalization
143
+ h = self.gn2(h)
144
+ h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
145
+ h = self.conv2(self.act(h))
146
+
147
+ return x + h
148
+
149
+
150
+ class DiTLayer(nn.Module):
151
+ def __init__(self, d_model, nhead, dim_feedforward=1024):
152
+ super().__init__()
153
+ self.norm1 = nn.LayerNorm(d_model)
154
+ self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
155
+ self.norm2 = nn.LayerNorm(d_model)
156
+ self.ffn = nn.Sequential(
157
+ nn.Linear(d_model, dim_feedforward),
158
+ nn.LeakyReLU(0.2, inplace=True),
159
+ nn.Linear(dim_feedforward, d_model),
160
+ )
161
+
162
+ def forward(self, src):
163
+ # Self-attention block
164
+ attn_output, _ = self.attn(self.norm1(src), self.norm1(src), self.norm1(src))
165
+ src = src + attn_output
166
+
167
+ # Feedforward block
168
+ ffn_output = self.ffn(self.norm2(src))
169
+ src = src + ffn_output
170
+ return src
171
+
172
+
173
+ class DiTBlock(nn.Module):
174
+ def __init__(self, channels, patch_size, hidden_size, nhead, num_layers=2):
175
+ super().__init__()
176
+ self.patch_size = patch_size
177
+ self.patchify = nn.Unfold(kernel_size=patch_size, stride=patch_size)
178
+ self.patch_embedding_in = nn.Linear(channels * patch_size**2, hidden_size)
179
+ self.pos_embd = SinusoidalPositionalEmbedding2D(hidden_size)
180
+ self.waterlevel_embd = SinusoidalEmbedding(hidden_size, 10)
181
+ self.patch_embedding_out = nn.Linear(hidden_size, channels * patch_size**2)
182
+ self.dit_layers = nn.ModuleList([
183
+ DiTLayer(hidden_size, nhead, 2*hidden_size)
184
+ for _ in range(num_layers)
185
+ ])
186
+ self.norm = nn.GroupNorm(8, channels)
187
+
188
+ def forward(self, src, water_level):
189
+ B, C, H, W = src.shape
190
+ H_p, W_p = H // self.patch_size, W // self.patch_size
191
+ x = self.norm(src)
192
+ x = self.patchify(x).permute(0, 2, 1)
193
+ x = self.patch_embedding_in(x)
194
+ pos_embd = self.pos_embd(H_p, W_p).to(dtype=x.dtype)
195
+ x = x + pos_embd.unsqueeze(0)
196
+ water_level_cls = self.waterlevel_embd(water_level).unsqueeze(1)
197
+ x = torch.cat((x, water_level_cls), dim=1)
198
+ for dit_layer in self.dit_layers:
199
+ x = dit_layer(x)
200
+ x = self.patch_embedding_out(x).permute(0, 2, 1)
201
+ x = x[:, :, :-1]
202
+ x = nn.functional.fold(x, (H, W), (self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size))
203
+ return src + x
204
+
205
+
206
+ class UpBlock(nn.Module):
207
+ def __init__(self, in_ch, out_ch, time_dim, cat):
208
+ super().__init__()
209
+ self.res = ResConvBlock(in_ch, time_dim)
210
+ self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
211
+ self.cat = cat
212
+
213
+ def forward(self, x, t_emb, skip=None):
214
+ x = self.res(x, t_emb)
215
+ x = self.up(x)
216
+ if self.cat:
217
+ x = torch.cat([x, skip], dim=1)
218
+ else:
219
+ x = x + skip
220
+ return x
221
+
222
+ class UpBlockWithDit(nn.Module):
223
+ def __init__(self, in_ch, out_ch, patch_size, hidden_size, nhead, time_dim, cat):
224
+ super().__init__()
225
+ self.res = ResConvBlock(in_ch, time_dim)
226
+ self.dit = DiTBlock(in_ch, patch_size, hidden_size, nhead, 4)
227
+ self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
228
+ self.cat = cat
229
+
230
+ def forward(self, x, t_emb, water_level, skip=None):
231
+ x = self.res(x, t_emb)
232
+ x = self.dit(x, water_level)
233
+ x = self.up(x)
234
+ if self.cat:
235
+ x = torch.cat([x, skip], dim=1)
236
+ else:
237
+ x = x + skip
238
+ return x
239
+
240
+
241
+ def run_block(module, *args):
242
+ return module(*args)
243
+
244
+
245
+ class ConditionalUNet(nn.Module):
246
+ def __init__(self, base_ch=16, embd_dim=64, depth=5):
247
+ super().__init__()
248
+ self.depth = depth
249
+ self.time_embd = SinusoidalEmbedding(embd_dim)
250
+ self.waterlevel_embd = SinusoidalEmbedding(embd_dim, 10)
251
+ embd_dim *= 2
252
+
253
+ # Input channels = noisy height (1) + ridge map (1) + lake map (1)
254
+ self.expand = nn.Conv2d(4, base_ch, 3, padding=1, padding_mode='replicate')
255
+
256
+ # Encoder layers
257
+ self.enc_blocks = nn.ModuleList()
258
+ self.enc_dit_blocks = nn.ModuleList()
259
+ self.down_convs = nn.ModuleList()
260
+ current_ch = base_ch
261
+
262
+ for i in range(depth):
263
+ self.enc_blocks.append(ResConvBlock(current_ch, embd_dim))
264
+ if i < depth - 1:
265
+ self.down_convs.append(
266
+ nn.Conv2d(current_ch, current_ch * 2, 4, stride=2, padding=1, padding_mode='replicate')
267
+ )
268
+ current_ch *= 2
269
+
270
+ # Bottleneck
271
+ self.bottleneck = nn.Conv2d(current_ch, current_ch * 2, 4, stride=2, padding=1, padding_mode='replicate')
272
+ current_ch *= 2
273
+
274
+ # Decoder layers
275
+ self.up_blocks = nn.ModuleList()
276
+ for i in range(depth):
277
+ cat = (i == depth - 1) # Only concatenate in the final up block
278
+ self.up_blocks.append(UpBlock(current_ch, current_ch // 2, embd_dim, cat))
279
+ current_ch = current_ch // 2 * (2 if cat else 1)
280
+
281
+ self.out = ResConvBlock(current_ch, embd_dim)
282
+ self.final = nn.Conv2d(current_ch, 1, 1)
283
+
284
+
285
+ def forward(self, x, map_average, ridge_map, basin_map, water_level, t):
286
+ t_embed = self.time_embd(t).to(x.dtype)
287
+ waterlevel_embd = self.waterlevel_embd(water_level).to(x.dtype)
288
+ embeds = torch.cat([t_embed, waterlevel_embd], dim=1)
289
+
290
+ h = torch.cat([x, ridge_map, basin_map, map_average], dim=1)
291
+ h = checkpoint(run_block, self.expand, h, use_reentrant=False) if self.training else self.expand(h)
292
+
293
+ # Encoder
294
+ skips = []
295
+ for i in range(self.depth):
296
+ h = checkpoint(run_block, self.enc_blocks[i], h, embeds, use_reentrant=False) if self.training else self.enc_blocks[i](h, embeds)
297
+ skips.append(h)
298
+ if i < self.depth - 1:
299
+ h = checkpoint(run_block, self.down_convs[i], h, use_reentrant=False) if self.training else self.down_convs[i](h)
300
+
301
+ # Bottleneck
302
+ h = checkpoint(run_block, self.bottleneck, h, use_reentrant=False) if self.training else self.bottleneck(h)
303
+
304
+ # Decoder
305
+ for i in range(self.depth):
306
+ h = checkpoint(run_block, self.up_blocks[i], h, embeds, skips[-(i + 1)], use_reentrant=False) if self.training else self.up_blocks[i](h, embeds, skips[-(i + 1)])
307
+
308
+ h = checkpoint(run_block, self.out, h, embeds, use_reentrant=False) if self.training else self.out(h, embeds)
309
+ h = checkpoint(run_block, self.final, h, use_reentrant=False) if self.training else self.final(h)
310
+ return h
311
+
312
+
313
+ class ConditionalUNetDiT(nn.Module):
314
+ def __init__(self, base_ch=8, embd_dim=16):
315
+ super().__init__()
316
+ self.time_embd = SinusoidalEmbedding(embd_dim, 1000)
317
+
318
+ # Input channels = noisy height (1) + ridge map (1) + lake map (1)
319
+ self.expand = nn.Conv2d(3, base_ch, 3, padding=1, padding_mode='replicate')
320
+ self.enc_0 = ResConvBlock(base_ch, embd_dim)
321
+
322
+ self.down0 = nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1, padding_mode='replicate') # 1024->512
323
+ self.enc_1 = ResConvBlock(base_ch * 2, embd_dim)
324
+ self.enc_1_dit = DiTBlock(base_ch * 2, 16, 1024, 8, 4)
325
+
326
+ self.down1 = nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1, padding_mode='replicate') # 512->256
327
+
328
+ self.up1 = UpBlockWithDit(base_ch * 4, base_ch * 2, 8, 1024, 8, embd_dim, False) # 256->512
329
+ self.up0 = UpBlockWithDit(base_ch * 2, base_ch, 16, 1024, 8, embd_dim, True) # 512->1024
330
+ self.out = ResConvBlock(base_ch * 2, embd_dim)
331
+ self.final = nn.Conv2d(base_ch * 2, 1, 1)
332
+
333
+ def forward(self, x, ridge_map, basin_map, water_level, t):
334
+ t_embed = self.time_embd(t).to(x.dtype)
335
+ # x: noisy height map, ridge_map: binary edges, basin_map: binary basins, water_level: the estimate sea level
336
+ h0 = torch.cat([x, ridge_map, basin_map], dim=1) # concat condition
337
+ # encode
338
+ h0 = checkpoint(run_block, self.expand, h0, use_reentrant=False) if self.training else self.expand(h0)
339
+ h0 = checkpoint(run_block, self.enc_0, h0, t_embed, use_reentrant=False) if self.training else self.enc_0(h0, t_embed)
340
+ h1 = checkpoint(run_block, self.down0, h0, use_reentrant=False) if self.training else self.down0(h0)
341
+ h1 = checkpoint(run_block, self.enc_1, h1, t_embed, use_reentrant=False) if self.training else self.enc_1(h1, t_embed)
342
+ h1 = checkpoint(run_block, self.enc_1_dit, h1, water_level, use_reentrant=False) if self.training else self.enc_1_dit(h1, water_level) # 512x512
343
+ h2 = checkpoint(run_block, self.down1, h1, use_reentrant=False) if self.training else self.down1(h1) # 256x256
344
+ # decode with skip connections
345
+ out = checkpoint(run_block, self.up1, h2, t_embed, water_level, h1, use_reentrant=False) if self.training else self.up1(h2, t_embed, water_level, h1) # 512x512
346
+ out = checkpoint(run_block, self.up0, out, t_embed, water_level, h0, use_reentrant=False) if self.training else self.up0(out, t_embed, water_level, h0) # 1024x1024
347
+ out = checkpoint(run_block, self.out, out, t_embed, use_reentrant=False) if self.training else self.out(out, t_embed)
348
+ out = self.final(out)
349
+ return out # predicted noise for diffusion loss
350
+
351
+
352
+
353
+ if __name__ == "__main__":
354
+ #a = ConditionalUNet()
355
+ #t = SinusoidalEmbedding(256)
356
+ #t_embd = t(torch.randint(0, 100, (1,)))
357
+ #x = torch.randn(1, 1, 256, 256)
358
+ #r = torch.randn(1, 1, 256, 256)
359
+ #c = a(x, r, t_embd)
360
+ #print(c)
361
+ #print(c.shape)
362
+ network = ConditionalUNetDiT()
363
+ for name, m in network.named_modules():
364
+ if isinstance(m, nn.Linear) and 'time_affine':
365
+ m.weight.data.zero_()
366
+ m.bias.data.zero_()
pl_module_rectifiedflow.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ import data_utils
5
+ import torch.utils.data
6
+
7
+ import imageio.v3 as imageio
8
+ import lightning.pytorch as pl
9
+ import torch.nn as nn
10
+ import torch.distributions as dist
11
+ import numpy as np
12
+ import safetensors.torch as st
13
+
14
+ from network_diffusion_unet import ConditionalUNet, ConditionalUNetDiT
15
+ from loss_fn import L1andGDL
16
+ from adam_atan2_pytorch import AdamAtan2
17
+ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
18
+ from lightning.pytorch.utilities import grad_norm
19
+ from lightning.pytorch.callbacks import LearningRateMonitor, StochasticWeightAveraging, LearningRateFinder
20
+ from torchvision.utils import make_grid
21
+
22
+ def convert_uniform_to_custom(u):
23
+ #return 0.5 - torch.cos((1/3) * torch.acos(1 - 2 * u) + math.pi / 3)
24
+ return 0.5 + 2 * torch.cos((2 * math.pi - torch.arccos((11/16)*(1-2*u)))/3)
25
+
26
+ class PLModule(pl.LightningModule):
27
+ def __init__(self, mid_visual_ridge, mid_visual_basins, mid_visual_gt):
28
+ super().__init__()
29
+ self.save_hyperparameters()
30
+ self.lr = 6e-4
31
+ self.wd = 5e-5
32
+ self.model = ConditionalUNetDiT(base_ch=8, embd_dim=16)
33
+ #self.map_average = torch.from_numpy(imageio.imread(map_average)).unsqueeze(0)
34
+ #self.map_average = (self.map_average - self.map_average.mean()) / self.map_average.std()
35
+ self.loss_fn = L1andGDL()
36
+ self.val_metrics = []
37
+ self.mid_visual_ridge, self.mid_visual_basins = mid_visual_ridge, mid_visual_basins
38
+ self.mid_visual_gt = mid_visual_gt
39
+ self.initialize_model()
40
+
41
+ def initialize_model(self):
42
+ for name, m in self.model.named_modules():
43
+ if isinstance(m, nn.Linear) and ('time_affine' in name or 'water_level_affine' in name):
44
+ m.weight.data.zero_()
45
+ m.bias.data.zero_()
46
+
47
+ def configure_optimizers(self):
48
+ opt = AdamAtan2(self.parameters(), lr=self.lr, decoupled_wd=True, weight_decay=self.wd)
49
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, 100, eta_min=1e-7)
50
+ return {
51
+ "optimizer": opt,
52
+ "lr_scheduler": {"scheduler": scheduler, "interval": "epoch", "frequency": 1},
53
+ }
54
+
55
+ def _step(self, batch, batch_idx):
56
+ x0, ridge_map, basin_map, water_level = batch
57
+ b = water_level.shape[0]
58
+ #map_average = self.map_average.expand((b, -1, -1, -1)).to(self.device)
59
+
60
+ noise = torch.randn_like(x0, device=self.device, dtype=x0.dtype)
61
+ t = torch.rand((b,), device=self.device)
62
+ t = convert_uniform_to_custom(t).to(x0.dtype)
63
+
64
+ xt = t.view(-1, 1, 1, 1) * x0 + (1 - t.view(-1, 1, 1, 1)) * noise
65
+ v = x0 - noise
66
+
67
+ predicted_v = self.model(xt, ridge_map, basin_map, water_level, t) # Predict velocity v
68
+ loss = self.loss_fn(predicted_v, v) # Loss between predicted and target v
69
+ return loss
70
+
71
+ def training_step(self, batch, batch_idx):
72
+ loss = self._step(batch, batch_idx)
73
+ self.logger.experiment.add_scalar(f"Train/Loss", loss.detach(), self.global_step)
74
+ return loss
75
+
76
+ def validation_step(self, batch, batch_idx):
77
+ loss = self._step(batch, batch_idx)
78
+ self.val_metrics.append(loss.detach())
79
+ return loss
80
+
81
+ @torch.no_grad()
82
+ def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
83
+ device = self.device
84
+ b = ridge_map.shape[0]
85
+ x = torch.randn_like(ridge_map, device=device)
86
+ water_level = torch.tensor((water_level,), device=device).expand(b,)
87
+ time = torch.linspace(0, 1, num_steps + 1, device=device)
88
+
89
+ for i in range(num_steps):
90
+ t = torch.full((b,), time[i], device=device)
91
+ dt = torch.full((b, 1, 1, 1), time[i+1] - time[i], device=device)
92
+
93
+ v = self.model(x, ridge_map, basin_map, water_level, t)
94
+
95
+ x = x + dt * v
96
+
97
+ return x
98
+
99
+ def on_train_epoch_end(self):
100
+ sea_level = 0.0
101
+ ridge_map = torch.from_numpy(imageio.imread(self.mid_visual_ridge))[None,None,:].to(device=self.device, dtype=torch.float32)
102
+
103
+ basin_map = torch.from_numpy(imageio.imread(self.mid_visual_basins))[None,None,:].to(device=self.device)
104
+ basin_map = (basin_map>=sea_level).to(torch.float32)
105
+ output = self.inference_step(ridge_map, basin_map, sea_level)
106
+ mid_visual_result = output.squeeze([1])
107
+ self.logger.experiment.add_scalar("Visualize/Min", mid_visual_result.min(), self.current_epoch)
108
+ self.logger.experiment.add_scalar("Visualize/Max", mid_visual_result.max(), self.current_epoch)
109
+ self.logger.experiment.add_scalar("Visualize/Mean", mid_visual_result.mean(), self.current_epoch)
110
+ mid_visual_result = (mid_visual_result - mid_visual_result.min()) / (mid_visual_result.max() - mid_visual_result.min())
111
+ self.logger.experiment.add_image(f'Visualize/Model Output', mid_visual_result, self.current_epoch)
112
+
113
+ vram_data = torch.cuda.mem_get_info()
114
+ vram_usage = (vram_data[1] - vram_data[0]) / (1024 ** 2)
115
+ self.logger.experiment.add_scalar(f"Other/VRAM Usage", vram_usage, self.current_epoch)
116
+ torch.cuda.reset_peak_memory_stats()
117
+ if self.current_epoch == 0:
118
+ mid_visual_gt = torch.from_numpy(imageio.imread(self.mid_visual_gt))[None,:]
119
+ mid_visual_gt = (mid_visual_gt - mid_visual_gt.min()) / (mid_visual_gt.max() - mid_visual_gt.min())
120
+ self.logger.experiment.add_image(f'Visualize/Ridge', ridge_map.squeeze([1]), self.current_epoch)
121
+ self.logger.experiment.add_image(f'Visualize/Basin', basin_map.squeeze([1]), self.current_epoch)
122
+ self.logger.experiment.add_image(f'Visualize/GT', mid_visual_gt, self.current_epoch)
123
+
124
+ def on_validation_epoch_end(self):
125
+ epoch_averages = torch.stack(self.val_metrics).nanmean(dim=0)
126
+ self.logger.experiment.add_scalar("Val/Loss", epoch_averages, self.current_epoch)
127
+ self.val_metrics.clear()
128
+ #def on_before_optimizer_step(self, optimizer):
129
+ # norms = grad_norm(self.model, norm_type=2)
130
+ # self.log_dict(norms, logger=True)
131
+
132
+
133
+
134
+ # Example usage
135
+ if __name__ == "__main__":
136
+ torch.set_float32_matmul_precision('medium')
137
+ if torch.cuda.is_available() and torch.version.cuda:
138
+ print('Optimising computing and memory use via cuDNN! (NVIDIA GPU only).')
139
+ torch.backends.cudnn.enabled = True
140
+ torch.backends.cudnn.benchmark = True
141
+ torch.backends.cudnn.allow_tf32 = True
142
+ elif torch.cuda.is_available() and torch.version.hip:
143
+ print('Optimising computing using TunableOp! (AMD GPU only).')
144
+ torch.cuda.tunable.enable()
145
+ torch.cuda.tunable.set_filename('TunableOp_results')
146
+
147
+ train_split, val_split = data_utils.make_dataset_t_v('dataset_large')
148
+
149
+ callbacks = []
150
+ callbacks.append(LearningRateMonitor(logging_interval='epoch'))
151
+ model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath="", filename="FlashScape",
152
+ save_weights_only=False,
153
+ enable_version_counter=False, save_last=False)
154
+ callbacks.append(model_checkpoint)
155
+ swa_callback = StochasticWeightAveraging(1e-5, 0.8, int(0.2 * 100 - 1))
156
+ callbacks.append(swa_callback)
157
+ #lr_finder = LearningRateFinder(1e-5, 0.1)
158
+ #callbacks.append(lr_finder)
159
+ #model = PLModule.load_from_checkpoint('FlashScape V2.ckpt')
160
+ trainer = pl.Trainer(max_epochs=100, log_every_n_steps=1, logger=TensorBoardLogger(f'lightning_logs', name='FlashScape Dit No MapAvg Zero Init'),
161
+ accelerator="gpu", enable_checkpointing=True,
162
+ precision='16-mixed', enable_progress_bar=True, num_sanity_val_steps=0, callbacks=callbacks)
163
+ with trainer.init_module():
164
+ model = PLModule('dataset_large/Ridge_11417648.tiff',
165
+ 'dataset_large/Basins_11417648.tiff',
166
+ 'dataset_large/11417648.tiff')
167
+ model = torch.compile(model)
168
+
169
+
170
+ train_dataset = data_utils.TrainDataset(train_split)
171
+ val_dataset = data_utils.ValDataset(val_split)
172
+
173
+ train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8,
174
+ num_workers=8, pin_memory=False, persistent_workers=True, shuffle=True)
175
+ val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=8,
176
+ num_workers=8, pin_memory=False, persistent_workers=True)
177
+
178
+ trainer.fit(model,
179
+ val_dataloaders=val_loader,
180
+ train_dataloaders=train_loader)
181
+ #ckpt_path='FlashScape.ckpt')