#!/usr/bin/env python3 """ Inference script for the conditional diffusion model. This script provides easy-to-use functions for generating medical images. """ import torch import numpy as np import matplotlib.pyplot as plt from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler class ConditionalDiffusionInference: """Wrapper class for easy inference with the conditional diffusion model.""" def __init__(self, model_path, device='cuda'): """ Initialize the inference model. Args: model_path: Path to the trained model checkpoint device: Device to run inference on ('cuda' or 'cpu') """ self.device = device self.Lambda = 25.0 # Initialize the model self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device) self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device) self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn) self.score_model.load_state_dict(torch.load(model_path, map_location=self.device)) self.score_model.to(self.device) self.score_model.eval() print(f"Model loaded successfully on {self.device}") def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3): """ Generate a medical image based on a conditioning mask. Args: conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256) num_steps: Number of sampling steps eps: Smallest time step for numerical stability Returns: Generated image tensor of shape (1, 4, 256, 256) """ if not isinstance(conditioning_mask, torch.Tensor): conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32) if conditioning_mask.dim() == 3: conditioning_mask = conditioning_mask.unsqueeze(0) conditioning_mask = conditioning_mask.to(self.device) with torch.no_grad(): samples = Euler_Maruyama_sampler( self.score_model, self.marginal_prob_std_fn, self.diffusion_coeff_fn, batch_size=1, x_shape=(4, 256, 256), num_steps=num_steps, device=self.device, eps=eps, y=conditioning_mask ) return samples.clamp(0, 1) def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3): """ Generate multiple images based on conditioning masks. Args: conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256) num_steps: Number of sampling steps eps: Smallest time step for numerical stability Returns: Generated images tensor of shape (B, 4, 256, 256) """ if not isinstance(conditioning_masks, torch.Tensor): conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32) if conditioning_masks.dim() == 3: conditioning_masks = conditioning_masks.unsqueeze(0) conditioning_masks = conditioning_masks.to(self.device) batch_size = conditioning_masks.shape[0] with torch.no_grad(): samples = Euler_Maruyama_sampler( self.score_model, self.marginal_prob_std_fn, self.diffusion_coeff_fn, batch_size=batch_size, x_shape=(4, 256, 256), num_steps=num_steps, device=self.device, eps=eps, y=conditioning_masks ) return samples.clamp(0, 1) def visualize_generation(self, conditioning_mask, generated_image, save_path=None): """ Visualize the conditioning mask and generated image. Args: conditioning_mask: Conditioning mask tensor generated_image: Generated image tensor save_path: Optional path to save the visualization """ fig, axes = plt.subplots(2, 4, figsize=(16, 8)) # Plot conditioning mask for i in range(4): axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray') axes[0, i].set_title(f'Conditioning Mask {i+1}') axes[0, i].axis('off') # Plot generated image for i in range(4): axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray') axes[1, i].set_title(f'Generated Image {i+1}') axes[1, i].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Visualization saved to {save_path}") plt.show() def main(): """Example usage of the inference model.""" # Initialize the model model_path = "ckpt_3D_v2.pth" # Update with your model path inference_model = ConditionalDiffusionInference(model_path, device='cuda') # Create a random conditioning mask (replace with your actual mask) conditioning_mask = torch.randn(1, 4, 256, 256) # Generate image print("Generating image...") generated_image = inference_model.generate_image(conditioning_mask) # Visualize results inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png") print("Generation complete!") if __name__ == "__main__": main()