File size: 5,919 Bytes
e18319b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""

Inference script for the conditional diffusion model.

This script provides easy-to-use functions for generating medical images.

"""

import torch
import numpy as np
import matplotlib.pyplot as plt
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler


class ConditionalDiffusionInference:
    """Wrapper class for easy inference with the conditional diffusion model."""
    
    def __init__(self, model_path, device='cuda'):
        """

        Initialize the inference model.

        

        Args:

            model_path: Path to the trained model checkpoint

            device: Device to run inference on ('cuda' or 'cpu')

        """
        self.device = device
        self.Lambda = 25.0
        
        # Initialize the model
        self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
        self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
        
        self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
        self.score_model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.score_model.to(self.device)
        self.score_model.eval()
        
        print(f"Model loaded successfully on {self.device}")
    
    def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3):
        """

        Generate a medical image based on a conditioning mask.

        

        Args:

            conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256)

            num_steps: Number of sampling steps

            eps: Smallest time step for numerical stability

            

        Returns:

            Generated image tensor of shape (1, 4, 256, 256)

        """
        if not isinstance(conditioning_mask, torch.Tensor):
            conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32)
        
        if conditioning_mask.dim() == 3:
            conditioning_mask = conditioning_mask.unsqueeze(0)
        
        conditioning_mask = conditioning_mask.to(self.device)
        
        with torch.no_grad():
            samples = Euler_Maruyama_sampler(
                self.score_model,
                self.marginal_prob_std_fn,
                self.diffusion_coeff_fn,
                batch_size=1,
                x_shape=(4, 256, 256),
                num_steps=num_steps,
                device=self.device,
                eps=eps,
                y=conditioning_mask
            )
        
        return samples.clamp(0, 1)
    
    def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3):
        """

        Generate multiple images based on conditioning masks.

        

        Args:

            conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256)

            num_steps: Number of sampling steps

            eps: Smallest time step for numerical stability

            

        Returns:

            Generated images tensor of shape (B, 4, 256, 256)

        """
        if not isinstance(conditioning_masks, torch.Tensor):
            conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32)
        
        if conditioning_masks.dim() == 3:
            conditioning_masks = conditioning_masks.unsqueeze(0)
        
        conditioning_masks = conditioning_masks.to(self.device)
        batch_size = conditioning_masks.shape[0]
        
        with torch.no_grad():
            samples = Euler_Maruyama_sampler(
                self.score_model,
                self.marginal_prob_std_fn,
                self.diffusion_coeff_fn,
                batch_size=batch_size,
                x_shape=(4, 256, 256),
                num_steps=num_steps,
                device=self.device,
                eps=eps,
                y=conditioning_masks
            )
        
        return samples.clamp(0, 1)
    
    def visualize_generation(self, conditioning_mask, generated_image, save_path=None):
        """

        Visualize the conditioning mask and generated image.

        

        Args:

            conditioning_mask: Conditioning mask tensor

            generated_image: Generated image tensor

            save_path: Optional path to save the visualization

        """
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        
        # Plot conditioning mask
        for i in range(4):
            axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray')
            axes[0, i].set_title(f'Conditioning Mask {i+1}')
            axes[0, i].axis('off')
        
        # Plot generated image
        for i in range(4):
            axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray')
            axes[1, i].set_title(f'Generated Image {i+1}')
            axes[1, i].axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Visualization saved to {save_path}")
        
        plt.show()


def main():
    """Example usage of the inference model."""
    
    # Initialize the model
    model_path = "ckpt_3D_v2.pth"  # Update with your model path
    inference_model = ConditionalDiffusionInference(model_path, device='cuda')
    
    # Create a random conditioning mask (replace with your actual mask)
    conditioning_mask = torch.randn(1, 4, 256, 256)
    
    # Generate image
    print("Generating image...")
    generated_image = inference_model.generate_image(conditioning_mask)
    
    # Visualize results
    inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png")
    
    print("Generation complete!")


if __name__ == "__main__":
    main()