File size: 3,075 Bytes
3dedef7
 
 
 
 
 
 
 
 
 
937a5f4
 
5a675f8
 
937a5f4
bb9025f
5a675f8
 
 
bb9025f
 
 
 
 
5a675f8
bb9025f
 
 
 
5a675f8
904babc
3a3b529
 
904babc
 
 
 
5a675f8
904babc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a675f8
 
 
 
904babc
 
5a675f8
904babc
 
5a675f8
904babc
5a675f8
 
 
904babc
 
5a675f8
 
904babc
5a675f8
904babc
 
 
 
 
5a675f8
904babc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a675f8
 
bb9025f
3dedef7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
---
license: mit
language:
- en
base_model:
- sakshamkr1/ResNet50-APTOS-DR
pipeline_tag: image-classification
tags:
- onnx
---
# ResNet50-APTOS-DR (ONNX) 
**5-class Diabetic Retinopathy classifier** ready for edge devices.

**Original model**: sakshamkr1/ResNet50-APTOS-DR  
**Format**: ONNX 
**Input shape**: (batch, 3, 224, 224) RGB fundus image  
**Output**: 5 classes (APTOS 2019)

### Classes
- 0: No DR  
- 1: Mild DR  
- 2: Moderate DR  
- 3: Severe DR  
- 4: Proliferative DR  

### Perfect 
- Model size: ~105 MB (single file)  
- RAM usage: ~150-220 MB  
- Speed: ~0.8–1.5 seconds per image on CPU

### Quick test code for colab

```python
# ============================
# 1. Install dependencies
# ============================
!pip install -q onnxruntime huggingface_hub pillow torchvision matplotlib

# ============================
# 2. Download the ONNX model
# ============================
from huggingface_hub import hf_hub_download

print("📥 Downloading iris-vit.onnx ...")
model_path = hf_hub_download(
    repo_id="Shadow0482/iris-onnx",
    filename="iris-vit.onnx"
)
print(f"✅ Model downloaded: {model_path}")

# ============================
# 3. Load model & define inference
# ============================
import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from google.colab import files

# Load ONNX session (CPU is fine & fast for this ~105 MB model)
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])

# Preprocessing (exactly what the model expects)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print("✅ Model loaded successfully!")

# ============================
# 4. Upload a fundus image & run inference
# ============================
print("\n📤 Please upload a color fundus/retina image (JPG/PNG)...")
uploaded = files.upload()

if uploaded:
    img_path = list(uploaded.keys())[0]
    img = Image.open(img_path).convert("RGB")
    
    # Preprocess
    input_tensor = transform(img).unsqueeze(0).numpy().astype(np.float32)
    
    # Inference
    outputs = session.run(None, {"input": input_tensor})[0][0]
    
    # Softmax
    exp_scores = np.exp(outputs)
    probs = exp_scores / np.sum(exp_scores)
    pred_idx = np.argmax(probs)
    
    classes = ["No DR", "Mild DR", "Moderate DR", "Severe DR", "Proliferative DR"]
    
    print(f"\n🎯 **Prediction:** {classes[pred_idx]}")
    print(f"   Confidence: {probs[pred_idx]*100:.1f}%")
    print("\n📊 Full probabilities:")
    for name, p in zip(classes, probs):
        print(f"   {name:20} → {p*100:5.1f}%")
    
    # Show image
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.title(f"Predicted: {classes[pred_idx]} ({probs[pred_idx]*100:.1f}%)", fontsize=14)
    plt.axis("off")
    plt.show()
```

**License**: MIT  
Made for low-resource diabetic retinopathy screening ❤️