|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- pytorch |
|
|
- diffusers |
|
|
- stable-diffusion |
|
|
- latent-diffusion |
|
|
- medical-imaging |
|
|
- brain-mri |
|
|
- multiple-sclerosis |
|
|
- dataset-conditioning |
|
|
--- |
|
|
|
|
|
# Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts) |
|
|
|
|
|
This model is a **fine-tuned version of Stable Diffusion v1-4** for **prompt-conditioned synthesis of brain MRI FLAIR slices**. It leverages **latent diffusion** and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution. |
|
|
|
|
|
## 🔍 Prompt Conditioning |
|
|
|
|
|
The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source: |
|
|
|
|
|
- `"SHIFTS FLAIR MRI"` |
|
|
- `"VH FLAIR MRI"` |
|
|
- `"WMH2017 FLAIR MRI"` |
|
|
|
|
|
These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis. |
|
|
|
|
|
## 🧠 Training Details |
|
|
|
|
|
- **Base Model:** [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) |
|
|
- **Architecture:** Latent Diffusion with U-Net + ResNet + Attention |
|
|
- **Input resolution (latent):** 32x32 |
|
|
- **Output resolution (decoded):** 256x256 pixels |
|
|
- **Datasets:** SHIFTS, VH, and WMH2017 (FLAIR MRI slices) |
|
|
- **Channels:** 4 latent channels |
|
|
- **Epochs:** 50 |
|
|
- **Batch size:** 8 |
|
|
- **Gradient accumulation:** 4 steps |
|
|
- **Optimizer:** AdamW |
|
|
- Learning Rate: `1.0e-4` |
|
|
- Betas: (0.95, 0.999) |
|
|
- Weight Decay: `1.0e-6` |
|
|
- Epsilon: `1.0e-8` |
|
|
- **LR Scheduler:** Cosine schedule with 500 warm-up steps |
|
|
- **Noise Scheduler:** DDPM with: |
|
|
- `num_train_timesteps`: 1000 |
|
|
- `beta_start`: 0.0001 |
|
|
- `beta_end`: 0.02 |
|
|
- `beta_schedule`: "linear" |
|
|
- **Mixed Precision:** Disabled |
|
|
- **Gradient Clipping:** max norm 1.0 |
|
|
- **Hardware:** NVIDIA A30 GPU with 4 dataloader workers |
|
|
|
|
|
## 🧪 Usage |
|
|
|
|
|
You can use this model via the `diffusers` library for conditional generation: |
|
|
|
|
|
```python |
|
|
from diffusers import DiffusionPipeline |
|
|
import torch |
|
|
|
|
|
# Load the model |
|
|
pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning") |
|
|
pipe.to("cuda") # or "cpu" |
|
|
|
|
|
# Generate a brain MRI image in SHIFTS style |
|
|
prompt = "SHIFTS FLAIR MRI" |
|
|
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0] |
|
|
|
|
|
image.show() |
|
|
|