|
|
|
|
|
--- |
|
|
base_model: ResNet50 |
|
|
tags: |
|
|
- image-classification |
|
|
- diabetic-retinopathy |
|
|
- onnx |
|
|
license: mit |
|
|
--- |
|
|
# ResNet50-APTOS-DR-ONNX Model |
|
|
|
|
|
This repository contains a ResNet50 model, originally trained for Diabetic Retinopathy (DR) detection on the APTOS dataset, exported to ONNX format for efficient inference. |
|
|
|
|
|
## Model Overview |
|
|
|
|
|
- **Architecture**: ResNet50 |
|
|
- **Task**: Diabetic Retinopathy Classification (5 classes: No DR, Mild DR, Moderate DR, Severe DR, Proliferative DR) |
|
|
- **Format**: ONNX (Opset 18) |
|
|
|
|
|
## Usage (ONNX Inference) |
|
|
|
|
|
To use this model for inference, you will need the `onnxruntime` library. Below is a basic example: |
|
|
|
|
|
```python |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
ONNX_MODEL_PATH = "mithu-vit.onnx" # Path to the downloaded ONNX model |
|
|
CLASSES = ["No DR", "Mild DR", "Moderate DR", "Severe DR", "Proliferative DR"] |
|
|
|
|
|
# Image preprocessing (matching the training pipeline) |
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def predict_image(image_path): |
|
|
img = Image.open(image_path).convert('RGB') |
|
|
input_tensor = preprocess(img) |
|
|
input_numpy = input_tensor.unsqueeze(0).numpy() # Add batch dimension |
|
|
|
|
|
session = ort.InferenceSession(ONNX_MODEL_PATH) |
|
|
input_name = session.get_inputs()[0].name |
|
|
output_name = session.get_outputs()[0].name |
|
|
|
|
|
outputs = session.run([output_name], {input_name: input_numpy}) |
|
|
logits = outputs[0][0] |
|
|
probs = np.exp(logits) / np.sum(np.exp(logits)) |
|
|
pred_index = np.argmax(probs) |
|
|
|
|
|
print(f"Predicted Class: {CLASSES[pred_index]} (Class {pred_index})") |
|
|
print(f"Confidence: {probs[pred_index] * 100:.2f}%") |
|
|
print("All Probabilities:") |
|
|
for i, p in enumerate(probs): |
|
|
print(f" {CLASSES[i]}: {p*100:.2f}%") |
|
|
|
|
|
# Example usage: |
|
|
# predict_image("path/to/your/image.jpg") |
|
|
``` |
|
|
|
|
|
## Fine-tuning |
|
|
|
|
|
The original model was trained using PyTorch. If you wish to fine-tune this model on a custom dataset or for a slightly different task, you can use the original PyTorch weights (if available) or adapt the ONNX model for further training in a suitable framework. |
|
|
|
|
|
Steps for fine-tuning generally involve: |
|
|
1. **Load the pre-trained model**: Start with the original PyTorch model or a version compatible with transfer learning. |
|
|
2. **Prepare your dataset**: Ensure your images are properly labeled and preprocessed (resized to 224x224, normalized with ImageNet stats). |
|
|
3. **Modify the head**: Replace the final classification layer to match the number of classes in your new dataset. |
|
|
4. **Define optimizer and loss function**: Choose appropriate settings for your fine-tuning task. |
|
|
5. **Train**: Fine-tune the model, typically with a lower learning rate than initial training, focusing on training the new head and potentially unfreezing earlier layers for more granular adjustments. |
|
|
6. **Export to ONNX**: After fine-tuning, export your updated model to ONNX format following similar steps to the original export process. |
|
|
|
|
|
### Recommended Frameworks for Fine-tuning: |
|
|
- [PyTorch](https://pytorch.org/) |
|
|
- [TensorFlow/Keras](https://www.tensorflow.org/) |
|
|
|
|
|
|