File size: 2,313 Bytes
e64a26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
---
license: mit
tags:
- pytorch
- diffusers
- stable-diffusion
- latent-diffusion
- medical-imaging
- brain-mri
- multiple-sclerosis
- dataset-conditioning
---

# Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts)

This model is a **fine-tuned version of Stable Diffusion v1-4** for **prompt-conditioned synthesis of brain MRI FLAIR slices**. It leverages **latent diffusion** and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution.

## 🔍 Prompt Conditioning

The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source:

- `"SHIFTS FLAIR MRI"`
- `"VH FLAIR MRI"`
- `"WMH2017 FLAIR MRI"`

These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis.

## 🧠 Training Details

- **Base Model:** [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)
- **Architecture:** Latent Diffusion with U-Net + ResNet + Attention
- **Input resolution (latent):** 32x32
- **Output resolution (decoded):** 256x256 pixels
- **Datasets:** SHIFTS, VH, and WMH2017 (FLAIR MRI slices)
- **Channels:** 4 latent channels
- **Epochs:** 50
- **Batch size:** 8
- **Gradient accumulation:** 4 steps
- **Optimizer:** AdamW
  - Learning Rate: `1.0e-4`
  - Betas: (0.95, 0.999)
  - Weight Decay: `1.0e-6`
  - Epsilon: `1.0e-8`
- **LR Scheduler:** Cosine schedule with 500 warm-up steps
- **Noise Scheduler:** DDPM with:
  - `num_train_timesteps`: 1000
  - `beta_start`: 0.0001
  - `beta_end`: 0.02
  - `beta_schedule`: "linear"
- **Mixed Precision:** Disabled
- **Gradient Clipping:** max norm 1.0
- **Hardware:** NVIDIA A30 GPU with 4 dataloader workers

## 🧪 Usage

You can use this model via the `diffusers` library for conditional generation:

```python
from diffusers import DiffusionPipeline
import torch

# Load the model
pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning")
pipe.to("cuda")  # or "cpu"

# Generate a brain MRI image in SHIFTS style
prompt = "SHIFTS FLAIR MRI"
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0]

image.show()