--- license: mit tags: - pytorch - diffusers - stable-diffusion - latent-diffusion - medical-imaging - brain-mri - multiple-sclerosis - dataset-conditioning --- #: Brain MRI Synthesis with Stable Diffusion (Fine-Tuned with Dataset Prompts) Fine-tuned version of Stable Diffusion v1-4 for brain MRI synthesis. It uses latent diffusion and dataset-specific prompts to generate realistic 256x256 FLAIR brain scans, with control over the dataset style. This model is a fine-tuned version of Stable Diffusion v1-4 for prompt-conditioned brain MRI image synthesis, trained on 2D FLAIR slices from the SHIFTS, VH, and WMH2017 datasets. It uses latent diffusion to generate realistic 256×256 scans from latent representations of resolution 32×32 and includes special prompt tokens that allow control over the visual style. ## 🔍 Prompt Conditioning Each training image was paired with a specific dataset prompt: - "SHIFTS FLAIR MRI" - "VH FLAIR MRI" - "WMH2017 FLAIR MRI" These prompts were added as new tokens in the tokenizer and trained jointly with the model, enabling conditional generation aligned with dataset distribution. ## 🧠 Training Details - Base model: [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) - Architecture: Latent Diffusion (U-Net + ResNet + Attention) - Latent resolution: 32x32 (decoded to 256x256) - Channels: 4 - Datasets: SHIFTS, VH, WMH2017 (FLAIR MRI) - Epochs: 50 - Batch size: 8 - Gradient accumulation: 4 - Optimizer: AdamW - LR: 1.0e-4 - Betas: (0.95, 0.999) - Weight decay: 1.0e-6 - Epsilon: 1.0e-8 - LR Scheduler: Cosine decay with 500 warm-up steps - Noise Scheduler: DDPM - Timesteps: 1000 - Beta schedule: linear (β_start=0.0001, β_end=0.02) - Gradient Clipping: Max norm 1.0 - Mixed Precision: Disabled - Hardware: Single NVIDIA A30 GPU ## ✍️ Fine-Tuning Strategy The text encoder, U-Net, and special prompt embeddings were trained jointly. Images were encoded into 32×32 latent space using a VAE and trained using latent diffusion. ## 🧪 Inference (Guided Sampling) ```python from diffusers import StableDiffusionPipeline import torch from torchvision.utils import save_image pipe = StableDiffusionPipeline.from_pretrained("benetraco/latent_finetuning", torch_dtype=torch.float32).to("cuda") pipe.scheduler.set_timesteps(999) def get_embeddings(prompt): tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=77).to("cuda") return pipe.text_encoder(**tokens).last_hidden_state def sample(prompt, guidance_scale=2.0, seed=42): torch.manual_seed(seed) latent = torch.randn(1, 4, 32, 32).to("cuda") * pipe.scheduler.init_noise_sigma text_emb = get_embeddings(prompt) uncond_emb = get_embeddings("") for t in pipe.scheduler.timesteps: latent_in = pipe.scheduler.scale_model_input(latent, t) with torch.no_grad(): noise_uncond = pipe.unet(latent_in, t, encoder_hidden_states=uncond_emb).sample noise_text = pipe.unet(latent_in, t, encoder_hidden_states=text_emb).sample noise = noise_uncond + guidance_scale * (noise_text - noise_uncond) latent = pipe.scheduler.step(noise, t, latent).prev_sample latent /= pipe.vae.config.scaling_factor with torch.no_grad(): decoded = pipe.vae.decode(latent).sample image = (decoded + 1.0) / 2.0 image = image.clamp(0, 1) save_image(image, f"{prompt.replace(' ', '_')}_g{guidance_scale}.png") sample("SHIFTS FLAIR MRI", guidance_scale=5.0)