Delete mass_generate_examples.py
Browse files- mass_generate_examples.py +0 -122
mass_generate_examples.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import math
|
| 3 |
-
|
| 4 |
-
import torch.utils.data
|
| 5 |
-
|
| 6 |
-
import imageio.v3 as imageio
|
| 7 |
-
import lightning.pytorch as pl
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
|
| 10 |
-
from network_diffusion_unet import ConditionalUNetDiT
|
| 11 |
-
from safetensors.torch import load_file
|
| 12 |
-
|
| 13 |
-
class PLModule(pl.LightningModule):
|
| 14 |
-
def __init__(self):
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.model = ConditionalUNetDiT(8, 16)
|
| 17 |
-
|
| 18 |
-
@torch.no_grad()
|
| 19 |
-
def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
|
| 20 |
-
device = self.device
|
| 21 |
-
b = ridge_map.shape[0]
|
| 22 |
-
x = torch.randn_like(ridge_map, device=device, dtype=torch.float16)
|
| 23 |
-
water_level = torch.tensor((water_level,), device=device, dtype=torch.float16).expand(b, )
|
| 24 |
-
time = torch.linspace(0, 1, num_steps + 1, device=device, dtype=torch.float16)
|
| 25 |
-
|
| 26 |
-
for i in range(num_steps):
|
| 27 |
-
t = torch.full((b,), time[i], device=device, dtype=torch.float16)
|
| 28 |
-
dt = torch.full((b, 1, 1, 1), time[i + 1] - time[i], device=device, dtype=torch.float16)
|
| 29 |
-
|
| 30 |
-
v = self.model(x, ridge_map, basin_map, water_level, t)
|
| 31 |
-
|
| 32 |
-
x = x + dt * v
|
| 33 |
-
|
| 34 |
-
return x
|
| 35 |
-
|
| 36 |
-
if __name__ == "__main__":
|
| 37 |
-
#model = PLModule.load_from_checkpoint('FlashScape.ckpt').to(device='cuda', dtype=torch.float16)
|
| 38 |
-
model = PLModule()
|
| 39 |
-
model.model.load_state_dict(load_file('FlashScape.safetensors'))
|
| 40 |
-
model.to(device='cuda', dtype=torch.float16)
|
| 41 |
-
model.eval()
|
| 42 |
-
|
| 43 |
-
test_ridge = torch.from_numpy(imageio.imread('dataset_large/Ridge_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
|
| 44 |
-
test_basin = torch.from_numpy(imageio.imread('dataset_large/Basins_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
|
| 45 |
-
gt = torch.from_numpy(imageio.imread('dataset_large/11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
|
| 46 |
-
water_level = 0.0
|
| 47 |
-
num_steps = 50
|
| 48 |
-
num_images = 16
|
| 49 |
-
|
| 50 |
-
test_basin = (test_basin >= water_level).to(torch.float16)
|
| 51 |
-
test_ridge = test_ridge.expand(num_images, -1, -1, -1)
|
| 52 |
-
test_basin = test_basin.expand(num_images, -1, -1, -1)
|
| 53 |
-
generated = model.inference_step(test_ridge, test_basin, water_level, num_steps)
|
| 54 |
-
# Back to original range
|
| 55 |
-
generated = generated * 330.8314960521203 + 149.95293407563648
|
| 56 |
-
|
| 57 |
-
# Prepare images for visualization
|
| 58 |
-
ridge_display = test_ridge[0, 0].cpu().float()
|
| 59 |
-
basin_display = test_basin[0, 0].cpu().float()
|
| 60 |
-
gt_display = gt[0, 0].cpu().float()
|
| 61 |
-
generated_display = generated[:, 0].cpu() # Remove channel dim
|
| 62 |
-
|
| 63 |
-
# Calculate optimal grid layout
|
| 64 |
-
total_images = num_images + 3 # condition1+ condition2 + gt + generated images
|
| 65 |
-
image_size = ridge_display.shape[0] # assuming square images
|
| 66 |
-
|
| 67 |
-
# Determine optimal number of columns (aim for roughly 4:3 aspect ratio)
|
| 68 |
-
max_cols = min(6, total_images) # Maximum 6 columns for readability
|
| 69 |
-
cols = min(max_cols, total_images)
|
| 70 |
-
rows = math.ceil(total_images / cols)
|
| 71 |
-
|
| 72 |
-
# Calculate figure size based on image dimensions and grid layout
|
| 73 |
-
base_height_per_image = 5 # inches per image height
|
| 74 |
-
base_width_per_image = 5 # inches per image width
|
| 75 |
-
|
| 76 |
-
fig_width = cols * base_width_per_image + 0.1 # +1 for colorbar space
|
| 77 |
-
fig_height = rows * base_height_per_image
|
| 78 |
-
|
| 79 |
-
# Create figure with subplots
|
| 80 |
-
fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
|
| 81 |
-
|
| 82 |
-
# Flatten axes array for easier indexing
|
| 83 |
-
if rows > 1 and cols > 1:
|
| 84 |
-
axes = axes.flatten()
|
| 85 |
-
elif rows == 1 and cols > 1:
|
| 86 |
-
axes = axes
|
| 87 |
-
elif rows > 1 and cols == 1:
|
| 88 |
-
axes = axes[:, 0]
|
| 89 |
-
else:
|
| 90 |
-
axes = [axes]
|
| 91 |
-
|
| 92 |
-
# Hide unused subplots
|
| 93 |
-
for i in range(total_images, len(axes)):
|
| 94 |
-
axes[i].set_visible(False)
|
| 95 |
-
|
| 96 |
-
# Plot condition image
|
| 97 |
-
im0 = axes[0].imshow(ridge_display, cmap='gray')
|
| 98 |
-
axes[0].set_title('Ridge Condition', fontsize=12, pad=2)
|
| 99 |
-
axes[0].set_axis_off()
|
| 100 |
-
|
| 101 |
-
# Plot condition image
|
| 102 |
-
im1 = axes[1].imshow(basin_display, cmap='gray')
|
| 103 |
-
axes[1].set_title(f'Basin Condition at level {water_level}', fontsize=12, pad=2)
|
| 104 |
-
axes[1].set_axis_off()
|
| 105 |
-
|
| 106 |
-
# Plot ground truth image
|
| 107 |
-
im2 = axes[2].imshow(gt_display, cmap='gray')
|
| 108 |
-
axes[2].set_title('Ground Truth', fontsize=12, pad=2)
|
| 109 |
-
axes[2].set_axis_off()
|
| 110 |
-
|
| 111 |
-
# Plot generated images
|
| 112 |
-
for i in range(num_images):
|
| 113 |
-
im = axes[i + 3].imshow(generated_display[i], cmap='gray')
|
| 114 |
-
axes[i + 3].set_title(f'Generated {i + 1}', fontsize=10, pad=2)
|
| 115 |
-
axes[i + 3].set_axis_off()
|
| 116 |
-
|
| 117 |
-
# Add colorbar
|
| 118 |
-
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, location='right')
|
| 119 |
-
cbar.set_label('Elevation', fontsize=14)
|
| 120 |
-
|
| 121 |
-
plt.savefig('result_grid.png', bbox_inches='tight', dpi=300)
|
| 122 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|