|
|
--- |
|
|
tags: |
|
|
- vision |
|
|
- image-classification |
|
|
- onnx |
|
|
- mobilevit |
|
|
- medical |
|
|
datasets: |
|
|
- rohithgowdax/processed-dr |
|
|
library_name: transformers |
|
|
widget: |
|
|
- src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg |
|
|
example_title: Example Eye Scan |
|
|
--- |
|
|
|
|
|
# Mithu-ViT: Diabetic Retinopathy Classifier |
|
|
|
|
|
This is a **MobileViT (Small)** model fine-tuned on the [Processed Diabetic Retinopathy dataset](https://www.kaggle.com/datasets/rohithgowdax/processed-dr). |
|
|
|
|
|
It classifies retina scans into 5 severity levels: |
|
|
- **0**: No DR |
|
|
- **1**: Mild |
|
|
- **2**: Moderate |
|
|
- **3**: Severe |
|
|
- **4**: Proliferative DR |
|
|
|
|
|
## Model Details |
|
|
- **Architecture**: MobileViT-Small (Apple) |
|
|
- **Format**: PyTorch (`pytorch_model.bin`) and ONNX (`mithu-vit.onnx`) |
|
|
- **Resolution**: 256x256 |
|
|
- **License**: Apache 2.0 |
|
|
|
|
|
## Usage (PyTorch) |
|
|
|
|
|
```python |
|
|
from transformers import MobileViTForImageClassification, MobileViTImageProcessor |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
# 1. Load Model |
|
|
model = MobileViTForImageClassification.from_pretrained("Shadow0482/mithu-mobilevit-dr") |
|
|
processor = MobileViTImageProcessor.from_pretrained("Shadow0482/mithu-mobilevit-dr") |
|
|
|
|
|
# 2. Load Image |
|
|
image = Image.open("path_to_eye_scan.jpg").convert("RGB") |
|
|
|
|
|
# 3. Predict |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
print("Predicted Class:", model.config.id2label[outputs.logits.argmax(-1).item()]) |
|
|
|
|
|
``` |
|
|
|
|
|
## Usage (ONNX) |
|
|
|
|
|
```python |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
# 1. Start Session |
|
|
session = ort.InferenceSession("mithu-vit.onnx") |
|
|
|
|
|
# 2. Prepare Input |
|
|
img = Image.open("test.jpg").resize((256, 256)) |
|
|
img_data = np.array(img).transpose(2, 0, 1).astype(np.float32) / 255.0 |
|
|
img_data = np.expand_dims(img_data, axis=0) |
|
|
|
|
|
# 3. Run |
|
|
outputs = session.run(None, {"pixel_values": img_data}) |
|
|
print("Logits:", outputs[0]) |
|
|
|
|
|
``` |
|
|
|
|
|
|