File size: 5,919 Bytes
e18319b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
#!/usr/bin/env python3
"""
Inference script for the conditional diffusion model.
This script provides easy-to-use functions for generating medical images.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
class ConditionalDiffusionInference:
"""Wrapper class for easy inference with the conditional diffusion model."""
def __init__(self, model_path, device='cuda'):
"""
Initialize the inference model.
Args:
model_path: Path to the trained model checkpoint
device: Device to run inference on ('cuda' or 'cpu')
"""
self.device = device
self.Lambda = 25.0
# Initialize the model
self.marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=self.Lambda, device=self.device)
self.diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=self.Lambda, device=self.device)
self.score_model = UNet(marginal_prob_std=self.marginal_prob_std_fn)
self.score_model.load_state_dict(torch.load(model_path, map_location=self.device))
self.score_model.to(self.device)
self.score_model.eval()
print(f"Model loaded successfully on {self.device}")
def generate_image(self, conditioning_mask, num_steps=250, eps=1e-3):
"""
Generate a medical image based on a conditioning mask.
Args:
conditioning_mask: Conditioning mask tensor of shape (1, 4, 256, 256)
num_steps: Number of sampling steps
eps: Smallest time step for numerical stability
Returns:
Generated image tensor of shape (1, 4, 256, 256)
"""
if not isinstance(conditioning_mask, torch.Tensor):
conditioning_mask = torch.tensor(conditioning_mask, dtype=torch.float32)
if conditioning_mask.dim() == 3:
conditioning_mask = conditioning_mask.unsqueeze(0)
conditioning_mask = conditioning_mask.to(self.device)
with torch.no_grad():
samples = Euler_Maruyama_sampler(
self.score_model,
self.marginal_prob_std_fn,
self.diffusion_coeff_fn,
batch_size=1,
x_shape=(4, 256, 256),
num_steps=num_steps,
device=self.device,
eps=eps,
y=conditioning_mask
)
return samples.clamp(0, 1)
def generate_batch(self, conditioning_masks, num_steps=250, eps=1e-3):
"""
Generate multiple images based on conditioning masks.
Args:
conditioning_masks: Conditioning mask tensor of shape (B, 4, 256, 256)
num_steps: Number of sampling steps
eps: Smallest time step for numerical stability
Returns:
Generated images tensor of shape (B, 4, 256, 256)
"""
if not isinstance(conditioning_masks, torch.Tensor):
conditioning_masks = torch.tensor(conditioning_masks, dtype=torch.float32)
if conditioning_masks.dim() == 3:
conditioning_masks = conditioning_masks.unsqueeze(0)
conditioning_masks = conditioning_masks.to(self.device)
batch_size = conditioning_masks.shape[0]
with torch.no_grad():
samples = Euler_Maruyama_sampler(
self.score_model,
self.marginal_prob_std_fn,
self.diffusion_coeff_fn,
batch_size=batch_size,
x_shape=(4, 256, 256),
num_steps=num_steps,
device=self.device,
eps=eps,
y=conditioning_masks
)
return samples.clamp(0, 1)
def visualize_generation(self, conditioning_mask, generated_image, save_path=None):
"""
Visualize the conditioning mask and generated image.
Args:
conditioning_mask: Conditioning mask tensor
generated_image: Generated image tensor
save_path: Optional path to save the visualization
"""
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
# Plot conditioning mask
for i in range(4):
axes[0, i].imshow(conditioning_mask[0, i].cpu().numpy(), cmap='gray')
axes[0, i].set_title(f'Conditioning Mask {i+1}')
axes[0, i].axis('off')
# Plot generated image
for i in range(4):
axes[1, i].imshow(generated_image[0, i].cpu().numpy(), cmap='gray')
axes[1, i].set_title(f'Generated Image {i+1}')
axes[1, i].axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
plt.show()
def main():
"""Example usage of the inference model."""
# Initialize the model
model_path = "ckpt_3D_v2.pth" # Update with your model path
inference_model = ConditionalDiffusionInference(model_path, device='cuda')
# Create a random conditioning mask (replace with your actual mask)
conditioning_mask = torch.randn(1, 4, 256, 256)
# Generate image
print("Generating image...")
generated_image = inference_model.generate_image(conditioning_mask)
# Visualize results
inference_model.visualize_generation(conditioning_mask, generated_image, "generation_result.png")
print("Generation complete!")
if __name__ == "__main__":
main() |