File size: 3,224 Bytes
74cc0a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

---
base_model: ResNet50
tags:
- image-classification
- diabetic-retinopathy
- onnx
license: mit
---
# ResNet50-APTOS-DR-ONNX Model

This repository contains a ResNet50 model, originally trained for Diabetic Retinopathy (DR) detection on the APTOS dataset, exported to ONNX format for efficient inference.

## Model Overview

- **Architecture**: ResNet50
- **Task**: Diabetic Retinopathy Classification (5 classes: No DR, Mild DR, Moderate DR, Severe DR, Proliferative DR)
- **Format**: ONNX (Opset 18)

## Usage (ONNX Inference)

To use this model for inference, you will need the `onnxruntime` library. Below is a basic example:

```python
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms

ONNX_MODEL_PATH = "mithu-vit.onnx" # Path to the downloaded ONNX model
CLASSES = ["No DR", "Mild DR", "Moderate DR", "Severe DR", "Proliferative DR"]

# Image preprocessing (matching the training pipeline)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict_image(image_path):
    img = Image.open(image_path).convert('RGB')
    input_tensor = preprocess(img)
    input_numpy = input_tensor.unsqueeze(0).numpy() # Add batch dimension

    session = ort.InferenceSession(ONNX_MODEL_PATH)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    outputs = session.run([output_name], {input_name: input_numpy})
    logits = outputs[0][0]
    probs = np.exp(logits) / np.sum(np.exp(logits))
    pred_index = np.argmax(probs)

    print(f"Predicted Class: {CLASSES[pred_index]} (Class {pred_index})")
    print(f"Confidence: {probs[pred_index] * 100:.2f}%")
    print("All Probabilities:")
    for i, p in enumerate(probs):
        print(f"  {CLASSES[i]}: {p*100:.2f}%")

# Example usage:
# predict_image("path/to/your/image.jpg")
```

## Fine-tuning

The original model was trained using PyTorch. If you wish to fine-tune this model on a custom dataset or for a slightly different task, you can use the original PyTorch weights (if available) or adapt the ONNX model for further training in a suitable framework.

Steps for fine-tuning generally involve:
1. **Load the pre-trained model**: Start with the original PyTorch model or a version compatible with transfer learning.
2. **Prepare your dataset**: Ensure your images are properly labeled and preprocessed (resized to 224x224, normalized with ImageNet stats).
3. **Modify the head**: Replace the final classification layer to match the number of classes in your new dataset.
4. **Define optimizer and loss function**: Choose appropriate settings for your fine-tuning task.
5. **Train**: Fine-tune the model, typically with a lower learning rate than initial training, focusing on training the new head and potentially unfreezing earlier layers for more granular adjustments.
6. **Export to ONNX**: After fine-tuning, export your updated model to ONNX format following similar steps to the original export process.

### Recommended Frameworks for Fine-tuning:
- [PyTorch](https://pytorch.org/)
- [TensorFlow/Keras](https://www.tensorflow.org/)