File size: 3,554 Bytes
9c10e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feb0db6
9c10e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
---
license: mit
tags:
  - pytorch
  - diffusers
  - stable-diffusion
  - latent-diffusion
  - medical-imaging
  - brain-mri
  - multiple-sclerosis
  - dataset-conditioning
---

#: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts)
Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. 
It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style.

This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets.
It uses latent diffusion to generate realistic 256×256 scans from latent representations of resolution 32×32 and includes special prompt tokens that allow control over the visual style.

## 🔍 Prompt Conditioning

Each training image was paired with a specific dataset prompt:
  
- "SHIFTS FLAIR MRI"
- "VH FLAIR MRI"
- "WMH2017 FLAIR MRI"

These prompts were added as new tokens in the tokenizer and trained jointly with the model,
enabling conditional generation aligned with dataset distribution.

## 🧠 Training Details

- Base model: [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
- Architecture: Latent Diffusion (U-Net + ResNet + Attention)
- Latent resolution: 32x32 (decoded to 256x256)
- Channels: 4
- Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI)
- Epochs: 50
- Batch size: 8
- Gradient accumulation: 4
- Optimizer: AdamW
  - LR: 1.0e-4
  - Betas: (0.95, 0.999)
  - Weight decay: 1.0e-6
  - Epsilon: 1.0e-8
- LR Scheduler: Cosine decay with 500 warm-up steps
- Noise Scheduler: DDPM
  - Timesteps: 1000
  - Beta schedule: linear (β_start=0.0001, β_end=0.02)
- Gradient Clipping: Max norm 1.0
- Mixed Precision: Disabled
- Hardware: Single NVIDIA A30 GPU

## ✍️ Fine-Tuning Strategy

The text encoder, U-Net, and special prompt embeddings were trained jointly.
Images were encoded into 32×32 latent space using a VAE and trained using latent diffusion.

## 🧪 Inference (Guided Sampling)

```python
from diffusers import StableDiffusionPipeline
import torch
from torchvision.utils import save_image

pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda")
pipe.scheduler.set_timesteps(999)

def get_embeddings(prompt):
    tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda")
    return pipe.text_encoder(**tokens).last_hidden_state

def sample(prompt, guidance_scale=2.0, seed=42):
    torch.manual_seed(seed)
    latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma
    text_emb = get_embeddings(prompt)
    uncond_emb = get_embeddings("")

    for t in pipe.scheduler.timesteps:
        latent_in = pipe.scheduler.scale_model_input(latent, t)
        with torch.no_grad():
            noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample
            noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample
            noise = noise_uncond + guidance_scale * (noise_text - noise_uncond)
        latent = pipe.scheduler.step(noise, t, latent).prev_sample

    latent /= pipe.vae.config.scaling_factor
    with torch.no_grad():
        decoded = pipe.vae.decode(latent).sample
    image = (decoded + 1.0) / 2.0
    image = image.clamp(0, 1)
    save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png")

sample("SHIFTS FLAIR MRI", guidance_scale=5.0)