--- license: cc-by-nc-4.0 tags: - diffusion - retinal-fundus - diabetic-retinopathy - medical-imaging - image-to-image - conditional-generation - stable-diffusion datasets: - usama10/retinal-dr-longitudinal pipeline_tag: image-to-image --- # Conditional Latent Diffusion Model for Retinal Future-State Synthesis Trained model weights for predicting two-year follow-up retinal fundus images from baseline photographs and clinical metadata. ## Model Description This model adapts Stable Diffusion 1.5 for longitudinal retinal image prediction. It consists of two components: 1. **Fine-tuned VAE** (`vae_best.pt`, 320 MB): SD 1.5 VAE encoder/decoder fine-tuned on retinal fundus images with L1 + SSIM + LPIPS + KL loss. Achieves SSIM 0.954 on reconstruction. 2. **Conditional U-Net** (`diffusion_best.pt`, 13 GB): 860M-parameter denoising U-Net with 15-channel input (4 noisy latent + 4 baseline latent + 7 clinical feature maps). Trained for 500 epochs with cosine LR schedule, EMA, and classifier-free guidance dropout. ## Performance | Metric | Value | |--------|-------| | SSIM | 0.762 | | PSNR | 17.26 dB | | LPIPS | 0.379 | | FID | 107.28 | Evaluated on 110 held-out test image pairs. ## Qualitative Results ![Qualitative comparison](fig_qualitative_comparison.png) Each row shows a different test patient. Columns: baseline fundus, ground-truth follow-up, our prediction, Regression U-Net, and Pix2Pix. Our diffusion model generates sharper, more realistic retinal textures compared to deterministic baselines. ## Training Dynamics ![Training dynamics](fig_training_dynamics.png) (a, b) Stage 1: VAE fine-tuning over 50 epochs reaching SSIM 0.954. (c-f) Stage 2: U-Net training over 500 epochs with cosine LR schedule and warmup. ## Guidance Scale Sweep ![Guidance scale](fig_guidance_scale.png) SSIM peaks at guidance scale 7.5, while FID increases monotonically with stronger guidance, reflecting the fidelity-diversity tradeoff. ## Usage ```python import torch from diffusers import AutoencoderKL, UNet2DConditionModel # Load VAE vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") vae_state = torch.load("vae_best.pt", map_location="cpu") if "model_state_dict" in vae_state: vae_state = vae_state["model_state_dict"] vae.load_state_dict(vae_state, strict=False) # Load U-Net (requires modified conv_in for 15 input channels) unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") # ... modify conv_in and load checkpoint # See full inference code at the GitHub repository ``` ## Links - **Code:** [github.com/Usama1002/retinal-diffusion](https://github.com/Usama1002/retinal-diffusion) - **Dataset:** [huggingface.co/datasets/usama10/retinal-dr-longitudinal](https://huggingface.co/datasets/usama10/retinal-dr-longitudinal) ## Citation ```bibtex @article{usama2026retinal, title={Conditional Latent Diffusion for Predictive Retinal Fundus Image Synthesis from Baseline Imaging and Clinical Metadata}, author={Usama, Muhammad and Pazo, Emmanuel Eric and Li, Xiaorong and Liu, Juping}, journal={Computers in Biology and Medicine (under review)}, year={2026} } ``` ## License CC BY-NC 4.0. Non-commercial research use only.