File size: 3,261 Bytes
1e27037 a48c7ac 1e27037 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | ---
license: cc-by-nc-4.0
tags:
- diffusion
- retinal-fundus
- diabetic-retinopathy
- medical-imaging
- image-to-image
- conditional-generation
- stable-diffusion
datasets:
- usama10/retinal-dr-longitudinal
pipeline_tag: image-to-image
---
# Conditional Latent Diffusion Model for Retinal Future-State Synthesis
Trained model weights for predicting two-year follow-up retinal fundus images from baseline photographs and clinical metadata.
## Model Description
This model adapts Stable Diffusion 1.5 for longitudinal retinal image prediction. It consists of two components:
1. **Fine-tuned VAE** (`vae_best.pt`, 320 MB): SD 1.5 VAE encoder/decoder fine-tuned on retinal fundus images with L1 + SSIM + LPIPS + KL loss. Achieves SSIM 0.954 on reconstruction.
2. **Conditional U-Net** (`diffusion_best.pt`, 13 GB): 860M-parameter denoising U-Net with 15-channel input (4 noisy latent + 4 baseline latent + 7 clinical feature maps). Trained for 500 epochs with cosine LR schedule, EMA, and classifier-free guidance dropout.
## Performance
| Metric | Value |
|--------|-------|
| SSIM | 0.762 |
| PSNR | 17.26 dB |
| LPIPS | 0.379 |
| FID | 107.28 |
Evaluated on 110 held-out test image pairs.
## Qualitative Results

Each row shows a different test patient. Columns: baseline fundus, ground-truth follow-up, our prediction, Regression U-Net, and Pix2Pix. Our diffusion model generates sharper, more realistic retinal textures compared to deterministic baselines.
## Training Dynamics

(a, b) Stage 1: VAE fine-tuning over 50 epochs reaching SSIM 0.954. (c-f) Stage 2: U-Net training over 500 epochs with cosine LR schedule and warmup.
## Guidance Scale Sweep

SSIM peaks at guidance scale 7.5, while FID increases monotonically with stronger guidance, reflecting the fidelity-diversity tradeoff.
## Usage
```python
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
# Load VAE
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
vae_state = torch.load("vae_best.pt", map_location="cpu")
if "model_state_dict" in vae_state:
vae_state = vae_state["model_state_dict"]
vae.load_state_dict(vae_state, strict=False)
# Load U-Net (requires modified conv_in for 15 input channels)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
# ... modify conv_in and load checkpoint
# See full inference code at the GitHub repository
```
## Links
- **Code:** [github.com/Usama1002/retinal-diffusion](https://github.com/Usama1002/retinal-diffusion)
- **Dataset:** [huggingface.co/datasets/usama10/retinal-dr-longitudinal](https://huggingface.co/datasets/usama10/retinal-dr-longitudinal)
## Citation
```bibtex
@article{usama2026retinal,
title={Conditional Latent Diffusion for Predictive Retinal Fundus Image Synthesis from Baseline Imaging and Clinical Metadata},
author={Usama, Muhammad and Pazo, Emmanuel Eric and Li, Xiaorong and Liu, Juping},
journal={Computers in Biology and Medicine (under review)},
year={2026}
}
```
## License
CC BY-NC 4.0. Non-commercial research use only.
|