Conditional Diffusion Model for Medical Image Generation
This repository contains a conditional diffusion model trained to generate 3D medical CT scan images based on segmentation masks. The model uses a U-Net architecture with score-based diffusion for high-quality medical image synthesis.
Model Architecture
- Base Model: U-Net with 5-level encoder-decoder
- Input: 4-channel 256x256 CT scan images
- Conditioning: Segmentation masks (4-channel 256x256)
- Output: 4-channel 256x256 generated images
- Sampling: Euler-Maruyama sampler with 250 steps
- Training: Score matching loss with conditional generation
Model Details
- Training Data: 3,346 medical CT scan examples
- Lambda Parameter: 25.0 (diffusion coefficient)
- Embedding Dimension: 256
- Channels: [32, 64, 128, 256, 512]
- Activation: SiLU (Swish)
Usage
Using the Hugging Face API
from transformers import AutoModelForImageGeneration
import torch
# Load the model
model = AutoModelForImageGeneration.from_pretrained("your-username/your-model-name")
# Generate images
conditioning_mask = torch.randn(1, 4, 256, 256) # Your segmentation mask
generated_image = model.generate(conditioning_mask)
Local Usage
import torch
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
# Load model
Lambda = 25.0
marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device='cuda')
score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
score_model.load_state_dict(torch.load("ckpt_3D_v2.pth"))
score_model.eval()
# Generate sample
conditioning_mask = torch.randn(1, 4, 256, 256)
samples = Euler_Maruyama_sampler(
score_model,
marginal_prob_std_fn,
lambda t: diffusion_coeff(t, Lambda=Lambda, device='cuda'),
batch_size=1,
x_shape=(4, 256, 256),
num_steps=250,
device='cuda',
y=conditioning_mask
)
Training
The model was trained for 5000 epochs with:
- Learning rate: 2e-4 (with decay)
- Batch size: 1
- Optimizer: Adam
- Loss: Score matching loss
Dataset
The model was trained on medical CT scan data with corresponding segmentation masks. The dataset contains 3,346 training examples with 80-20 train/validation split.
Citation
If you use this model in your research, please cite:
@misc{conditional_diffusion_medical,
title={Conditional Diffusion Model for Medical Image Generation},
author={Your Name},
year={2024},
url={https://huggingface.co/your-username/your-model-name}
}
License
[Add your license here]
Contact
For questions or issues, please open an issue on this repository.