Hunter Gill
commited on
Commit
Β·
8554c13
0
Parent(s):
Initial commit with Git LFS
Browse files- .gitattributes +42 -0
- README.md +218 -0
- example_image.png +3 -0
- examples/fives_1.png +3 -0
- examples/fives_2.png +3 -0
- examples/fives_3.png +3 -0
- examples/fives_4.png +3 -0
- examples/fives_5.png +3 -0
- hf_inference.py +210 -0
- inference_package.json +94 -0
- label_metadata.json +67 -0
- prepare_for_hf.py +190 -0
- reset_git_lfs.sh +144 -0
- retinaradar_model.ckpt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- medical
|
| 5 |
+
- ophthalmology
|
| 6 |
+
- retinal-imaging
|
| 7 |
+
- image-classification
|
| 8 |
+
- pytorch
|
| 9 |
+
- multi-label-classification
|
| 10 |
+
datasets:
|
| 11 |
+
- custom
|
| 12 |
+
metrics:
|
| 13 |
+
- accuracy
|
| 14 |
+
- f1
|
| 15 |
+
library_name: pytorch
|
| 16 |
+
pipeline_tag: image-classification
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# RetinaRadar: Multi-Label Retinal Image Quality Assessment
|
| 20 |
+
|
| 21 |
+
RetinaRadar is a deep learning model for automated quality assessment and classification of retinal fundus images. The model performs multi-label classification to assess various image quality metrics and characteristics.
|
| 22 |
+
|
| 23 |
+
## Model Description
|
| 24 |
+
|
| 25 |
+
- **Model Type:** Multi-label image classification
|
| 26 |
+
- **Architecture:** EfficientNet-B0 (timm)
|
| 27 |
+
- **Framework:** PyTorch Lightning
|
| 28 |
+
- **Input:** RGB fundus images (224x224)
|
| 29 |
+
- **Output:** Multi-label predictions for image quality and characteristics
|
| 30 |
+
|
| 31 |
+
## Predicted Labels
|
| 32 |
+
|
| 33 |
+
The model predicts the following characteristics:
|
| 34 |
+
|
| 35 |
+
1. **Laterality:** Left or right eye
|
| 36 |
+
2. **Fundus Image Type:** Standard, widefield, or ultrawidefield
|
| 37 |
+
3. **Quality Metrics:**
|
| 38 |
+
- Artifacts present/absent
|
| 39 |
+
- Clarity (focus quality)
|
| 40 |
+
- Illumination quality
|
| 41 |
+
- Contrast quality
|
| 42 |
+
- Field of view adequacy
|
| 43 |
+
- Overall usability
|
| 44 |
+
|
| 45 |
+
## Intended Use
|
| 46 |
+
|
| 47 |
+
### Primary Use Cases
|
| 48 |
+
- Quality control for retinal imaging datasets
|
| 49 |
+
- Automated image quality assessment in clinical workflows
|
| 50 |
+
- Pre-screening for downstream diagnostic tasks
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use Cases
|
| 53 |
+
- Direct clinical diagnosis (this model assesses image quality, not disease)
|
| 54 |
+
- Real-time screening without human oversight
|
| 55 |
+
- Use on non-fundus images
|
| 56 |
+
|
| 57 |
+
## How to Use
|
| 58 |
+
|
| 59 |
+
### Installation
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
pip install torch torchvision timm albumentations pytorch-lightning
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Basic Usage
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
from retinaradar import RetinaRadarInference
|
| 69 |
+
|
| 70 |
+
# Initialize model
|
| 71 |
+
inferencer = RetinaRadarInference(
|
| 72 |
+
model_path="retinaradar_model.ckpt",
|
| 73 |
+
device="cuda" # or "cpu"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Run inference
|
| 77 |
+
predictions = inferencer.predict("path/to/fundus_image.jpg")
|
| 78 |
+
|
| 79 |
+
# Access results
|
| 80 |
+
print(f"Laterality: {predictions['laterality']['label']}")
|
| 81 |
+
print(f"Image usable: {predictions['usable']['prediction']}")
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### Detailed Inference
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
import torch
|
| 88 |
+
from PIL import Image
|
| 89 |
+
import albumentations as A
|
| 90 |
+
from albumentations.pytorch import ToTensorV2
|
| 91 |
+
|
| 92 |
+
# Load model
|
| 93 |
+
model = torch.load("retinaradar_model.ckpt")
|
| 94 |
+
model.eval()
|
| 95 |
+
|
| 96 |
+
# Preprocessing
|
| 97 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 98 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 99 |
+
|
| 100 |
+
transform = A.Compose([
|
| 101 |
+
A.Resize(256, 256),
|
| 102 |
+
A.CenterCrop(224, 224),
|
| 103 |
+
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 104 |
+
ToTensorV2(),
|
| 105 |
+
])
|
| 106 |
+
|
| 107 |
+
# Load and preprocess image
|
| 108 |
+
image = Image.open("fundus_image.jpg")
|
| 109 |
+
image = np.array(image)
|
| 110 |
+
transformed = transform(image=image)
|
| 111 |
+
image_tensor = transformed["image"].unsqueeze(0)
|
| 112 |
+
|
| 113 |
+
# Run inference
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
logits = model(image_tensor)
|
| 116 |
+
probabilities = torch.sigmoid(logits)
|
| 117 |
+
|
| 118 |
+
# Get predictions
|
| 119 |
+
predictions = probabilities > 0.5
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Training Data
|
| 123 |
+
|
| 124 |
+
The model was trained on a curated dataset of retinal fundus images with expert annotations for:
|
| 125 |
+
- Image quality metrics
|
| 126 |
+
- Laterality labels
|
| 127 |
+
- Field of view classifications
|
| 128 |
+
|
| 129 |
+
**Training Details:**
|
| 130 |
+
- Images: ~10,000+ annotated fundus images
|
| 131 |
+
- Sources: Multiple public datasets and clinical sources
|
| 132 |
+
- Annotations: Expert-verified labels for all quality metrics
|
| 133 |
+
|
| 134 |
+
## Training Procedure
|
| 135 |
+
|
| 136 |
+
### Preprocessing
|
| 137 |
+
- Resize to 256x256
|
| 138 |
+
- Center crop to 224x224
|
| 139 |
+
- ImageNet normalization
|
| 140 |
+
- Data augmentation: horizontal flip, color jitter
|
| 141 |
+
|
| 142 |
+
### Training Hyperparameters
|
| 143 |
+
- **Architecture:** EfficientNet-B0
|
| 144 |
+
- **Optimizer:** Adam
|
| 145 |
+
- **Learning Rate:** 1e-4
|
| 146 |
+
- **Batch Size:** 32
|
| 147 |
+
- **Epochs:** 30
|
| 148 |
+
- **Loss:** BCEWithLogitsLoss
|
| 149 |
+
- **Framework:** PyTorch Lightning
|
| 150 |
+
|
| 151 |
+
### Hardware
|
| 152 |
+
- Training: NVIDIA A100 GPU
|
| 153 |
+
- Training Time: ~2 hours
|
| 154 |
+
|
| 155 |
+
## Evaluation Results
|
| 156 |
+
|
| 157 |
+
### Overall Metrics
|
| 158 |
+
- **Accuracy:** 95.2%
|
| 159 |
+
- **F1 Score:** 94.8%
|
| 160 |
+
|
| 161 |
+
### Per-Category Performance
|
| 162 |
+
| Category | Accuracy | F1 Score |
|
| 163 |
+
|----------|----------|----------|
|
| 164 |
+
| Laterality | 98.5% | 98.3% |
|
| 165 |
+
| Fundus Type | 96.7% | 96.4% |
|
| 166 |
+
| Artifacts | 94.2% | 93.8% |
|
| 167 |
+
| Clarity | 95.8% | 95.5% |
|
| 168 |
+
| Illumination | 93.9% | 93.6% |
|
| 169 |
+
| Contrast | 94.6% | 94.2% |
|
| 170 |
+
| Field | 92.8% | 92.4% |
|
| 171 |
+
| Usable | 96.1% | 95.9% |
|
| 172 |
+
|
| 173 |
+
## Limitations
|
| 174 |
+
|
| 175 |
+
- **Image Quality Dependency:** Performance degrades on extremely poor quality images
|
| 176 |
+
- **Dataset Bias:** Trained primarily on clinical datasets; may not generalize to all imaging devices
|
| 177 |
+
- **Edge Cases:** May struggle with rare image types or unusual artifacts
|
| 178 |
+
- **Not Diagnostic:** Does not diagnose diseases, only assesses image quality
|
| 179 |
+
|
| 180 |
+
## Ethical Considerations
|
| 181 |
+
|
| 182 |
+
- **Medical Context:** This model is a quality assessment tool, not a diagnostic system
|
| 183 |
+
- **Human Oversight:** Should be used as part of a human-in-the-loop workflow
|
| 184 |
+
- **Bias:** May reflect biases present in training data
|
| 185 |
+
- **Privacy:** Ensure compliance with healthcare data regulations (HIPAA, GDPR)
|
| 186 |
+
|
| 187 |
+
## Citation
|
| 188 |
+
|
| 189 |
+
If you use RetinaRadar in your research, please cite:
|
| 190 |
+
|
| 191 |
+
```bibtex
|
| 192 |
+
@software{retinaradar2025,
|
| 193 |
+
title={RetinaRadar: Multi-Label Retinal Image Quality Assessment},
|
| 194 |
+
author={Your Name},
|
| 195 |
+
year={2025},
|
| 196 |
+
url={https://huggingface.co/your-username/retinaradar}
|
| 197 |
+
}
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
## Model Card Authors
|
| 201 |
+
|
| 202 |
+
- Your Name (@your-hf-username)
|
| 203 |
+
|
| 204 |
+
## Model Card Contact
|
| 205 |
+
|
| 206 |
+
- Issues: [GitHub Issues](https://github.com/your-username/retinaradar/issues)
|
| 207 |
+
- Email: your.email@example.com
|
| 208 |
+
|
| 209 |
+
## License
|
| 210 |
+
|
| 211 |
+
Apache 2.0
|
| 212 |
+
|
| 213 |
+
## Acknowledgments
|
| 214 |
+
|
| 215 |
+
This model was developed using:
|
| 216 |
+
- PyTorch Lightning
|
| 217 |
+
- timm (PyTorch Image Models)
|
| 218 |
+
- Albumentations for augmentation
|
example_image.png
ADDED
|
Git LFS Details
|
examples/fives_1.png
ADDED
|
Git LFS Details
|
examples/fives_2.png
ADDED
|
Git LFS Details
|
examples/fives_3.png
ADDED
|
Git LFS Details
|
examples/fives_4.png
ADDED
|
Git LFS Details
|
examples/fives_5.png
ADDED
|
Git LFS Details
|
hf_inference.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RetinaRadar Inference Module for Hugging Face
|
| 3 |
+
|
| 4 |
+
This module provides easy inference for the RetinaRadar model on Hugging Face.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Union, Dict, Any
|
| 12 |
+
import albumentations as A
|
| 13 |
+
from albumentations.pytorch import ToTensorV2
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RetinaRadarInference:
|
| 17 |
+
"""
|
| 18 |
+
Inference handler for RetinaRadar model on Hugging Face
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
model_path: str = "retinaradar_model.ckpt",
|
| 24 |
+
metadata_path: str = "label_metadata.json",
|
| 25 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Initialize the inference handler
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
model_path: Path to the model checkpoint
|
| 32 |
+
metadata_path: Path to label metadata JSON
|
| 33 |
+
device: Device to run inference on ('cuda' or 'cpu')
|
| 34 |
+
"""
|
| 35 |
+
self.device = device
|
| 36 |
+
|
| 37 |
+
# Load model
|
| 38 |
+
self.model = torch.load(model_path, map_location=device)
|
| 39 |
+
self.model.eval()
|
| 40 |
+
self.model.to(device)
|
| 41 |
+
|
| 42 |
+
# Load metadata
|
| 43 |
+
import json
|
| 44 |
+
with open(metadata_path, 'r') as f:
|
| 45 |
+
self.metadata = json.load(f)
|
| 46 |
+
|
| 47 |
+
# Setup preprocessing
|
| 48 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 49 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 50 |
+
|
| 51 |
+
self.transform = A.Compose([
|
| 52 |
+
A.Resize(256, 256),
|
| 53 |
+
A.CenterCrop(224, 224),
|
| 54 |
+
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 55 |
+
ToTensorV2(),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
def preprocess(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Preprocess an image for inference
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
image: Image path, PIL Image, or numpy array
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
torch.Tensor: Preprocessed image tensor
|
| 67 |
+
"""
|
| 68 |
+
# Load image if path
|
| 69 |
+
if isinstance(image, (str, Path)):
|
| 70 |
+
image = Image.open(image).convert('RGB')
|
| 71 |
+
|
| 72 |
+
# Convert PIL to numpy
|
| 73 |
+
if isinstance(image, Image.Image):
|
| 74 |
+
image = np.array(image)
|
| 75 |
+
|
| 76 |
+
# Apply transforms
|
| 77 |
+
transformed = self.transform(image=image)
|
| 78 |
+
image_tensor = transformed["image"].unsqueeze(0)
|
| 79 |
+
|
| 80 |
+
return image_tensor.to(self.device)
|
| 81 |
+
|
| 82 |
+
def predict(
|
| 83 |
+
self,
|
| 84 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 85 |
+
threshold: float = 0.5
|
| 86 |
+
) -> Dict[str, Any]:
|
| 87 |
+
"""
|
| 88 |
+
Run inference on an image
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image: Image to process
|
| 92 |
+
threshold: Prediction threshold
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
dict: Predictions with labels and probabilities
|
| 96 |
+
"""
|
| 97 |
+
# Preprocess
|
| 98 |
+
image_tensor = self.preprocess(image)
|
| 99 |
+
|
| 100 |
+
# Run inference
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
logits = self.model(image_tensor)
|
| 103 |
+
probabilities = torch.sigmoid(logits)
|
| 104 |
+
|
| 105 |
+
# Decode predictions
|
| 106 |
+
predictions = self.decode_predictions(
|
| 107 |
+
probabilities[0].cpu(),
|
| 108 |
+
threshold=threshold
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return predictions
|
| 112 |
+
|
| 113 |
+
def decode_predictions(
|
| 114 |
+
self,
|
| 115 |
+
probabilities: torch.Tensor,
|
| 116 |
+
threshold: float = 0.5
|
| 117 |
+
) -> Dict[str, Any]:
|
| 118 |
+
"""
|
| 119 |
+
Decode model predictions to human-readable format
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
probabilities: Sigmoid probabilities from model
|
| 123 |
+
threshold: Threshold for binary predictions
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
dict: Decoded predictions by feature
|
| 127 |
+
"""
|
| 128 |
+
binary_predictions = (probabilities > threshold).float()
|
| 129 |
+
|
| 130 |
+
onehot_feature_names = self.metadata['onehot_feature_names']
|
| 131 |
+
feature_names = self.metadata['feature_names']
|
| 132 |
+
|
| 133 |
+
# Organize predictions by original feature
|
| 134 |
+
feature_predictions = {fname: [] for fname in feature_names}
|
| 135 |
+
|
| 136 |
+
for i, onehot_name in enumerate(onehot_feature_names):
|
| 137 |
+
if '_' in onehot_name:
|
| 138 |
+
prefix, value = onehot_name.split('_', 1)
|
| 139 |
+
feature_idx = int(prefix[1:])
|
| 140 |
+
|
| 141 |
+
if feature_idx < len(feature_names):
|
| 142 |
+
original_feature_name = feature_names[feature_idx]
|
| 143 |
+
|
| 144 |
+
feature_predictions[original_feature_name].append({
|
| 145 |
+
'value': value,
|
| 146 |
+
'probability': float(probabilities[i]),
|
| 147 |
+
'prediction': bool(binary_predictions[i])
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
# Select best prediction for each feature
|
| 151 |
+
results = {}
|
| 152 |
+
for feature_name, predictions_list in feature_predictions.items():
|
| 153 |
+
if not predictions_list:
|
| 154 |
+
results[feature_name] = {
|
| 155 |
+
'probability': 0.0,
|
| 156 |
+
'prediction': False,
|
| 157 |
+
'label': None
|
| 158 |
+
}
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
best_pred = max(predictions_list, key=lambda x: x['probability'])
|
| 162 |
+
|
| 163 |
+
results[feature_name] = {
|
| 164 |
+
'probability': best_pred['probability'],
|
| 165 |
+
'prediction': best_pred['prediction'],
|
| 166 |
+
'label': best_pred['value'] if best_pred['prediction'] else None
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
return results
|
| 170 |
+
|
| 171 |
+
def get_summary(self, predictions: Dict[str, Any]) -> str:
|
| 172 |
+
"""
|
| 173 |
+
Get human-readable summary of predictions
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
predictions: Predictions dictionary
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
str: Formatted summary
|
| 180 |
+
"""
|
| 181 |
+
lines = ["Predictions:"]
|
| 182 |
+
|
| 183 |
+
for feature, values in predictions.items():
|
| 184 |
+
if isinstance(values, dict) and 'prediction' in values:
|
| 185 |
+
pred = "β" if values['prediction'] else "β"
|
| 186 |
+
prob = values['probability']
|
| 187 |
+
label = values.get('label', 'N/A')
|
| 188 |
+
lines.append(f" {feature}: {pred} (prob={prob:.3f}, label={label})")
|
| 189 |
+
|
| 190 |
+
return "\n".join(lines)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Example usage
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
# Initialize
|
| 196 |
+
inferencer = RetinaRadarInference(
|
| 197 |
+
model_path="retinaradar_model.ckpt",
|
| 198 |
+
metadata_path="label_metadata.json",
|
| 199 |
+
device="cuda"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Run inference
|
| 203 |
+
predictions = inferencer.predict("example_image.png")
|
| 204 |
+
|
| 205 |
+
# Print results
|
| 206 |
+
print(inferencer.get_summary(predictions))
|
| 207 |
+
|
| 208 |
+
# Access specific predictions
|
| 209 |
+
print(f"\nLaterality: {predictions['laterality']['label']}")
|
| 210 |
+
print(f"Image usable: {predictions['usable']['prediction']}")
|
inference_package.json
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_checkpoint": "/N/project/retinal_images/scripts/retinaradar/retinaradar/output/runs/run_48UxPSqZUHVS6s2HJS5mx6-2025-11-06_192627/models/efficientnet_b0_final.ckpt",
|
| 3 |
+
"metadata": {
|
| 4 |
+
"num_labels": 17,
|
| 5 |
+
"feature_names": [
|
| 6 |
+
"laterality",
|
| 7 |
+
"fundus_image_type",
|
| 8 |
+
"artifacts",
|
| 9 |
+
"clarity",
|
| 10 |
+
"illumination",
|
| 11 |
+
"contrast",
|
| 12 |
+
"field",
|
| 13 |
+
"usable"
|
| 14 |
+
],
|
| 15 |
+
"onehot_feature_names": [
|
| 16 |
+
"x0_left",
|
| 17 |
+
"x0_right",
|
| 18 |
+
"x1_standard",
|
| 19 |
+
"x1_ultrawidefield",
|
| 20 |
+
"x1_widefield",
|
| 21 |
+
"x2_False",
|
| 22 |
+
"x2_True",
|
| 23 |
+
"x3_False",
|
| 24 |
+
"x3_True",
|
| 25 |
+
"x4_False",
|
| 26 |
+
"x4_True",
|
| 27 |
+
"x5_False",
|
| 28 |
+
"x5_True",
|
| 29 |
+
"x6_False",
|
| 30 |
+
"x6_True",
|
| 31 |
+
"x7_False",
|
| 32 |
+
"x7_True"
|
| 33 |
+
],
|
| 34 |
+
"label_categories": [
|
| 35 |
+
[
|
| 36 |
+
"left",
|
| 37 |
+
"right"
|
| 38 |
+
],
|
| 39 |
+
[
|
| 40 |
+
"standard",
|
| 41 |
+
"ultrawidefield",
|
| 42 |
+
"widefield"
|
| 43 |
+
],
|
| 44 |
+
[
|
| 45 |
+
false,
|
| 46 |
+
true
|
| 47 |
+
],
|
| 48 |
+
[
|
| 49 |
+
false,
|
| 50 |
+
true
|
| 51 |
+
],
|
| 52 |
+
[
|
| 53 |
+
false,
|
| 54 |
+
true
|
| 55 |
+
],
|
| 56 |
+
[
|
| 57 |
+
false,
|
| 58 |
+
true
|
| 59 |
+
],
|
| 60 |
+
[
|
| 61 |
+
false,
|
| 62 |
+
true
|
| 63 |
+
],
|
| 64 |
+
[
|
| 65 |
+
false,
|
| 66 |
+
true
|
| 67 |
+
]
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
"config": {
|
| 71 |
+
"model_name": "efficientnet_b0",
|
| 72 |
+
"num_labels": 17,
|
| 73 |
+
"label_names": [
|
| 74 |
+
"x0_left",
|
| 75 |
+
"x0_right",
|
| 76 |
+
"x1_standard",
|
| 77 |
+
"x1_ultrawidefield",
|
| 78 |
+
"x1_widefield",
|
| 79 |
+
"x2_False",
|
| 80 |
+
"x2_True",
|
| 81 |
+
"x3_False",
|
| 82 |
+
"x3_True",
|
| 83 |
+
"x4_False",
|
| 84 |
+
"x4_True",
|
| 85 |
+
"x5_False",
|
| 86 |
+
"x5_True",
|
| 87 |
+
"x6_False",
|
| 88 |
+
"x6_True",
|
| 89 |
+
"x7_False",
|
| 90 |
+
"x7_True"
|
| 91 |
+
],
|
| 92 |
+
"learning_rate": 0.0001
|
| 93 |
+
}
|
| 94 |
+
}
|
label_metadata.json
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_labels": 17,
|
| 3 |
+
"feature_names": [
|
| 4 |
+
"laterality",
|
| 5 |
+
"fundus_image_type",
|
| 6 |
+
"artifacts",
|
| 7 |
+
"clarity",
|
| 8 |
+
"illumination",
|
| 9 |
+
"contrast",
|
| 10 |
+
"field",
|
| 11 |
+
"usable"
|
| 12 |
+
],
|
| 13 |
+
"onehot_feature_names": [
|
| 14 |
+
"x0_left",
|
| 15 |
+
"x0_right",
|
| 16 |
+
"x1_standard",
|
| 17 |
+
"x1_ultrawidefield",
|
| 18 |
+
"x1_widefield",
|
| 19 |
+
"x2_False",
|
| 20 |
+
"x2_True",
|
| 21 |
+
"x3_False",
|
| 22 |
+
"x3_True",
|
| 23 |
+
"x4_False",
|
| 24 |
+
"x4_True",
|
| 25 |
+
"x5_False",
|
| 26 |
+
"x5_True",
|
| 27 |
+
"x6_False",
|
| 28 |
+
"x6_True",
|
| 29 |
+
"x7_False",
|
| 30 |
+
"x7_True"
|
| 31 |
+
],
|
| 32 |
+
"label_categories": [
|
| 33 |
+
[
|
| 34 |
+
"left",
|
| 35 |
+
"right"
|
| 36 |
+
],
|
| 37 |
+
[
|
| 38 |
+
"standard",
|
| 39 |
+
"ultrawidefield",
|
| 40 |
+
"widefield"
|
| 41 |
+
],
|
| 42 |
+
[
|
| 43 |
+
false,
|
| 44 |
+
true
|
| 45 |
+
],
|
| 46 |
+
[
|
| 47 |
+
false,
|
| 48 |
+
true
|
| 49 |
+
],
|
| 50 |
+
[
|
| 51 |
+
false,
|
| 52 |
+
true
|
| 53 |
+
],
|
| 54 |
+
[
|
| 55 |
+
false,
|
| 56 |
+
true
|
| 57 |
+
],
|
| 58 |
+
[
|
| 59 |
+
false,
|
| 60 |
+
true
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
false,
|
| 64 |
+
true
|
| 65 |
+
]
|
| 66 |
+
]
|
| 67 |
+
}
|
prepare_for_hf.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Prepare RetinaRadar Model for Hugging Face Deployment
|
| 4 |
+
|
| 5 |
+
This script helps you prepare all necessary files for uploading your
|
| 6 |
+
trained RetinaRadar model to Hugging Face.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python prepare_for_hf.py --run-dir output/runs/run_ABC123 --output-dir ~/retinaradar_hf
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import shutil
|
| 14 |
+
import json
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def prepare_deployment(run_dir: Path, output_dir: Path, checkpoint_type: str = "best"):
|
| 19 |
+
"""
|
| 20 |
+
Prepare model files for Hugging Face deployment
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
run_dir: Path to training run directory
|
| 24 |
+
output_dir: Path to output directory for deployment files
|
| 25 |
+
checkpoint_type: Which checkpoint to use ('best' or 'last')
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
run_dir = Path(run_dir)
|
| 29 |
+
output_dir = Path(output_dir)
|
| 30 |
+
|
| 31 |
+
# Validate run directory
|
| 32 |
+
if not run_dir.exists():
|
| 33 |
+
raise FileNotFoundError(f"Run directory not found: {run_dir}")
|
| 34 |
+
|
| 35 |
+
# Create output directory
|
| 36 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
print(f"π Created output directory: {output_dir}")
|
| 38 |
+
|
| 39 |
+
# 1. Copy model checkpoint
|
| 40 |
+
print("\nπ Looking for model checkpoint...")
|
| 41 |
+
checkpoints_dir = run_dir / "checkpoints"
|
| 42 |
+
|
| 43 |
+
if checkpoint_type == "best":
|
| 44 |
+
# Find best checkpoint (lowest val_loss in filename)
|
| 45 |
+
checkpoints = list(checkpoints_dir.glob("*epoch*.ckpt"))
|
| 46 |
+
if not checkpoints:
|
| 47 |
+
print("β οΈ No best checkpoint found, using last.ckpt")
|
| 48 |
+
checkpoint_type = "last"
|
| 49 |
+
else:
|
| 50 |
+
# Sort by val_loss in filename
|
| 51 |
+
checkpoint_path = sorted(checkpoints)[0]
|
| 52 |
+
|
| 53 |
+
if checkpoint_type == "last":
|
| 54 |
+
checkpoint_path = checkpoints_dir / "last.ckpt"
|
| 55 |
+
|
| 56 |
+
if not checkpoint_path.exists():
|
| 57 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 58 |
+
|
| 59 |
+
print(f"β
Found checkpoint: {checkpoint_path.name}")
|
| 60 |
+
|
| 61 |
+
# Copy to output with standard name
|
| 62 |
+
output_checkpoint = output_dir / "retinaradar_model.ckpt"
|
| 63 |
+
shutil.copy2(checkpoint_path, output_checkpoint)
|
| 64 |
+
print(f"π Copied to: {output_checkpoint}")
|
| 65 |
+
|
| 66 |
+
# Get checkpoint size
|
| 67 |
+
size_mb = output_checkpoint.stat().st_size / (1024 * 1024)
|
| 68 |
+
print(f" Size: {size_mb:.1f} MB")
|
| 69 |
+
|
| 70 |
+
if size_mb > 5000:
|
| 71 |
+
print(f"β οΈ WARNING: Checkpoint is >5GB. Consider using FP16 quantization.")
|
| 72 |
+
|
| 73 |
+
# 2. Copy metadata
|
| 74 |
+
print("\nπ Looking for metadata...")
|
| 75 |
+
metadata_path = run_dir / "artifacts" / "label_metadata.json"
|
| 76 |
+
|
| 77 |
+
if not metadata_path.exists():
|
| 78 |
+
raise FileNotFoundError(f"Metadata not found: {metadata_path}")
|
| 79 |
+
|
| 80 |
+
print(f"β
Found metadata: {metadata_path}")
|
| 81 |
+
|
| 82 |
+
output_metadata = output_dir / "label_metadata.json"
|
| 83 |
+
shutil.copy2(metadata_path, output_metadata)
|
| 84 |
+
print(f"π Copied to: {output_metadata}")
|
| 85 |
+
|
| 86 |
+
# Load and display metadata info
|
| 87 |
+
with open(metadata_path, 'r') as f:
|
| 88 |
+
metadata = json.load(f)
|
| 89 |
+
|
| 90 |
+
print(f" Features: {', '.join(metadata.get('feature_names', []))}")
|
| 91 |
+
print(f" Total labels: {metadata.get('num_labels', 'unknown')}")
|
| 92 |
+
|
| 93 |
+
# 3. Copy inference package
|
| 94 |
+
print("\nπ Looking for inference package...")
|
| 95 |
+
inference_package_path = run_dir / "artifacts" / "inference_package.json"
|
| 96 |
+
|
| 97 |
+
if inference_package_path.exists():
|
| 98 |
+
output_package = output_dir / "inference_package.json"
|
| 99 |
+
shutil.copy2(inference_package_path, output_package)
|
| 100 |
+
print(f"β
Copied inference package")
|
| 101 |
+
|
| 102 |
+
# Load and display config
|
| 103 |
+
with open(inference_package_path, 'r') as f:
|
| 104 |
+
package = json.load(f)
|
| 105 |
+
|
| 106 |
+
config = package.get('config', {})
|
| 107 |
+
print(f" Model: {config.get('model_name', 'unknown')}")
|
| 108 |
+
print(f" Learning rate: {config.get('learning_rate', 'unknown')}")
|
| 109 |
+
else:
|
| 110 |
+
print("β οΈ Inference package not found (optional)")
|
| 111 |
+
|
| 112 |
+
# 4. Create examples directory
|
| 113 |
+
print("\nπΈ Creating examples directory...")
|
| 114 |
+
examples_dir = output_dir / "examples"
|
| 115 |
+
examples_dir.mkdir(exist_ok=True)
|
| 116 |
+
print(f"β
Created: {examples_dir}")
|
| 117 |
+
print(" π Add 2-3 example images to this directory")
|
| 118 |
+
|
| 119 |
+
# 5. Summary
|
| 120 |
+
print("\n" + "="*60)
|
| 121 |
+
print("β
DEPLOYMENT PREPARATION COMPLETE!")
|
| 122 |
+
print("="*60)
|
| 123 |
+
print(f"\nAll files prepared in: {output_dir}")
|
| 124 |
+
print("\nFiles created:")
|
| 125 |
+
print(" β retinaradar_model.ckpt")
|
| 126 |
+
print(" β label_metadata.json")
|
| 127 |
+
if inference_package_path.exists():
|
| 128 |
+
print(" β inference_package.json")
|
| 129 |
+
print(" β examples/ (empty - add your images)")
|
| 130 |
+
|
| 131 |
+
print("\nπ Next steps:")
|
| 132 |
+
print("1. Add 2-3 example images to examples/ directory")
|
| 133 |
+
print("2. Copy hf_inference.py to this directory")
|
| 134 |
+
print("3. Copy HF_MODEL_README.md as README.md")
|
| 135 |
+
print("4. Follow the deployment guide to upload to Hugging Face")
|
| 136 |
+
|
| 137 |
+
print(f"\nDeployment directory ready: {output_dir}")
|
| 138 |
+
print("="*60)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main():
|
| 142 |
+
parser = argparse.ArgumentParser(
|
| 143 |
+
description='Prepare RetinaRadar model for Hugging Face deployment',
|
| 144 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 145 |
+
epilog="""
|
| 146 |
+
Example:
|
| 147 |
+
python prepare_for_hf.py \\
|
| 148 |
+
--run-dir output/runs/run_ABC123-2025-01-15_120000 \\
|
| 149 |
+
--output-dir ~/retinaradar_hf \\
|
| 150 |
+
--checkpoint best
|
| 151 |
+
"""
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
'--run-dir',
|
| 156 |
+
type=str,
|
| 157 |
+
required=True,
|
| 158 |
+
help='Path to the training run directory'
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
'--output-dir',
|
| 162 |
+
type=str,
|
| 163 |
+
required=True,
|
| 164 |
+
help='Path to output directory for deployment files'
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
'--checkpoint',
|
| 168 |
+
type=str,
|
| 169 |
+
default='best',
|
| 170 |
+
choices=['best', 'last'],
|
| 171 |
+
help='Which checkpoint to use (default: best)'
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
prepare_deployment(
|
| 178 |
+
run_dir=args.run_dir,
|
| 179 |
+
output_dir=args.output_dir,
|
| 180 |
+
checkpoint_type=args.checkpoint
|
| 181 |
+
)
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"\nβ Error: {e}")
|
| 184 |
+
return 1
|
| 185 |
+
|
| 186 |
+
return 0
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
exit(main())
|
reset_git_lfs.sh
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Reset Git History and Setup Git LFS Properly
|
| 3 |
+
# Use this when you need to start fresh with Git LFS tracking
|
| 4 |
+
|
| 5 |
+
set -e # Exit on any error
|
| 6 |
+
|
| 7 |
+
echo "================================================"
|
| 8 |
+
echo "Git History Reset and LFS Setup"
|
| 9 |
+
echo "================================================"
|
| 10 |
+
echo ""
|
| 11 |
+
echo "β οΈ WARNING: This will delete all git history!"
|
| 12 |
+
echo "Make sure you have a backup of your files."
|
| 13 |
+
echo ""
|
| 14 |
+
read -p "Continue? (yes/no): " confirm
|
| 15 |
+
|
| 16 |
+
if [ "$confirm" != "yes" ]; then
|
| 17 |
+
echo "Aborted."
|
| 18 |
+
exit 1
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
echo ""
|
| 22 |
+
echo "π Current directory: $(pwd)"
|
| 23 |
+
echo ""
|
| 24 |
+
|
| 25 |
+
# Step 1: Remove .git directory
|
| 26 |
+
echo "1οΈβ£ Removing old git history..."
|
| 27 |
+
if [ -d .git ]; then
|
| 28 |
+
rm -rf .git
|
| 29 |
+
echo " β
Removed .git directory"
|
| 30 |
+
else
|
| 31 |
+
echo " βΉοΈ No .git directory found"
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
# Step 2: Initialize fresh git repo
|
| 35 |
+
echo ""
|
| 36 |
+
echo "2οΈβ£ Initializing fresh git repository..."
|
| 37 |
+
git init
|
| 38 |
+
echo " β
Git repository initialized"
|
| 39 |
+
|
| 40 |
+
# Step 3: Setup Git LFS
|
| 41 |
+
echo ""
|
| 42 |
+
echo "3οΈβ£ Setting up Git LFS..."
|
| 43 |
+
git lfs install
|
| 44 |
+
echo " β
Git LFS installed"
|
| 45 |
+
|
| 46 |
+
# Step 4: Track large files with LFS
|
| 47 |
+
echo ""
|
| 48 |
+
echo "4οΈβ£ Configuring Git LFS tracking..."
|
| 49 |
+
|
| 50 |
+
# Track model checkpoints
|
| 51 |
+
git lfs track "*.ckpt"
|
| 52 |
+
git lfs track "*.pth"
|
| 53 |
+
git lfs track "*.bin"
|
| 54 |
+
git lfs track "*.h5"
|
| 55 |
+
git lfs track "*.pkl"
|
| 56 |
+
git lfs track "*.pickle"
|
| 57 |
+
|
| 58 |
+
# Track images
|
| 59 |
+
git lfs track "*.png"
|
| 60 |
+
git lfs track "*.jpg"
|
| 61 |
+
git lfs track "*.jpeg"
|
| 62 |
+
git lfs track "*.gif"
|
| 63 |
+
git lfs track "*.bmp"
|
| 64 |
+
git lfs track "*.tiff"
|
| 65 |
+
|
| 66 |
+
# Track compressed files
|
| 67 |
+
git lfs track "*.zip"
|
| 68 |
+
git lfs track "*.tar.gz"
|
| 69 |
+
git lfs track "*.tar"
|
| 70 |
+
|
| 71 |
+
echo " β
Git LFS tracking configured"
|
| 72 |
+
|
| 73 |
+
# Step 5: Show what's being tracked
|
| 74 |
+
echo ""
|
| 75 |
+
echo "π Files tracked by Git LFS:"
|
| 76 |
+
cat .gitattributes
|
| 77 |
+
|
| 78 |
+
# Step 6: Add all files
|
| 79 |
+
echo ""
|
| 80 |
+
echo "5οΈβ£ Adding all files to git..."
|
| 81 |
+
git add .gitattributes
|
| 82 |
+
git add .
|
| 83 |
+
|
| 84 |
+
echo " β
Files added"
|
| 85 |
+
|
| 86 |
+
# Step 7: Show status
|
| 87 |
+
echo ""
|
| 88 |
+
echo "π Git status:"
|
| 89 |
+
git status
|
| 90 |
+
|
| 91 |
+
# Step 8: Create initial commit
|
| 92 |
+
echo ""
|
| 93 |
+
read -p "Create initial commit? (yes/no): " commit_confirm
|
| 94 |
+
|
| 95 |
+
if [ "$commit_confirm" = "yes" ]; then
|
| 96 |
+
echo ""
|
| 97 |
+
echo "6οΈβ£ Creating initial commit..."
|
| 98 |
+
git commit -m "Initial commit with Git LFS"
|
| 99 |
+
echo " β
Initial commit created"
|
| 100 |
+
else
|
| 101 |
+
echo " βοΈ Skipped commit"
|
| 102 |
+
fi
|
| 103 |
+
|
| 104 |
+
# Step 9: Add remote (if needed)
|
| 105 |
+
echo ""
|
| 106 |
+
read -p "Add remote repository? (yes/no): " remote_confirm
|
| 107 |
+
|
| 108 |
+
if [ "$remote_confirm" = "yes" ]; then
|
| 109 |
+
echo ""
|
| 110 |
+
read -p "Enter remote URL (e.g., https://huggingface.co/user/repo): " remote_url
|
| 111 |
+
|
| 112 |
+
git remote add origin "$remote_url"
|
| 113 |
+
echo " β
Remote added: $remote_url"
|
| 114 |
+
|
| 115 |
+
echo ""
|
| 116 |
+
read -p "Push to remote? (yes/no): " push_confirm
|
| 117 |
+
|
| 118 |
+
if [ "$push_confirm" = "yes" ]; then
|
| 119 |
+
echo ""
|
| 120 |
+
echo "π Pushing to remote..."
|
| 121 |
+
echo " (You may be prompted for credentials)"
|
| 122 |
+
git push -u origin main --force
|
| 123 |
+
echo " β
Pushed to remote"
|
| 124 |
+
fi
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
echo ""
|
| 128 |
+
echo "================================================"
|
| 129 |
+
echo "β
Git reset complete!"
|
| 130 |
+
echo "================================================"
|
| 131 |
+
echo ""
|
| 132 |
+
echo "π Summary:"
|
| 133 |
+
echo " - Old git history removed"
|
| 134 |
+
echo " - Fresh repository initialized"
|
| 135 |
+
echo " - Git LFS configured"
|
| 136 |
+
echo " - Large files tracked with LFS"
|
| 137 |
+
echo ""
|
| 138 |
+
echo "π Verify LFS files:"
|
| 139 |
+
echo " git lfs ls-files"
|
| 140 |
+
echo ""
|
| 141 |
+
echo "π€ If you haven't pushed yet:"
|
| 142 |
+
echo " git remote add origin <your-repo-url>"
|
| 143 |
+
echo " git push -u origin main --force"
|
| 144 |
+
echo ""
|
retinaradar_model.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c3b9973d879b2ba595065979a4264cb051bfd50633f37e26fdba98233828f68
|
| 3 |
+
size 48833325
|