tan200224 commited on
Commit
e18319b
·
verified ·
1 Parent(s): ab50f76

Add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +162 -0
inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for the conditional diffusion model.
4
+ This script provides easy-to-use functions for generating medical images.
5
+ """
6
+
7
+ import torch
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
11
+
12
+
13
+ class ConditionalDiffusionInference:
14
+ """Wrapper class for easy inference with the conditional diffusion model."""
15
+
16
+ def __init__(self, model_path, device='cuda'):
17
+ """
18
+ Initialize the inference model.
19
+
20
+ Args:
21
+ model_path: Path to the trained model checkpoint
22
+ device: Device to run inference on ('cuda' or 'cpu')
23
+ """
24
+ self.device = device
25
+ self.Lambda = 25.0
26
+
27
+ # Initialize the model
28
+ self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
29
+ self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
30
+
31
+ self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
32
+ self.score_model.load_state_dict(torch.load(model_path, map_location=self.device))
33
+ self.score_model.to(self.device)
34
+ self.score_model.eval()
35
+
36
+ print(f"Model loaded successfully on {self.device}")
37
+
38
+ def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3):
39
+ """
40
+ Generate a medical image based on a conditioning mask.
41
+
42
+ Args:
43
+ conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256)
44
+ num_steps: Number of sampling steps
45
+ eps: Smallest time step for numerical stability
46
+
47
+ Returns:
48
+ Generated image tensor of shape (1, 4, 256, 256)
49
+ """
50
+ if not isinstance(conditioning_mask, torch.Tensor):
51
+ conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32)
52
+
53
+ if conditioning_mask.dim() == 3:
54
+ conditioning_mask = conditioning_mask.unsqueeze(0)
55
+
56
+ conditioning_mask = conditioning_mask.to(self.device)
57
+
58
+ with torch.no_grad():
59
+ samples = Euler_Maruyama_sampler(
60
+ self.score_model,
61
+ self.marginal_prob_std_fn,
62
+ self.diffusion_coeff_fn,
63
+ batch_size=1,
64
+ x_shape=(4, 256, 256),
65
+ num_steps=num_steps,
66
+ device=self.device,
67
+ eps=eps,
68
+ y=conditioning_mask
69
+ )
70
+
71
+ return samples.clamp(0, 1)
72
+
73
+ def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3):
74
+ """
75
+ Generate multiple images based on conditioning masks.
76
+
77
+ Args:
78
+ conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256)
79
+ num_steps: Number of sampling steps
80
+ eps: Smallest time step for numerical stability
81
+
82
+ Returns:
83
+ Generated images tensor of shape (B, 4, 256, 256)
84
+ """
85
+ if not isinstance(conditioning_masks, torch.Tensor):
86
+ conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32)
87
+
88
+ if conditioning_masks.dim() == 3:
89
+ conditioning_masks = conditioning_masks.unsqueeze(0)
90
+
91
+ conditioning_masks = conditioning_masks.to(self.device)
92
+ batch_size = conditioning_masks.shape[0]
93
+
94
+ with torch.no_grad():
95
+ samples = Euler_Maruyama_sampler(
96
+ self.score_model,
97
+ self.marginal_prob_std_fn,
98
+ self.diffusion_coeff_fn,
99
+ batch_size=batch_size,
100
+ x_shape=(4, 256, 256),
101
+ num_steps=num_steps,
102
+ device=self.device,
103
+ eps=eps,
104
+ y=conditioning_masks
105
+ )
106
+
107
+ return samples.clamp(0, 1)
108
+
109
+ def visualize_generation(self, conditioning_mask, generated_image, save_path=None):
110
+ """
111
+ Visualize the conditioning mask and generated image.
112
+
113
+ Args:
114
+ conditioning_mask: Conditioning mask tensor
115
+ generated_image: Generated image tensor
116
+ save_path: Optional path to save the visualization
117
+ """
118
+ fig, axes = plt.subplots(2, 4, figsize=(16, 8))
119
+
120
+ # Plot conditioning mask
121
+ for i in range(4):
122
+ axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray')
123
+ axes[0, i].set_title(f'Conditioning Mask {i+1}')
124
+ axes[0, i].axis('off')
125
+
126
+ # Plot generated image
127
+ for i in range(4):
128
+ axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray')
129
+ axes[1, i].set_title(f'Generated Image {i+1}')
130
+ axes[1, i].axis('off')
131
+
132
+ plt.tight_layout()
133
+
134
+ if save_path:
135
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
136
+ print(f"Visualization saved to {save_path}")
137
+
138
+ plt.show()
139
+
140
+
141
+ def main():
142
+ """Example usage of the inference model."""
143
+
144
+ # Initialize the model
145
+ model_path = "ckpt_3D_v2.pth" # Update with your model path
146
+ inference_model = ConditionalDiffusionInference(model_path, device='cuda')
147
+
148
+ # Create a random conditioning mask (replace with your actual mask)
149
+ conditioning_mask = torch.randn(1, 4, 256, 256)
150
+
151
+ # Generate image
152
+ print("Generating image...")
153
+ generated_image = inference_model.generate_image(conditioning_mask)
154
+
155
+ # Visualize results
156
+ inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png")
157
+
158
+ print("Generation complete!")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()