tan200224's picture
Add README.md
fa6ba40 verified
|
raw
history blame
2.77 kB

Conditional Diffusion Model for Medical Image Generation

This repository contains a conditional diffusion model trained to generate 3D medical CT scan images based on segmentation masks. The model uses a U-Net architecture with score-based diffusion for high-quality medical image synthesis.

Model Architecture

  • Base Model: U-Net with 5-level encoder-decoder
  • Input: 4-channel 256x256 CT scan images
  • Conditioning: Segmentation masks (4-channel 256x256)
  • Output: 4-channel 256x256 generated images
  • Sampling: Euler-Maruyama sampler with 250 steps
  • Training: Score matching loss with conditional generation

Model Details

  • Training Data: 3,346 medical CT scan examples
  • Lambda Parameter: 25.0 (diffusion coefficient)
  • Embedding Dimension: 256
  • Channels: [32, 64, 128, 256, 512]
  • Activation: SiLU (Swish)

Usage

Using the Hugging Face API

from transformers import AutoModelForImageGeneration
import torch

# Load the model
model = AutoModelForImageGeneration.from_pretrained("your-username/your-model-name")

# Generate images
conditioning_mask = torch.randn(1, 4, 256, 256)  # Your segmentation mask
generated_image = model.generate(conditioning_mask)

Local Usage

import torch
from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler

# Load model
Lambda = 25.0
marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device='cuda')
score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
score_model.load_state_dict(torch.load("ckpt_3D_v2.pth"))
score_model.eval()

# Generate sample
conditioning_mask = torch.randn(1, 4, 256, 256)
samples = Euler_Maruyama_sampler(
    score_model,
    marginal_prob_std_fn,
    lambda t: diffusion_coeff(t, Lambda=Lambda, device='cuda'),
    batch_size=1,
    x_shape=(4, 256, 256),
    num_steps=250,
    device='cuda',
    y=conditioning_mask
)

Training

The model was trained for 5000 epochs with:

  • Learning rate: 2e-4 (with decay)
  • Batch size: 1
  • Optimizer: Adam
  • Loss: Score matching loss

Dataset

The model was trained on medical CT scan data with corresponding segmentation masks. The dataset contains 3,346 training examples with 80-20 train/validation split.

Citation

If you use this model in your research, please cite:

@misc{conditional_diffusion_medical,
  title={Conditional Diffusion Model for Medical Image Generation},
  author={Your Name},
  year={2024},
  url={https://huggingface.co/your-username/your-model-name}
}

License

[Add your license here]

Contact

For questions or issues, please open an issue on this repository.