latent_finetuning / README.md
benetraco's picture
Upload README.md with huggingface_hub
e64a26c verified
---
license: mit
tags:
- pytorch
- diffusers
- stable-diffusion
- latent-diffusion
- medical-imaging
- brain-mri
- multiple-sclerosis
- dataset-conditioning
---
# Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts)
This model is a **fine-tuned version of Stable Diffusion v1-4** for **prompt-conditioned synthesis of brain MRI FLAIR slices**. It leverages **latent diffusion** and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution.
## 🔍 Prompt Conditioning
The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source:
- `"SHIFTS FLAIR MRI"`
- `"VH FLAIR MRI"`
- `"WMH2017 FLAIR MRI"`
These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis.
## 🧠 Training Details
- **Base Model:** [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)
- **Architecture:** Latent Diffusion with U-Net + ResNet + Attention
- **Input resolution (latent):** 32x32
- **Output resolution (decoded):** 256x256 pixels
- **Datasets:** SHIFTS, VH, and WMH2017 (FLAIR MRI slices)
- **Channels:** 4 latent channels
- **Epochs:** 50
- **Batch size:** 8
- **Gradient accumulation:** 4 steps
- **Optimizer:** AdamW
- Learning Rate: `1.0e-4`
- Betas: (0.95, 0.999)
- Weight Decay: `1.0e-6`
- Epsilon: `1.0e-8`
- **LR Scheduler:** Cosine schedule with 500 warm-up steps
- **Noise Scheduler:** DDPM with:
- `num_train_timesteps`: 1000
- `beta_start`: 0.0001
- `beta_end`: 0.02
- `beta_schedule`: "linear"
- **Mixed Precision:** Disabled
- **Gradient Clipping:** max norm 1.0
- **Hardware:** NVIDIA A30 GPU with 4 dataloader workers
## 🧪 Usage
You can use this model via the `diffusers` library for conditional generation:
```python
from diffusers import DiffusionPipeline
import torch
# Load the model
pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning")
pipe.to("cuda") # or "cpu"
# Generate a brain MRI image in SHIFTS style
prompt = "SHIFTS FLAIR MRI"
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0]
image.show()