Shadow0482 commited on
Commit
74cc0a1
·
verified ·
1 Parent(s): 1a6f73c

Add README.md with usage and fine-tuning instructions

Browse files
Files changed (1) hide show
  1. README.md +79 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ base_model: ResNet50
4
+ tags:
5
+ - image-classification
6
+ - diabetic-retinopathy
7
+ - onnx
8
+ license: mit
9
+ ---
10
+ # ResNet50-APTOS-DR-ONNX Model
11
+
12
+ This repository contains a ResNet50 model, originally trained for Diabetic Retinopathy (DR) detection on the APTOS dataset, exported to ONNX format for efficient inference.
13
+
14
+ ## Model Overview
15
+
16
+ - **Architecture**: ResNet50
17
+ - **Task**: Diabetic Retinopathy Classification (5 classes: No DR, Mild DR, Moderate DR, Severe DR, Proliferative DR)
18
+ - **Format**: ONNX (Opset 18)
19
+
20
+ ## Usage (ONNX Inference)
21
+
22
+ To use this model for inference, you will need the `onnxruntime` library. Below is a basic example:
23
+
24
+ ```python
25
+ import onnxruntime as ort
26
+ import numpy as np
27
+ from PIL import Image
28
+ from torchvision import transforms
29
+
30
+ ONNX_MODEL_PATH = "mithu-vit.onnx" # Path to the downloaded ONNX model
31
+ CLASSES = ["No DR", "Mild DR", "Moderate DR", "Severe DR", "Proliferative DR"]
32
+
33
+ # Image preprocessing (matching the training pipeline)
34
+ preprocess = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+ ])
39
+
40
+ def predict_image(image_path):
41
+ img = Image.open(image_path).convert('RGB')
42
+ input_tensor = preprocess(img)
43
+ input_numpy = input_tensor.unsqueeze(0).numpy() # Add batch dimension
44
+
45
+ session = ort.InferenceSession(ONNX_MODEL_PATH)
46
+ input_name = session.get_inputs()[0].name
47
+ output_name = session.get_outputs()[0].name
48
+
49
+ outputs = session.run([output_name], {input_name: input_numpy})
50
+ logits = outputs[0][0]
51
+ probs = np.exp(logits) / np.sum(np.exp(logits))
52
+ pred_index = np.argmax(probs)
53
+
54
+ print(f"Predicted Class: {CLASSES[pred_index]} (Class {pred_index})")
55
+ print(f"Confidence: {probs[pred_index] * 100:.2f}%")
56
+ print("All Probabilities:")
57
+ for i, p in enumerate(probs):
58
+ print(f" {CLASSES[i]}: {p*100:.2f}%")
59
+
60
+ # Example usage:
61
+ # predict_image("path/to/your/image.jpg")
62
+ ```
63
+
64
+ ## Fine-tuning
65
+
66
+ The original model was trained using PyTorch. If you wish to fine-tune this model on a custom dataset or for a slightly different task, you can use the original PyTorch weights (if available) or adapt the ONNX model for further training in a suitable framework.
67
+
68
+ Steps for fine-tuning generally involve:
69
+ 1. **Load the pre-trained model**: Start with the original PyTorch model or a version compatible with transfer learning.
70
+ 2. **Prepare your dataset**: Ensure your images are properly labeled and preprocessed (resized to 224x224, normalized with ImageNet stats).
71
+ 3. **Modify the head**: Replace the final classification layer to match the number of classes in your new dataset.
72
+ 4. **Define optimizer and loss function**: Choose appropriate settings for your fine-tuning task.
73
+ 5. **Train**: Fine-tune the model, typically with a lower learning rate than initial training, focusing on training the new head and potentially unfreezing earlier layers for more granular adjustments.
74
+ 6. **Export to ONNX**: After fine-tuning, export your updated model to ONNX format following similar steps to the original export process.
75
+
76
+ ### Recommended Frameworks for Fine-tuning:
77
+ - [PyTorch](https://pytorch.org/)
78
+ - [TensorFlow/Keras](https://www.tensorflow.org/)
79
+