|
|
|
|
|
"""
|
|
|
Inference script for the conditional diffusion model.
|
|
|
This script provides easy-to-use functions for generating medical images.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
|
|
|
|
|
|
|
|
|
class ConditionalDiffusionInference:
|
|
|
"""Wrapper class for easy inference with the conditional diffusion model."""
|
|
|
|
|
|
def __init__(self, model_path, device='cuda'):
|
|
|
"""
|
|
|
Initialize the inference model.
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the trained model checkpoint
|
|
|
device: Device to run inference on ('cuda' or 'cpu')
|
|
|
"""
|
|
|
self.device = device
|
|
|
self.Lambda = 25.0
|
|
|
|
|
|
|
|
|
self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
|
|
|
self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
|
|
|
|
|
|
self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
|
|
|
self.score_model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
|
self.score_model.to(self.device)
|
|
|
self.score_model.eval()
|
|
|
|
|
|
print(f"Model loaded successfully on {self.device}")
|
|
|
|
|
|
def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3):
|
|
|
"""
|
|
|
Generate a medical image based on a conditioning mask.
|
|
|
|
|
|
Args:
|
|
|
conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256)
|
|
|
num_steps: Number of sampling steps
|
|
|
eps: Smallest time step for numerical stability
|
|
|
|
|
|
Returns:
|
|
|
Generated image tensor of shape (1, 4, 256, 256)
|
|
|
"""
|
|
|
if not isinstance(conditioning_mask, torch.Tensor):
|
|
|
conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32)
|
|
|
|
|
|
if conditioning_mask.dim() == 3:
|
|
|
conditioning_mask = conditioning_mask.unsqueeze(0)
|
|
|
|
|
|
conditioning_mask = conditioning_mask.to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
samples = Euler_Maruyama_sampler(
|
|
|
self.score_model,
|
|
|
self.marginal_prob_std_fn,
|
|
|
self.diffusion_coeff_fn,
|
|
|
batch_size=1,
|
|
|
x_shape=(4, 256, 256),
|
|
|
num_steps=num_steps,
|
|
|
device=self.device,
|
|
|
eps=eps,
|
|
|
y=conditioning_mask
|
|
|
)
|
|
|
|
|
|
return samples.clamp(0, 1)
|
|
|
|
|
|
def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3):
|
|
|
"""
|
|
|
Generate multiple images based on conditioning masks.
|
|
|
|
|
|
Args:
|
|
|
conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256)
|
|
|
num_steps: Number of sampling steps
|
|
|
eps: Smallest time step for numerical stability
|
|
|
|
|
|
Returns:
|
|
|
Generated images tensor of shape (B, 4, 256, 256)
|
|
|
"""
|
|
|
if not isinstance(conditioning_masks, torch.Tensor):
|
|
|
conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32)
|
|
|
|
|
|
if conditioning_masks.dim() == 3:
|
|
|
conditioning_masks = conditioning_masks.unsqueeze(0)
|
|
|
|
|
|
conditioning_masks = conditioning_masks.to(self.device)
|
|
|
batch_size = conditioning_masks.shape[0]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
samples = Euler_Maruyama_sampler(
|
|
|
self.score_model,
|
|
|
self.marginal_prob_std_fn,
|
|
|
self.diffusion_coeff_fn,
|
|
|
batch_size=batch_size,
|
|
|
x_shape=(4, 256, 256),
|
|
|
num_steps=num_steps,
|
|
|
device=self.device,
|
|
|
eps=eps,
|
|
|
y=conditioning_masks
|
|
|
)
|
|
|
|
|
|
return samples.clamp(0, 1)
|
|
|
|
|
|
def visualize_generation(self, conditioning_mask, generated_image, save_path=None):
|
|
|
"""
|
|
|
Visualize the conditioning mask and generated image.
|
|
|
|
|
|
Args:
|
|
|
conditioning_mask: Conditioning mask tensor
|
|
|
generated_image: Generated image tensor
|
|
|
save_path: Optional path to save the visualization
|
|
|
"""
|
|
|
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
|
|
|
|
|
|
|
|
|
for i in range(4):
|
|
|
axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray')
|
|
|
axes[0, i].set_title(f'Conditioning Mask {i+1}')
|
|
|
axes[0, i].axis('off')
|
|
|
|
|
|
|
|
|
for i in range(4):
|
|
|
axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray')
|
|
|
axes[1, i].set_title(f'Generated Image {i+1}')
|
|
|
axes[1, i].axis('off')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
if save_path:
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
print(f"Visualization saved to {save_path}")
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Example usage of the inference model."""
|
|
|
|
|
|
|
|
|
model_path = "ckpt_3D_v2.pth"
|
|
|
inference_model = ConditionalDiffusionInference(model_path, device='cuda')
|
|
|
|
|
|
|
|
|
conditioning_mask = torch.randn(1, 4, 256, 256)
|
|
|
|
|
|
|
|
|
print("Generating image...")
|
|
|
generated_image = inference_model.generate_image(conditioning_mask)
|
|
|
|
|
|
|
|
|
inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png")
|
|
|
|
|
|
print("Generation complete!")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |