File size: 5,544 Bytes
e60e09b fd5e60b e60e09b 3643c42 e60e09b 3643c42 e60e09b 3643c42 e60e09b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | ---
license: apache-2.0
tags:
- medical-imaging
- glaucoma
- federated-learning
- semantic-segmentation
- ophthalmology
- computer-vision
- healthcare
- mask2former
- swin-transformer
datasets:
- chaksu
- refuge
- g1020
- rim-one-dl
- messidor
- origa
library_name: transformers
pipeline_tag: image-segmentation
---
# Federated Learning for Glaucoma Segmentation: Model Checkpoints
## Overview
This repository contains trained model checkpoints from the research project: **"A Federated Learning-based Optic Disc and Cup Segmentation Model for Glaucoma Monitoring in Color Fundus Photographs"**
### Key Information
- **Task**: Automated optic disc and cup segmentation for glaucoma assessment
- **Architecture**: Mask2Former with Swin Transformer backbone
- **Pre-training**: ADE20K semantic segmentation dataset
- **Training Data**: 5,550 color fundus photographs from 9 datasets across 7 countries
- **Approach**: Privacy-preserving federated learning with site-specific fine-tuning
## Clinical Context
Glaucoma is a leading cause of irreversible blindness worldwide, affecting 3.54% of the population aged 40-80 and projected to impact 111.8 million people by 2040. A key indicator of glaucoma severity is the vertical cup-to-disc ratio (CDR), with ratios ≥0.6 suggestive of glaucoma.
This work addresses the need for accurate automated segmentation while preserving patient data privacy across multiple clinical sites, enabling HIPAA/GDPR-compliant multi-institutional collaboration.
## Models Included
This repository contains **22 trained models** organized into four categories:
### Baseline Models
- **Central Model** (1 model): Trained on pooled multi-site data, representing upper bound performance
- **Local Models** (9 models): Site-specific models trained on individual datasets, representing lower bound performance
### Federated Learning Models
- **Pipeline 1** (1 model): Global Validation
- **Pipeline 2** (1 model): Weighted Global Validation
- **Pipeline 3** (1 model): Onsite Validation
- **Pipeline 4** (9 models): Fine-Tuned Onsite Validation
## Usage
### Download Specific Model
from huggingface_hub import hf_hub_download
# Download central model
model_path = hf_hub_download(
repo_id="sud11111/Federated-Learning-Glaucoma",
filename="models/central/best_model.pt"
)
# Download fine-tuned model for specific dataset
model_path = hf_hub_download(
repo_id="sud11111/Federated-Learning-Glaucoma",
filename="models/pipeline4/chaksu/best_model.pt"
)
### Download All Models
from huggingface_hub import snapshot_download
# Download entire models directory
local_dir = snapshot_download(
repo_id="sud11111/Federated-Learning-Glaucoma",
allow_patterns="models/**"
)
print(f"Models downloaded to: {local_dir}")
### Load and Perform Inference
import torch
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
from PIL import Image
# Load preprocessor
processor = Mask2FormerImageProcessor.from_pretrained(
"facebook/mask2former-swin-base-ade-semantic"
)
# Load model architecture
model = Mask2FormerForUniversalSegmentation.from_pretrained(
"facebook/mask2former-swin-base-ade-semantic",
num_labels=4 # background, unlabeled, optic disc, optic cup
)
# Load trained weights
model.load_state_dict(torch.load(model_path))
model.eval()
# Perform inference on fundus image
image = Image.open("fundus_image.jpg")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Post-process segmentation
predicted_segmentation = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
## Datasets
Training was performed across 9 public datasets spanning 7 countries, comprising a total of 5,550 color fundus photographs from at least 917 patients:
| Dataset | Total Images | Test Images | Country | Characteristics |
|---------|-------------|-------------|---------|-----------------|
| Chaksu | 1,345 | 135 | India | Multi-center research dataset |
| REFUGE | 1,200 | 120 | China | Glaucoma challenge dataset |
| G1020 | 1,020 | 102 | Germany | Benchmark retinal fundus dataset |
| RIM-ONE DL | 485 | 49 | Spain | Glaucoma assessment dataset |
| MESSIDOR | 460 | 46 | France | Diabetic retinopathy screening |
| ORIGA | 650 | 65 | Singapore | Multi-ethnic Asian population |
| Bin Rushed | 195 | 20 | Saudi Arabia | RIGA dataset collection |
| DRISHTI-GS | 101 | 11 | India | Optic nerve head segmentation |
| Magrabi | 94 | 10 | Saudi Arabia | RIGA dataset collection |
**Data Split**: Each dataset was divided into training (80%), validation (10%), and testing (10%) subsets. For datasets with multiple expert annotations, the STAPLE (Simultaneous Truth and Performance Level Estimation) method was used to generate consensus segmentation labels.
## Model Architecture
- **Base Model**: Mask2Former
- **Backbone**: Swin Transformer (Swin-Base)
- **Pre-training**: ADE20K semantic segmentation dataset
- **Input Resolution**: 512×512 pixels
- **Output Classes**: 4 (background, unlabeled, optic disc, optic cup)
- **Optimizer**: AdamW (learning rate: 2×10⁻⁵)
- **Loss Function**: Multi-class cross-entropy
- **Early Stopping**: Patience of 7 epochs/rounds
## Training Configuration
### Common Hyperparameters
- Batch size: 8
- Learning rate: 2×10⁻⁵
- Optimizer: AdamW with weight decay
- Maximum epochs: 100 (with early stopping)
- Early stopping patience: 7 epochs/rounds
- Input size: 512×512 pixels (normalized)
---
|