tan200224's picture
Add inference.py
e18319b verified
#!/usr/bin/env python3
"""
Inference script for the conditional diffusion model.
This script provides easy-to-use functions for generating medical images.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
class ConditionalDiffusionInference:
"""Wrapper class for easy inference with the conditional diffusion model."""
def __init__(self, model_path, device='cuda'):
"""
Initialize the inference model.
Args:
model_path: Path to the trained model checkpoint
device: Device to run inference on ('cuda' or 'cpu')
"""
self.device = device
self.Lambda = 25.0
# Initialize the model
self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
self.score_model.load_state_dict(torch.load(model_path, map_location=self.device))
self.score_model.to(self.device)
self.score_model.eval()
print(f"Model loaded successfully on {self.device}")
def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3):
"""
Generate a medical image based on a conditioning mask.
Args:
conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256)
num_steps: Number of sampling steps
eps: Smallest time step for numerical stability
Returns:
Generated image tensor of shape (1, 4, 256, 256)
"""
if not isinstance(conditioning_mask, torch.Tensor):
conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32)
if conditioning_mask.dim() == 3:
conditioning_mask = conditioning_mask.unsqueeze(0)
conditioning_mask = conditioning_mask.to(self.device)
with torch.no_grad():
samples = Euler_Maruyama_sampler(
self.score_model,
self.marginal_prob_std_fn,
self.diffusion_coeff_fn,
batch_size=1,
x_shape=(4, 256, 256),
num_steps=num_steps,
device=self.device,
eps=eps,
y=conditioning_mask
)
return samples.clamp(0, 1)
def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3):
"""
Generate multiple images based on conditioning masks.
Args:
conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256)
num_steps: Number of sampling steps
eps: Smallest time step for numerical stability
Returns:
Generated images tensor of shape (B, 4, 256, 256)
"""
if not isinstance(conditioning_masks, torch.Tensor):
conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32)
if conditioning_masks.dim() == 3:
conditioning_masks = conditioning_masks.unsqueeze(0)
conditioning_masks = conditioning_masks.to(self.device)
batch_size = conditioning_masks.shape[0]
with torch.no_grad():
samples = Euler_Maruyama_sampler(
self.score_model,
self.marginal_prob_std_fn,
self.diffusion_coeff_fn,
batch_size=batch_size,
x_shape=(4, 256, 256),
num_steps=num_steps,
device=self.device,
eps=eps,
y=conditioning_masks
)
return samples.clamp(0, 1)
def visualize_generation(self, conditioning_mask, generated_image, save_path=None):
"""
Visualize the conditioning mask and generated image.
Args:
conditioning_mask: Conditioning mask tensor
generated_image: Generated image tensor
save_path: Optional path to save the visualization
"""
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
# Plot conditioning mask
for i in range(4):
axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray')
axes[0, i].set_title(f'Conditioning Mask {i+1}')
axes[0, i].axis('off')
# Plot generated image
for i in range(4):
axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray')
axes[1, i].set_title(f'Generated Image {i+1}')
axes[1, i].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
plt.show()
def main():
"""Example usage of the inference model."""
# Initialize the model
model_path = "ckpt_3D_v2.pth" # Update with your model path
inference_model = ConditionalDiffusionInference(model_path, device='cuda')
# Create a random conditioning mask (replace with your actual mask)
conditioning_mask = torch.randn(1, 4, 256, 256)
# Generate image
print("Generating image...")
generated_image = inference_model.generate_image(conditioning_mask)
# Visualize results
inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png")
print("Generation complete!")
if __name__ == "__main__":
main()