| import torch
|
| import math
|
|
|
| import torch.utils.data
|
|
|
| import imageio.v3 as imageio
|
| import lightning.pytorch as pl
|
| import matplotlib.pyplot as plt
|
|
|
| from network_diffusion_unet import ConditionalUNetDiT
|
| from safetensors.torch import load_file
|
|
|
| class PLModule(pl.LightningModule):
|
| def __init__(self):
|
| super().__init__()
|
| self.model = ConditionalUNetDiT(8, 16)
|
|
|
| @torch.no_grad()
|
| def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
|
| device = self.device
|
| b = ridge_map.shape[0]
|
| x = torch.randn_like(ridge_map, device=device, dtype=torch.float16)
|
| water_level = torch.tensor((water_level,), device=device, dtype=torch.float16).expand(b, )
|
| time = torch.linspace(0, 1, num_steps + 1, device=device, dtype=torch.float16)
|
|
|
| for i in range(num_steps):
|
| t = torch.full((b,), time[i], device=device, dtype=torch.float16)
|
| dt = torch.full((b, 1, 1, 1), time[i + 1] - time[i], device=device, dtype=torch.float16)
|
|
|
| v = self.model(x, ridge_map, basin_map, water_level, t)
|
|
|
| x = x + dt * v
|
|
|
| return x
|
|
|
| if __name__ == "__main__":
|
|
|
| model = PLModule()
|
| model.model.load_state_dict(load_file('FlashScape.safetensors'))
|
| model.to(device='cuda', dtype=torch.float16)
|
| model.eval()
|
|
|
| test_ridge = torch.from_numpy(imageio.imread('dataset_large/Ridge_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
|
| test_basin = torch.from_numpy(imageio.imread('dataset_large/Basins_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
|
| gt = torch.from_numpy(imageio.imread('dataset_large/11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
|
| water_level = 300.0
|
| num_steps = 10
|
| num_images = 4
|
|
|
| test_basin = (test_basin >= water_level).to(torch.float16)
|
| test_ridge = test_ridge.expand(num_images, -1, -1, -1)
|
| test_basin = test_basin.expand(num_images, -1, -1, -1)
|
| generated = model.inference_step(test_ridge, test_basin, water_level, num_steps)
|
|
|
| generated = generated * 330.8314960521203 + 149.95293407563648
|
|
|
|
|
| ridge_display = test_ridge[0, 0].cpu().float()
|
| basin_display = test_basin[0, 0].cpu().float()
|
| gt_display = gt[0, 0].cpu().float()
|
| generated_display = generated[:, 0].cpu()
|
|
|
|
|
| total_images = num_images + 3
|
| image_size = ridge_display.shape[0]
|
|
|
|
|
| max_cols = min(6, total_images)
|
| cols = min(max_cols, total_images)
|
| rows = math.ceil(total_images / cols)
|
|
|
|
|
| base_height_per_image = 5
|
| base_width_per_image = 5
|
|
|
| fig_width = cols * base_width_per_image + 0.1
|
| fig_height = rows * base_height_per_image
|
|
|
|
|
| fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
|
|
|
|
|
| if rows > 1 and cols > 1:
|
| axes = axes.flatten()
|
| elif rows == 1 and cols > 1:
|
| axes = axes
|
| elif rows > 1 and cols == 1:
|
| axes = axes[:, 0]
|
| else:
|
| axes = [axes]
|
|
|
|
|
| for i in range(total_images, len(axes)):
|
| axes[i].set_visible(False)
|
|
|
|
|
| im0 = axes[0].imshow(ridge_display, cmap='gray')
|
| axes[0].set_title('Ridge Condition', fontsize=12, pad=2)
|
| axes[0].set_axis_off()
|
|
|
|
|
| im1 = axes[1].imshow(basin_display, cmap='gray')
|
| axes[1].set_title('Basin Condition', fontsize=12, pad=2)
|
| axes[1].set_axis_off()
|
|
|
|
|
| im2 = axes[2].imshow(gt_display, cmap='gray')
|
| axes[2].set_title('Ground Truth', fontsize=12, pad=2)
|
| axes[2].set_axis_off()
|
|
|
|
|
| for i in range(num_images):
|
| im = axes[i + 3].imshow(generated_display[i], cmap='gray')
|
| axes[i + 3].set_title(f'Generated {i + 1}', fontsize=10, pad=2)
|
| axes[i + 3].set_axis_off()
|
|
|
|
|
| cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, location='right')
|
| cbar.set_label('Elevation', fontsize=14)
|
|
|
| plt.savefig('result_grid.png', bbox_inches='tight', dpi=300)
|
| plt.show() |