Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Inference script for the conditional diffusion model. | |
| This script provides easy-to-use functions for generating medical images. | |
| """ | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler | |
| class ConditionalDiffusionInference: | |
| """Wrapper class for easy inference with the conditional diffusion model.""" | |
| def __init__(self, model_path, device='cuda'): | |
| """ | |
| Initialize the inference model. | |
| Args: | |
| model_path: Path to the trained model checkpoint | |
| device: Device to run inference on ('cuda' or 'cpu') | |
| """ | |
| self.device = device | |
| self.Lambda = 25.0 | |
| # Initialize the model | |
| self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device) | |
| self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device) | |
| self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn) | |
| self.score_model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.score_model.to(self.device) | |
| self.score_model.eval() | |
| print(f"Model loaded successfully on {self.device}") | |
| def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3): | |
| """ | |
| Generate a medical image based on a conditioning mask. | |
| Args: | |
| conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256) | |
| num_steps: Number of sampling steps | |
| eps: Smallest time step for numerical stability | |
| Returns: | |
| Generated image tensor of shape (1, 4, 256, 256) | |
| """ | |
| if not isinstance(conditioning_mask, torch.Tensor): | |
| conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32) | |
| if conditioning_mask.dim() == 3: | |
| conditioning_mask = conditioning_mask.unsqueeze(0) | |
| conditioning_mask = conditioning_mask.to(self.device) | |
| with torch.no_grad(): | |
| samples = Euler_Maruyama_sampler( | |
| self.score_model, | |
| self.marginal_prob_std_fn, | |
| self.diffusion_coeff_fn, | |
| batch_size=1, | |
| x_shape=(4, 256, 256), | |
| num_steps=num_steps, | |
| device=self.device, | |
| eps=eps, | |
| y=conditioning_mask | |
| ) | |
| return samples.clamp(0, 1) | |
| def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3): | |
| """ | |
| Generate multiple images based on conditioning masks. | |
| Args: | |
| conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256) | |
| num_steps: Number of sampling steps | |
| eps: Smallest time step for numerical stability | |
| Returns: | |
| Generated images tensor of shape (B, 4, 256, 256) | |
| """ | |
| if not isinstance(conditioning_masks, torch.Tensor): | |
| conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32) | |
| if conditioning_masks.dim() == 3: | |
| conditioning_masks = conditioning_masks.unsqueeze(0) | |
| conditioning_masks = conditioning_masks.to(self.device) | |
| batch_size = conditioning_masks.shape[0] | |
| with torch.no_grad(): | |
| samples = Euler_Maruyama_sampler( | |
| self.score_model, | |
| self.marginal_prob_std_fn, | |
| self.diffusion_coeff_fn, | |
| batch_size=batch_size, | |
| x_shape=(4, 256, 256), | |
| num_steps=num_steps, | |
| device=self.device, | |
| eps=eps, | |
| y=conditioning_masks | |
| ) | |
| return samples.clamp(0, 1) | |
| def visualize_generation(self, conditioning_mask, generated_image, save_path=None): | |
| """ | |
| Visualize the conditioning mask and generated image. | |
| Args: | |
| conditioning_mask: Conditioning mask tensor | |
| generated_image: Generated image tensor | |
| save_path: Optional path to save the visualization | |
| """ | |
| fig, axes = plt.subplots(2, 4, figsize=(16, 8)) | |
| # Plot conditioning mask | |
| for i in range(4): | |
| axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray') | |
| axes[0, i].set_title(f'Conditioning Mask {i+1}') | |
| axes[0, i].axis('off') | |
| # Plot generated image | |
| for i in range(4): | |
| axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray') | |
| axes[1, i].set_title(f'Generated Image {i+1}') | |
| axes[1, i].axis('off') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"Visualization saved to {save_path}") | |
| plt.show() | |
| def main(): | |
| """Example usage of the inference model.""" | |
| # Initialize the model | |
| model_path = "ckpt_3D_v2.pth" # Update with your model path | |
| inference_model = ConditionalDiffusionInference(model_path, device='cuda') | |
| # Create a random conditioning mask (replace with your actual mask) | |
| conditioning_mask = torch.randn(1, 4, 256, 256) | |
| # Generate image | |
| print("Generating image...") | |
| generated_image = inference_model.generate_image(conditioning_mask) | |
| # Visualize results | |
| inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png") | |
| print("Generation complete!") | |
| if __name__ == "__main__": | |
| main() |