| | --- |
| | license: apache-2.0 |
| | tags: |
| | - medical-imaging |
| | - glaucoma |
| | - federated-learning |
| | - semantic-segmentation |
| | - ophthalmology |
| | - computer-vision |
| | - healthcare |
| | - mask2former |
| | - swin-transformer |
| | datasets: |
| | - chaksu |
| | - refuge |
| | - g1020 |
| | - rim-one-dl |
| | - messidor |
| | - origa |
| | library_name: transformers |
| | pipeline_tag: image-segmentation |
| | --- |
| | |
| | # Federated Learning for Glaucoma Segmentation: Model Checkpoints |
| |
|
| | ## Overview |
| |
|
| | This repository contains trained model checkpoints from the research project: **"A Federated Learning-based Optic Disc and Cup Segmentation Model for Glaucoma Monitoring in Color Fundus Photographs"** |
| |
|
| | ### Key Information |
| |
|
| | - **Task**: Automated optic disc and cup segmentation for glaucoma assessment |
| | - **Architecture**: Mask2Former with Swin Transformer backbone |
| | - **Pre-training**: ADE20K semantic segmentation dataset |
| | - **Training Data**: 5,550 color fundus photographs from 9 datasets across 7 countries |
| | - **Approach**: Privacy-preserving federated learning with site-specific fine-tuning |
| |
|
| | ## Clinical Context |
| |
|
| | Glaucoma is a leading cause of irreversible blindness worldwide, affecting 3.54% of the population aged 40-80 and projected to impact 111.8 million people by 2040. A key indicator of glaucoma severity is the vertical cup-to-disc ratio (CDR), with ratios ≥0.6 suggestive of glaucoma. |
| |
|
| | This work addresses the need for accurate automated segmentation while preserving patient data privacy across multiple clinical sites, enabling HIPAA/GDPR-compliant multi-institutional collaboration. |
| |
|
| | ## Models Included |
| |
|
| | This repository contains **22 trained models** organized into four categories: |
| |
|
| | ### Baseline Models |
| |
|
| | - **Central Model** (1 model): Trained on pooled multi-site data, representing upper bound performance |
| | - **Local Models** (9 models): Site-specific models trained on individual datasets, representing lower bound performance |
| |
|
| | ### Federated Learning Models |
| |
|
| | - **Pipeline 1** (1 model): Global Validation |
| | - **Pipeline 2** (1 model): Weighted Global Validation |
| | - **Pipeline 3** (1 model): Onsite Validation |
| | - **Pipeline 4** (9 models): Fine-Tuned Onsite Validation |
| |
|
| | ## Usage |
| |
|
| | ### Download Specific Model |
| |
|
| | from huggingface_hub import hf_hub_download |
| | |
| | # Download central model |
| | model_path = hf_hub_download( |
| | repo_id="sud11111/Federated-Learning-Glaucoma", |
| | filename="models/central/best_model.pt" |
| | ) |
| | |
| | # Download fine-tuned model for specific dataset |
| | model_path = hf_hub_download( |
| | repo_id="sud11111/Federated-Learning-Glaucoma", |
| | filename="models/pipeline4/chaksu/best_model.pt" |
| | ) |
| | |
| | ### Download All Models |
| |
|
| | from huggingface_hub import snapshot_download |
| |
|
| | # Download entire models directory |
| | local_dir = snapshot_download( |
| | repo_id="sud11111/Federated-Learning-Glaucoma", |
| | allow_patterns="models/**" |
| | ) |
| | print(f"Models downloaded to: {local_dir}") |
| | |
| | ### Load and Perform Inference |
| |
|
| | import torch |
| | from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor |
| | from PIL import Image |
| |
|
| | # Load preprocessor |
| | processor = Mask2FormerImageProcessor.from_pretrained( |
| | "facebook/mask2former-swin-base-ade-semantic" |
| | ) |
| | |
| | # Load model architecture |
| | model = Mask2FormerForUniversalSegmentation.from_pretrained( |
| | "facebook/mask2former-swin-base-ade-semantic", |
| | num_labels=4 # background, unlabeled, optic disc, optic cup |
| | ) |
| | |
| | # Load trained weights |
| | model.load_state_dict(torch.load(model_path)) |
| | model.eval() |
| | |
| | # Perform inference on fundus image |
| | image = Image.open("fundus_image.jpg") |
| | inputs = processor(images=image, return_tensors="pt") |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | |
| | # Post-process segmentation |
| | predicted_segmentation = processor.post_process_semantic_segmentation( |
| | outputs, target_sizes=[image.size[::-1]] |
| | )[0] |
| | |
| | ## Datasets |
| |
|
| | Training was performed across 9 public datasets spanning 7 countries, comprising a total of 5,550 color fundus photographs from at least 917 patients: |
| |
|
| | | Dataset | Total Images | Test Images | Country | Characteristics | |
| | |---------|-------------|-------------|---------|-----------------| |
| | | Chaksu | 1,345 | 135 | India | Multi-center research dataset | |
| | | REFUGE | 1,200 | 120 | China | Glaucoma challenge dataset | |
| | | G1020 | 1,020 | 102 | Germany | Benchmark retinal fundus dataset | |
| | | RIM-ONE DL | 485 | 49 | Spain | Glaucoma assessment dataset | |
| | | MESSIDOR | 460 | 46 | France | Diabetic retinopathy screening | |
| | | ORIGA | 650 | 65 | Singapore | Multi-ethnic Asian population | |
| | | Bin Rushed | 195 | 20 | Saudi Arabia | RIGA dataset collection | |
| | | DRISHTI-GS | 101 | 11 | India | Optic nerve head segmentation | |
| | | Magrabi | 94 | 10 | Saudi Arabia | RIGA dataset collection | |
| |
|
| | **Data Split**: Each dataset was divided into training (80%), validation (10%), and testing (10%) subsets. For datasets with multiple expert annotations, the STAPLE (Simultaneous Truth and Performance Level Estimation) method was used to generate consensus segmentation labels. |
| |
|
| | ## Model Architecture |
| |
|
| | - **Base Model**: Mask2Former |
| | - **Backbone**: Swin Transformer (Swin-Base) |
| | - **Pre-training**: ADE20K semantic segmentation dataset |
| | - **Input Resolution**: 512×512 pixels |
| | - **Output Classes**: 4 (background, unlabeled, optic disc, optic cup) |
| | - **Optimizer**: AdamW (learning rate: 2×10⁻⁵) |
| | - **Loss Function**: Multi-class cross-entropy |
| | - **Early Stopping**: Patience of 7 epochs/rounds |
| |
|
| | ## Training Configuration |
| |
|
| | ### Common Hyperparameters |
| | - Batch size: 8 |
| | - Learning rate: 2×10⁻⁵ |
| | - Optimizer: AdamW with weight decay |
| | - Maximum epochs: 100 (with early stopping) |
| | - Early stopping patience: 7 epochs/rounds |
| | - Input size: 512×512 pixels (normalized) |
| | --- |
| |
|