Fgdfgfthgr commited on
Commit
4936d6c
·
verified ·
1 Parent(s): c8c6825

Delete mass_generate_examples.py

Browse files
Files changed (1) hide show
  1. mass_generate_examples.py +0 -122
mass_generate_examples.py DELETED
@@ -1,122 +0,0 @@
1
- import torch
2
- import math
3
-
4
- import torch.utils.data
5
-
6
- import imageio.v3 as imageio
7
- import lightning.pytorch as pl
8
- import matplotlib.pyplot as plt
9
-
10
- from network_diffusion_unet import ConditionalUNetDiT
11
- from safetensors.torch import load_file
12
-
13
- class PLModule(pl.LightningModule):
14
- def __init__(self):
15
- super().__init__()
16
- self.model = ConditionalUNetDiT(8, 16)
17
-
18
- @torch.no_grad()
19
- def inference_step(self, ridge_map, basin_map, water_level, num_steps=50):
20
- device = self.device
21
- b = ridge_map.shape[0]
22
- x = torch.randn_like(ridge_map, device=device, dtype=torch.float16)
23
- water_level = torch.tensor((water_level,), device=device, dtype=torch.float16).expand(b, )
24
- time = torch.linspace(0, 1, num_steps + 1, device=device, dtype=torch.float16)
25
-
26
- for i in range(num_steps):
27
- t = torch.full((b,), time[i], device=device, dtype=torch.float16)
28
- dt = torch.full((b, 1, 1, 1), time[i + 1] - time[i], device=device, dtype=torch.float16)
29
-
30
- v = self.model(x, ridge_map, basin_map, water_level, t)
31
-
32
- x = x + dt * v
33
-
34
- return x
35
-
36
- if __name__ == "__main__":
37
- #model = PLModule.load_from_checkpoint('FlashScape.ckpt').to(device='cuda', dtype=torch.float16)
38
- model = PLModule()
39
- model.model.load_state_dict(load_file('FlashScape.safetensors'))
40
- model.to(device='cuda', dtype=torch.float16)
41
- model.eval()
42
-
43
- test_ridge = torch.from_numpy(imageio.imread('dataset_large/Ridge_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
44
- test_basin = torch.from_numpy(imageio.imread('dataset_large/Basins_11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
45
- gt = torch.from_numpy(imageio.imread('dataset_large/11417648.tiff'))[None, None, :].to(dtype=torch.float16, device='cuda')
46
- water_level = 0.0
47
- num_steps = 50
48
- num_images = 16
49
-
50
- test_basin = (test_basin >= water_level).to(torch.float16)
51
- test_ridge = test_ridge.expand(num_images, -1, -1, -1)
52
- test_basin = test_basin.expand(num_images, -1, -1, -1)
53
- generated = model.inference_step(test_ridge, test_basin, water_level, num_steps)
54
- # Back to original range
55
- generated = generated * 330.8314960521203 + 149.95293407563648
56
-
57
- # Prepare images for visualization
58
- ridge_display = test_ridge[0, 0].cpu().float()
59
- basin_display = test_basin[0, 0].cpu().float()
60
- gt_display = gt[0, 0].cpu().float()
61
- generated_display = generated[:, 0].cpu() # Remove channel dim
62
-
63
- # Calculate optimal grid layout
64
- total_images = num_images + 3 # condition1+ condition2 + gt + generated images
65
- image_size = ridge_display.shape[0] # assuming square images
66
-
67
- # Determine optimal number of columns (aim for roughly 4:3 aspect ratio)
68
- max_cols = min(6, total_images) # Maximum 6 columns for readability
69
- cols = min(max_cols, total_images)
70
- rows = math.ceil(total_images / cols)
71
-
72
- # Calculate figure size based on image dimensions and grid layout
73
- base_height_per_image = 5 # inches per image height
74
- base_width_per_image = 5 # inches per image width
75
-
76
- fig_width = cols * base_width_per_image + 0.1 # +1 for colorbar space
77
- fig_height = rows * base_height_per_image
78
-
79
- # Create figure with subplots
80
- fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
81
-
82
- # Flatten axes array for easier indexing
83
- if rows > 1 and cols > 1:
84
- axes = axes.flatten()
85
- elif rows == 1 and cols > 1:
86
- axes = axes
87
- elif rows > 1 and cols == 1:
88
- axes = axes[:, 0]
89
- else:
90
- axes = [axes]
91
-
92
- # Hide unused subplots
93
- for i in range(total_images, len(axes)):
94
- axes[i].set_visible(False)
95
-
96
- # Plot condition image
97
- im0 = axes[0].imshow(ridge_display, cmap='gray')
98
- axes[0].set_title('Ridge Condition', fontsize=12, pad=2)
99
- axes[0].set_axis_off()
100
-
101
- # Plot condition image
102
- im1 = axes[1].imshow(basin_display, cmap='gray')
103
- axes[1].set_title(f'Basin Condition at level {water_level}', fontsize=12, pad=2)
104
- axes[1].set_axis_off()
105
-
106
- # Plot ground truth image
107
- im2 = axes[2].imshow(gt_display, cmap='gray')
108
- axes[2].set_title('Ground Truth', fontsize=12, pad=2)
109
- axes[2].set_axis_off()
110
-
111
- # Plot generated images
112
- for i in range(num_images):
113
- im = axes[i + 3].imshow(generated_display[i], cmap='gray')
114
- axes[i + 3].set_title(f'Generated {i + 1}', fontsize=10, pad=2)
115
- axes[i + 3].set_axis_off()
116
-
117
- # Add colorbar
118
- cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8, location='right')
119
- cbar.set_label('Elevation', fontsize=14)
120
-
121
- plt.savefig('result_grid.png', bbox_inches='tight', dpi=300)
122
- plt.show()