sud11111's picture
Update README.md
d3a4369 verified
---
license: apache-2.0
tags:
- medical-imaging
- glaucoma
- federated-learning
- semantic-segmentation
- ophthalmology
- computer-vision
- healthcare
- mask2former
- swin-transformer
datasets:
- chaksu
- refuge
- g1020
- rim-one-dl
- messidor
- origa
library_name: transformers
pipeline_tag: image-segmentation
---
# Federated Learning for Glaucoma Segmentation: Model Checkpoints
## Overview
This repository contains trained model checkpoints from the research project: **"A Federated Learning-based Optic Disc and Cup Segmentation Model for Glaucoma Monitoring in Color Fundus Photographs"**
### Key Information
- **Task**: Automated optic disc and cup segmentation for glaucoma assessment
- **Architecture**: Mask2Former with Swin Transformer backbone
- **Pre-training**: ADE20K semantic segmentation dataset
- **Training Data**: 5,550 color fundus photographs from 9 datasets across 7 countries
- **Approach**: Privacy-preserving federated learning with site-specific fine-tuning
## Clinical Context
Glaucoma is a leading cause of irreversible blindness worldwide, affecting 3.54% of the population aged 40-80 and projected to impact 111.8 million people by 2040. A key indicator of glaucoma severity is the vertical cup-to-disc ratio (CDR), with ratios ≥0.6 suggestive of glaucoma.
This work addresses the need for accurate automated segmentation while preserving patient data privacy across multiple clinical sites, enabling HIPAA/GDPR-compliant multi-institutional collaboration.
## Models Included
This repository contains **22 trained models** organized into four categories:
### Baseline Models
- **Central Model** (1 model): Trained on pooled multi-site data, representing upper bound performance
- **Local Models** (9 models): Site-specific models trained on individual datasets, representing lower bound performance
### Federated Learning Models
- **Pipeline 1** (1 model): Global Validation
- **Pipeline 2** (1 model): Weighted Global Validation
- **Pipeline 3** (1 model): Onsite Validation
- **Pipeline 4** (9 models): Fine-Tuned Onsite Validation
## Usage
### Download Specific Model
from huggingface_hub import hf_hub_download
# Download central model
model_path = hf_hub_download(
repo_id="sud11111/Federated-Learning-Glaucoma",
filename="models/central/best_model.pt"
)
# Download fine-tuned model for specific dataset
model_path = hf_hub_download(
repo_id="sud11111/Federated-Learning-Glaucoma",
filename="models/pipeline4/chaksu/best_model.pt"
)
### Download All Models
from huggingface_hub import snapshot_download
# Download entire models directory
local_dir = snapshot_download(
repo_id="sud11111/Federated-Learning-Glaucoma",
allow_patterns="models/**"
)
print(f"Models downloaded to: {local_dir}")
### Load and Perform Inference
import torch
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
from PIL import Image
# Load preprocessor
processor = Mask2FormerImageProcessor.from_pretrained(
"facebook/mask2former-swin-base-ade-semantic"
)
# Load model architecture
model = Mask2FormerForUniversalSegmentation.from_pretrained(
"facebook/mask2former-swin-base-ade-semantic",
num_labels=4 # background, unlabeled, optic disc, optic cup
)
# Load trained weights
model.load_state_dict(torch.load(model_path))
model.eval()
# Perform inference on fundus image
image = Image.open("fundus_image.jpg")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Post-process segmentation
predicted_segmentation = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
## Datasets
Training was performed across 9 public datasets spanning 7 countries, comprising a total of 5,550 color fundus photographs from at least 917 patients:
| Dataset | Total Images | Test Images | Country | Characteristics |
|---------|-------------|-------------|---------|-----------------|
| Chaksu | 1,345 | 135 | India | Multi-center research dataset |
| REFUGE | 1,200 | 120 | China | Glaucoma challenge dataset |
| G1020 | 1,020 | 102 | Germany | Benchmark retinal fundus dataset |
| RIM-ONE DL | 485 | 49 | Spain | Glaucoma assessment dataset |
| MESSIDOR | 460 | 46 | France | Diabetic retinopathy screening |
| ORIGA | 650 | 65 | Singapore | Multi-ethnic Asian population |
| Bin Rushed | 195 | 20 | Saudi Arabia | RIGA dataset collection |
| DRISHTI-GS | 101 | 11 | India | Optic nerve head segmentation |
| Magrabi | 94 | 10 | Saudi Arabia | RIGA dataset collection |
**Data Split**: Each dataset was divided into training (80%), validation (10%), and testing (10%) subsets. For datasets with multiple expert annotations, the STAPLE (Simultaneous Truth and Performance Level Estimation) method was used to generate consensus segmentation labels.
## Model Architecture
- **Base Model**: Mask2Former
- **Backbone**: Swin Transformer (Swin-Base)
- **Pre-training**: ADE20K semantic segmentation dataset
- **Input Resolution**: 512×512 pixels
- **Output Classes**: 4 (background, unlabeled, optic disc, optic cup)
- **Optimizer**: AdamW (learning rate: 2×10⁻⁵)
- **Loss Function**: Multi-class cross-entropy
- **Early Stopping**: Patience of 7 epochs/rounds
## Training Configuration
### Common Hyperparameters
- Batch size: 8
- Learning rate: 2×10⁻⁵
- Optimizer: AdamW with weight decay
- Maximum epochs: 100 (with early stopping)
- Early stopping patience: 7 epochs/rounds
- Input size: 512×512 pixels (normalized)
---