File size: 1,871 Bytes
3dd0c9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24064bf
 
3dd0c9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
---
tags:
- vision
- image-classification
- onnx
- mobilevit
- medical
datasets:
- rohithgowdax/processed-dr
library_name: transformers
widget:
- src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg
  example_title: Example Eye Scan
---

# Mithu-ViT: Diabetic Retinopathy Classifier

This is a **MobileViT (Small)** model fine-tuned on the [Processed Diabetic Retinopathy dataset](https://www.kaggle.com/datasets/rohithgowdax/processed-dr).

It classifies retina scans into 5 severity levels:
- **0**: No DR
- **1**: Mild
- **2**: Moderate
- **3**: Severe
- **4**: Proliferative DR

## Model Details
- **Architecture**: MobileViT-Small (Apple)
- **Format**: PyTorch (`pytorch_model.bin`) and ONNX (`mithu-vit.onnx`)
- **Resolution**: 256x256
- **License**: Apache 2.0

## Usage (PyTorch)

```python
from transformers import MobileViTForImageClassification, MobileViTImageProcessor
from PIL import Image
import torch

# 1. Load Model
model = MobileViTForImageClassification.from_pretrained("Shadow0482/mithu-mobilevit-dr")
processor = MobileViTImageProcessor.from_pretrained("Shadow0482/mithu-mobilevit-dr")

# 2. Load Image
image = Image.open("path_to_eye_scan.jpg").convert("RGB")

# 3. Predict
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
    
print("Predicted Class:", model.config.id2label[outputs.logits.argmax(-1).item()])

```

## Usage (ONNX)

```python
import onnxruntime as ort
import numpy as np
from PIL import Image

# 1. Start Session
session = ort.InferenceSession("mithu-vit.onnx")

# 2. Prepare Input
img = Image.open("test.jpg").resize((256, 256))
img_data = np.array(img).transpose(2, 0, 1).astype(np.float32) / 255.0
img_data = np.expand_dims(img_data, axis=0)

# 3. Run
outputs = session.run(None, {"pixel_values": img_data})
print("Logits:", outputs[0])

```