--- license: mit tags: - pytorch - diffusers - stable-diffusion - latent-diffusion - medical-imaging - brain-mri - multiple-sclerosis - dataset-conditioning --- # Brain MRI Synthesis with Stable Diffusion (fine-tuned with dataset prompts) This model is a **fine-tuned version of Stable Diffusion v1-4** for **prompt-conditioned synthesis of brain MRI FLAIR slices**. It leverages **latent diffusion** and dataset-specific prompts to generate realistic 256x256 FLAIR scans with control over the source dataset's style or distribution. ## ๐Ÿ” Prompt Conditioning The model introduces three special prompt tokens corresponding to the dataset of origin. During training, each image was paired with a prompt indicating its source: - `"SHIFTS FLAIR MRI"` - `"VH FLAIR MRI"` - `"WMH2017 FLAIR MRI"` These prompts were added as special tokens to the tokenizer, and their embeddings were fine-tuned alongside the U-Net, enabling dataset-specific synthesis. ## ๐Ÿง  Training Details - **Base Model:** [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) - **Architecture:** Latent Diffusion with U-Net + ResNet + Attention - **Input resolution (latent):** 32x32 - **Output resolution (decoded):** 256x256 pixels - **Datasets:** SHIFTS, VH, and WMH2017 (FLAIR MRI slices) - **Channels:** 4 latent channels - **Epochs:** 50 - **Batch size:** 8 - **Gradient accumulation:** 4 steps - **Optimizer:** AdamW - Learning Rate: `1.0e-4` - Betas: (0.95, 0.999) - Weight Decay: `1.0e-6` - Epsilon: `1.0e-8` - **LR Scheduler:** Cosine schedule with 500 warm-up steps - **Noise Scheduler:** DDPM with: - `num_train_timesteps`: 1000 - `beta_start`: 0.0001 - `beta_end`: 0.02 - `beta_schedule`: "linear" - **Mixed Precision:** Disabled - **Gradient Clipping:** max norm 1.0 - **Hardware:** NVIDIA A30 GPU with 4 dataloader workers ## ๐Ÿงช Usage You can use this model via the `diffusers` library for conditional generation: ```python from diffusers import DiffusionPipeline import torch # Load the model pipe = DiffusionPipeline.from_pretrained("benetraco/latent_finetuning") pipe.to("cuda") # or "cpu" # Generate a brain MRI image in SHIFTS style prompt = "SHIFTS FLAIR MRI" image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=2.0).images[0] image.show()