|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- pytorch |
|
|
- diffusers |
|
|
- stable-diffusion |
|
|
- latent-diffusion |
|
|
- medical-imaging |
|
|
- brain-mri |
|
|
- multiple-sclerosis |
|
|
- dataset-conditioning |
|
|
--- |
|
|
|
|
|
#: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts) |
|
|
Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. |
|
|
It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style. |
|
|
|
|
|
This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets. |
|
|
It uses latent diffusion to generate realistic 256×256 scans from latent representations of resolution 32×32 and includes special prompt tokens that allow control over the visual style. |
|
|
|
|
|
## 🔍 Prompt Conditioning |
|
|
|
|
|
Each training image was paired with a specific dataset prompt: |
|
|
|
|
|
- "SHIFTS FLAIR MRI" |
|
|
- "VH FLAIR MRI" |
|
|
- "WMH2017 FLAIR MRI" |
|
|
|
|
|
These prompts were added as new tokens in the tokenizer and trained jointly with the model, |
|
|
enabling conditional generation aligned with dataset distribution. |
|
|
|
|
|
## 🧠 Training Details |
|
|
|
|
|
- Base model: [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) |
|
|
- Architecture: Latent Diffusion (U-Net + ResNet + Attention) |
|
|
- Latent resolution: 32x32 (decoded to 256x256) |
|
|
- Channels: 4 |
|
|
- Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI) |
|
|
- Epochs: 50 |
|
|
- Batch size: 8 |
|
|
- Gradient accumulation: 4 |
|
|
- Optimizer: AdamW |
|
|
- LR: 1.0e-4 |
|
|
- Betas: (0.95, 0.999) |
|
|
- Weight decay: 1.0e-6 |
|
|
- Epsilon: 1.0e-8 |
|
|
- LR Scheduler: Cosine decay with 500 warm-up steps |
|
|
- Noise Scheduler: DDPM |
|
|
- Timesteps: 1000 |
|
|
- Beta schedule: linear (β_start=0.0001, β_end=0.02) |
|
|
- Gradient Clipping: Max norm 1.0 |
|
|
- Mixed Precision: Disabled |
|
|
- Hardware: Single NVIDIA A30 GPU |
|
|
|
|
|
## ✍️ Fine-Tuning Strategy |
|
|
|
|
|
The text encoder, U-Net, and special prompt embeddings were trained jointly. |
|
|
Images were encoded into 32×32 latent space using a VAE and trained using latent diffusion. |
|
|
|
|
|
## 🧪 Inference (Guided Sampling) |
|
|
|
|
|
```python |
|
|
from diffusers import StableDiffusionPipeline |
|
|
import torch |
|
|
from torchvision.utils import save_image |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda") |
|
|
pipe.scheduler.set_timesteps(999) |
|
|
|
|
|
def get_embeddings(prompt): |
|
|
tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda") |
|
|
return pipe.text_encoder(**tokens).last_hidden_state |
|
|
|
|
|
def sample(prompt, guidance_scale=2.0, seed=42): |
|
|
torch.manual_seed(seed) |
|
|
latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma |
|
|
text_emb = get_embeddings(prompt) |
|
|
uncond_emb = get_embeddings("") |
|
|
|
|
|
for t in pipe.scheduler.timesteps: |
|
|
latent_in = pipe.scheduler.scale_model_input(latent, t) |
|
|
with torch.no_grad(): |
|
|
noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample |
|
|
noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample |
|
|
noise = noise_uncond + guidance_scale * (noise_text - noise_uncond) |
|
|
latent = pipe.scheduler.step(noise, t, latent).prev_sample |
|
|
|
|
|
latent /= pipe.vae.config.scaling_factor |
|
|
with torch.no_grad(): |
|
|
decoded = pipe.vae.decode(latent).sample |
|
|
image = (decoded + 1.0) / 2.0 |
|
|
image = image.clamp(0, 1) |
|
|
save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png") |
|
|
|
|
|
sample("SHIFTS FLAIR MRI", guidance_scale=5.0) |
|
|
|