add large model
Browse files- README.md +74 -29
- axmodel_infer.py +27 -13
- model/bird-l/AX615/bird_615_npu2.axmodel +3 -0
- model/bird-l/AX620E/bird_630_npu1.axmodel +3 -0
- model/bird-l/AX620E/bird_630_npu2.axmodel +3 -0
- model/bird-l/AX650/bird_650_npu3.axmodel +3 -0
- onnx_infer.py +26 -12
- prediction_result_top5.png +2 -2
- quant/{Bird.json → bird-l.json} +2 -2
- quant/bird-m.json +3 -0
- quant/bird-s.json +3 -0
README.md
CHANGED
|
@@ -13,8 +13,6 @@ tags:
|
|
| 13 |
|
| 14 |
This project only to show a demo of Bird Species Classification model within 1400+ species of birds.
|
| 15 |
|
| 16 |
-
The model is trained with 224x224 resolution.
|
| 17 |
-
|
| 18 |
This model has been converted to run on the Axera NPU using **w8a16** quantization.
|
| 19 |
|
| 20 |
This model has been optimized with the following LoRA:
|
|
@@ -23,6 +21,11 @@ Compatible with Pulsar2 version: 5.1
|
|
| 23 |
|
| 24 |
## Convert tools links:
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
For those who are interested in model conversion, you can try to export axmodel through
|
| 27 |
|
| 28 |
- [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
|
|
@@ -38,39 +41,62 @@ For those who are interested in model conversion, you can try to export axmodel
|
|
| 38 |
- [Module-LLM](https://docs.m5stack.com/zh_CN/module/Module-LLM)
|
| 39 |
- [LLM630 Compute Kit](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
|
| 40 |
|
| 41 |
-
| Models | Platforms | latency
|
| 42 |
-
| -------------| -------------| ------------- | --------------| --------------|
|
| 43 |
-
| | AX650 | 0.19ms
|
| 44 |
-
| bird-s | AX630C | 0.54ms
|
| 45 |
-
| | AX615 | 0.87ms
|
| 46 |
-
| | AX650 | 0.58ms
|
| 47 |
-
| bird-m | AX630C | 2.52ms
|
| 48 |
-
| | AX615 | 4.97ms
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
## How to use
|
| 51 |
|
| 52 |
Download all files from this repository to the device
|
| 53 |
|
| 54 |
```
|
| 55 |
-
root@ax650:~/Bird-Species-Classification# tree
|
| 56 |
.
|
| 57 |
├── README.md
|
| 58 |
├── axmodel_infer.py
|
| 59 |
├── class_name.txt
|
|
|
|
| 60 |
├── model
|
| 61 |
-
│ ├──
|
| 62 |
-
│ │ ├──
|
| 63 |
-
│ │ └── bird_615_npu2.axmodel
|
| 64 |
-
│ ├── AX620E
|
| 65 |
-
│ │ ├── bird_630_npu1.axmodel
|
| 66 |
-
│ │ └── bird_630_npu2.axmodel
|
| 67 |
-
│ └── AX650
|
| 68 |
-
│ └── bird_650_npu3.axmodel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
├── onnx_infer.py
|
| 70 |
├── prediction_result_top5.png
|
| 71 |
├── quant
|
| 72 |
-
│ ├──
|
| 73 |
-
│ ├──
|
|
|
|
| 74 |
│ └── bird.tar.gz
|
| 75 |
└── test_images
|
| 76 |
├── 03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg
|
|
@@ -81,7 +107,7 @@ root@ax650:~/Bird-Species-Classification# tree
|
|
| 81 |
├── 04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg
|
| 82 |
└── 04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg
|
| 83 |
|
| 84 |
-
|
| 85 |
|
| 86 |
```
|
| 87 |
|
|
@@ -96,27 +122,46 @@ wget https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3rc0/axengin
|
|
| 96 |
pip install axengine-0.1.3-py3-none-any.whl
|
| 97 |
```
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
## Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
|
| 100 |
|
| 101 |
```
|
| 102 |
-
root@ax650:~/
|
| 103 |
[INFO] Available providers: ['AxEngineExecutionProvider']
|
|
|
|
| 104 |
Loading ONNX model with providers: ['AxEngineExecutionProvider']
|
| 105 |
[INFO] Using provider: AxEngineExecutionProvider
|
| 106 |
[INFO] Chip type: ChipType.MC50
|
| 107 |
[INFO] VNPU type: VNPUType.DISABLED
|
| 108 |
[INFO] Engine version: 2.12.0s
|
| 109 |
[INFO] Model type: 2 (triple core)
|
| 110 |
-
[INFO] Compiler version: 5.1-patch1
|
| 111 |
|
| 112 |
Image: test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 113 |
Top-5 Predictions:
|
| 114 |
-
#1: 04251_Animalia_Chordata_Aves_Passeriformes_Tityridae_Tityra_semifasciata (0.
|
| 115 |
-
#2:
|
| 116 |
-
#3:
|
| 117 |
-
#4:
|
| 118 |
-
#5:
|
| 119 |
Result saved to: prediction_result_top5.png
|
|
|
|
| 120 |
```
|
| 121 |
|
| 122 |
output:
|
|
|
|
| 13 |
|
| 14 |
This project only to show a demo of Bird Species Classification model within 1400+ species of birds.
|
| 15 |
|
|
|
|
|
|
|
| 16 |
This model has been converted to run on the Axera NPU using **w8a16** quantization.
|
| 17 |
|
| 18 |
This model has been optimized with the following LoRA:
|
|
|
|
| 21 |
|
| 22 |
## Convert tools links:
|
| 23 |
|
| 24 |
+
Do model convert from onnx to axmodel with commands like:
|
| 25 |
+
```
|
| 26 |
+
pulsar2 build --config ./quant/bird-l.json
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
For those who are interested in model conversion, you can try to export axmodel through
|
| 30 |
|
| 31 |
- [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
|
|
|
|
| 41 |
- [Module-LLM](https://docs.m5stack.com/zh_CN/module/Module-LLM)
|
| 42 |
- [LLM630 Compute Kit](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
|
| 43 |
|
| 44 |
+
| Models | Platforms | latency | Top1 Accuracy | Top5 Accuracy | CMM size(MB) |
|
| 45 |
+
| -------------| -------------| -------------| --------------| --------------| --------------|
|
| 46 |
+
| | AX650 | 0.19ms | | | |
|
| 47 |
+
| bird-s | AX630C | 0.54ms | 44% | 66% | 1.07 |
|
| 48 |
+
| | AX615 | 0.87ms | | | |
|
| 49 |
+
| | AX650 | 0.58ms | | | |
|
| 50 |
+
| bird-m | AX630C | 2.52ms | 59% | 79% | 12.2 |
|
| 51 |
+
| | AX615 | 4.97ms | | | |
|
| 52 |
+
| | AX650 | 5.60ms | | | |
|
| 53 |
+
| bird-l | AX630C | 35.2ms | 86% | 95% | 29.6 |
|
| 54 |
+
| | AX615 | 64.1ms | | | |
|
| 55 |
|
| 56 |
## How to use
|
| 57 |
|
| 58 |
Download all files from this repository to the device
|
| 59 |
|
| 60 |
```
|
| 61 |
+
(base) root@ax650:~/Bird-Species-Classification# tree
|
| 62 |
.
|
| 63 |
├── README.md
|
| 64 |
├── axmodel_infer.py
|
| 65 |
├── class_name.txt
|
| 66 |
+
├── config.json
|
| 67 |
├── model
|
| 68 |
+
│ ├── bird-l
|
| 69 |
+
│ │ ├── AX615
|
| 70 |
+
│ │ │ └── bird_615_npu2.axmodel
|
| 71 |
+
│ │ ├── AX620E
|
| 72 |
+
│ │ │ ├── bird_630_npu1.axmodel
|
| 73 |
+
│ │ │ └── bird_630_npu2.axmodel
|
| 74 |
+
│ │ └── AX650
|
| 75 |
+
│ │ └── bird_650_npu3.axmodel
|
| 76 |
+
│ ├── bird-m
|
| 77 |
+
│ │ ├── AX615
|
| 78 |
+
│ │ │ ├── bird_615_npu1.axmodel
|
| 79 |
+
│ │ │ └── bird_615_npu2.axmodel
|
| 80 |
+
│ │ ├── AX620E
|
| 81 |
+
│ │ │ ├── bird_630_npu1.axmodel
|
| 82 |
+
│ │ │ └── bird_630_npu2.axmodel
|
| 83 |
+
│ │ └── AX650
|
| 84 |
+
│ │ └── bird_650_npu3.axmodel
|
| 85 |
+
│ └── bird-s
|
| 86 |
+
│ ├── AX615
|
| 87 |
+
│ │ ├── bird_615_npu1.axmodel
|
| 88 |
+
│ │ └── bird_615_npu2.axmodel
|
| 89 |
+
│ ├── AX630C
|
| 90 |
+
│ │ ├── bird_630_npu1.axmodel
|
| 91 |
+
│ │ └── bird_630_npu2.axmodel
|
| 92 |
+
│ └── AX650
|
| 93 |
+
│ └── bird_650_npu3.axmodel
|
| 94 |
├── onnx_infer.py
|
| 95 |
├── prediction_result_top5.png
|
| 96 |
├── quant
|
| 97 |
+
│ ├── bird-l.json
|
| 98 |
+
│ ├── bird-m.json
|
| 99 |
+
│ ├── bird-s.json
|
| 100 |
│ └── bird.tar.gz
|
| 101 |
└── test_images
|
| 102 |
├── 03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg
|
|
|
|
| 107 |
├── 04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg
|
| 108 |
└── 04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg
|
| 109 |
|
| 110 |
+
15 directories, 31 files
|
| 111 |
|
| 112 |
```
|
| 113 |
|
|
|
|
| 122 |
pip install axengine-0.1.3-py3-none-any.whl
|
| 123 |
```
|
| 124 |
|
| 125 |
+
## Inference with ONNX model
|
| 126 |
+
```
|
| 127 |
+
root@ebba5440b03c:/home/# python onnx_infer.py -m Birdmodel_inat_bird-l.onnx --image test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 128 |
+
build predictor with Birdmodel_inat_bird-l.onnx...
|
| 129 |
+
Loading ONNX model with providers: ['CPUExecutionProvider']
|
| 130 |
+
|
| 131 |
+
Image: test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 132 |
+
Top-5 Predictions:
|
| 133 |
+
#1: 04251_Animalia_Chordata_Aves_Passeriformes_Tityridae_Tityra_semifasciata (0.9231)
|
| 134 |
+
#2: 04019_Animalia_Chordata_Aves_Passeriformes_Oriolidae_Sphecotheres_vieilloti (0.0021)
|
| 135 |
+
#3: 03233_Animalia_Chordata_Aves_Anseriformes_Anatidae_Cairina_moschata (0.0006)
|
| 136 |
+
#4: 04219_Animalia_Chordata_Aves_Passeriformes_Thraupidae_Paroaria_capitata (0.0004)
|
| 137 |
+
#5: 03912_Animalia_Chordata_Aves_Passeriformes_Laniidae_Lanius_minor (0.0003)
|
| 138 |
+
Result saved to: prediction_result_top5.png
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
## Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
|
| 143 |
|
| 144 |
```
|
| 145 |
+
root@ax650:~/bird# python3 axmodel_infer.py -m bird_650_npu3.axmodel -i test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 146 |
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 147 |
+
build predictor with bird_650_npu3.axmodel...
|
| 148 |
Loading ONNX model with providers: ['AxEngineExecutionProvider']
|
| 149 |
[INFO] Using provider: AxEngineExecutionProvider
|
| 150 |
[INFO] Chip type: ChipType.MC50
|
| 151 |
[INFO] VNPU type: VNPUType.DISABLED
|
| 152 |
[INFO] Engine version: 2.12.0s
|
| 153 |
[INFO] Model type: 2 (triple core)
|
| 154 |
+
[INFO] Compiler version: 5.1-patch1 8c5871d5
|
| 155 |
|
| 156 |
Image: test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 157 |
Top-5 Predictions:
|
| 158 |
+
#1: 04251_Animalia_Chordata_Aves_Passeriformes_Tityridae_Tityra_semifasciata (0.9137)
|
| 159 |
+
#2: 04019_Animalia_Chordata_Aves_Passeriformes_Oriolidae_Sphecotheres_vieilloti (0.0022)
|
| 160 |
+
#3: 03233_Animalia_Chordata_Aves_Anseriformes_Anatidae_Cairina_moschata (0.0006)
|
| 161 |
+
#4: 03912_Animalia_Chordata_Aves_Passeriformes_Laniidae_Lanius_minor (0.0005)
|
| 162 |
+
#5: 04219_Animalia_Chordata_Aves_Passeriformes_Thraupidae_Paroaria_capitata (0.0004)
|
| 163 |
Result saved to: prediction_result_top5.png
|
| 164 |
+
|
| 165 |
```
|
| 166 |
|
| 167 |
output:
|
axmodel_infer.py
CHANGED
|
@@ -21,16 +21,17 @@ plt.rcParams['axes.unicode_minus'] = False
|
|
| 21 |
|
| 22 |
class BirdPredictorONNX:
|
| 23 |
"""Bird classification predictor based on ONNX Runtime"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, class_name_file, model_file):
|
| 26 |
"""
|
| 27 |
Initialize the predictor.
|
| 28 |
Defaults to AxEngineExecutionProvider.
|
| 29 |
"""
|
| 30 |
-
self.rgb_mean =
|
| 31 |
-
self.rgb_std =
|
|
|
|
| 32 |
self.classes = self.load_classes(class_name_file)
|
| 33 |
-
|
| 34 |
providers = ['AxEngineExecutionProvider']
|
| 35 |
print(f"Loading ONNX model with providers: {providers}")
|
| 36 |
|
|
@@ -57,8 +58,8 @@ class BirdPredictorONNX:
|
|
| 57 |
|
| 58 |
def preprocess_image(self, image_path):
|
| 59 |
image = Image.open(image_path).convert('RGB')
|
| 60 |
-
image = image.resize((
|
| 61 |
-
|
| 62 |
img_array = np.array(image, dtype=np.uint8)
|
| 63 |
img_array = img_array.transpose(2, 0, 1)
|
| 64 |
img_array = np.expand_dims(img_array, axis=0)
|
|
@@ -226,12 +227,25 @@ def main():
|
|
| 226 |
default="./class_name.txt",
|
| 227 |
help="Path to configuration file")
|
| 228 |
parser.add_argument("-m", "--model_file",
|
| 229 |
-
default="./
|
| 230 |
help="Path to ONNX model file")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
parser.add_argument("--image_dir",
|
| 232 |
default="./test_images",
|
| 233 |
help="Directory containing test images")
|
| 234 |
-
parser.add_argument("--image",
|
| 235 |
help="Path to a single test image")
|
| 236 |
parser.add_argument("--top_k",
|
| 237 |
type=int,
|
|
@@ -239,10 +253,10 @@ def main():
|
|
| 239 |
help="Number of top predictions to show (default: 5)")
|
| 240 |
|
| 241 |
args = parser.parse_args()
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
predictor = BirdPredictorONNX(args.class_map_file, args.model_file)
|
| 245 |
-
|
| 246 |
if args.image and os.path.exists(args.image):
|
| 247 |
try:
|
| 248 |
top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
|
|
|
|
| 21 |
|
| 22 |
class BirdPredictorONNX:
|
| 23 |
"""Bird classification predictor based on ONNX Runtime"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, class_name_file, model_file, mean, std, image_size=224):
|
| 26 |
"""
|
| 27 |
Initialize the predictor.
|
| 28 |
Defaults to AxEngineExecutionProvider.
|
| 29 |
"""
|
| 30 |
+
self.rgb_mean = mean
|
| 31 |
+
self.rgb_std = std
|
| 32 |
+
self.image_size = image_size
|
| 33 |
self.classes = self.load_classes(class_name_file)
|
| 34 |
+
print(f"build predictor with {model_file}...")
|
| 35 |
providers = ['AxEngineExecutionProvider']
|
| 36 |
print(f"Loading ONNX model with providers: {providers}")
|
| 37 |
|
|
|
|
| 58 |
|
| 59 |
def preprocess_image(self, image_path):
|
| 60 |
image = Image.open(image_path).convert('RGB')
|
| 61 |
+
image = image.resize((int(self.image_size), int(self.image_size)), Image.BICUBIC)
|
| 62 |
+
|
| 63 |
img_array = np.array(image, dtype=np.uint8)
|
| 64 |
img_array = img_array.transpose(2, 0, 1)
|
| 65 |
img_array = np.expand_dims(img_array, axis=0)
|
|
|
|
| 227 |
default="./class_name.txt",
|
| 228 |
help="Path to configuration file")
|
| 229 |
parser.add_argument("-m", "--model_file",
|
| 230 |
+
default="./bird_650_npu3.axmodel",
|
| 231 |
help="Path to ONNX model file")
|
| 232 |
+
parser.add_argument("-imgsz", "--image_size",
|
| 233 |
+
default=384,
|
| 234 |
+
help="Input image size")
|
| 235 |
+
parser.add_argument("-mean", "--mean",
|
| 236 |
+
type=float,
|
| 237 |
+
nargs='+',
|
| 238 |
+
default=[0.485, 0.456, 0.406],
|
| 239 |
+
help="Mean normalization values")
|
| 240 |
+
parser.add_argument("-std", "--std",
|
| 241 |
+
type=float,
|
| 242 |
+
nargs='+',
|
| 243 |
+
default=[0.229, 0.224, 0.225],
|
| 244 |
+
help="Standard deviation normalization values")
|
| 245 |
parser.add_argument("--image_dir",
|
| 246 |
default="./test_images",
|
| 247 |
help="Directory containing test images")
|
| 248 |
+
parser.add_argument("-i", "--image",
|
| 249 |
help="Path to a single test image")
|
| 250 |
parser.add_argument("--top_k",
|
| 251 |
type=int,
|
|
|
|
| 253 |
help="Number of top predictions to show (default: 5)")
|
| 254 |
|
| 255 |
args = parser.parse_args()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
predictor = BirdPredictorONNX(args.class_map_file, args.model_file, args.mean, args.std, args.image_size)
|
| 259 |
+
|
| 260 |
if args.image and os.path.exists(args.image):
|
| 261 |
try:
|
| 262 |
top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
|
model/bird-l/AX615/bird_615_npu2.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8d71be9d5145d0256465968e0d15cf2de585d6c44fd545be33bd3c663feda29
|
| 3 |
+
size 23482611
|
model/bird-l/AX620E/bird_630_npu1.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12d4e9dd34ffa17c4c06ea35883855234adcee8b243d8cc40d3d11025de508db
|
| 3 |
+
size 24419641
|
model/bird-l/AX620E/bird_630_npu2.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:70cb2a12d16c05d6e186af8b656c0209bfb5c62de67fd7efc5006620683712e6
|
| 3 |
+
size 24269265
|
model/bird-l/AX650/bird_650_npu3.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bfdb20082d3972977699a6fdf1c97100016cdcccbd9591462cc2095e63ad1c04
|
| 3 |
+
size 23236891
|
onnx_infer.py
CHANGED
|
@@ -21,16 +21,17 @@ plt.rcParams['axes.unicode_minus'] = False
|
|
| 21 |
|
| 22 |
class BirdPredictorONNX:
|
| 23 |
"""Bird classification predictor based on ONNX Runtime"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, class_name_file, model_file):
|
| 26 |
"""
|
| 27 |
Initialize the predictor.
|
| 28 |
Defaults to CPUExecutionProvider.
|
| 29 |
"""
|
| 30 |
-
self.rgb_mean =
|
| 31 |
-
self.rgb_std =
|
|
|
|
| 32 |
self.classes = self.load_classes(class_name_file)
|
| 33 |
-
|
| 34 |
providers = ['CPUExecutionProvider']
|
| 35 |
print(f"Loading ONNX model with providers: {providers}")
|
| 36 |
|
|
@@ -57,7 +58,7 @@ class BirdPredictorONNX:
|
|
| 57 |
|
| 58 |
def preprocess_image(self, image_path):
|
| 59 |
image = Image.open(image_path).convert('RGB')
|
| 60 |
-
image = image.resize((
|
| 61 |
|
| 62 |
img_array = np.array(image, dtype=np.float32) / 255.0
|
| 63 |
img_array = img_array.transpose(2, 0, 1)
|
|
@@ -230,12 +231,25 @@ def main():
|
|
| 230 |
default="./class_name.txt",
|
| 231 |
help="Path to configuration file")
|
| 232 |
parser.add_argument("-m", "--model_file",
|
| 233 |
-
default="./
|
| 234 |
help="Path to ONNX model file")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
parser.add_argument("--image_dir",
|
| 236 |
default="./test_images",
|
| 237 |
help="Directory containing test images")
|
| 238 |
-
parser.add_argument("--image",
|
| 239 |
help="Path to a single test image")
|
| 240 |
parser.add_argument("--top_k",
|
| 241 |
type=int,
|
|
@@ -243,10 +257,10 @@ def main():
|
|
| 243 |
help="Number of top predictions to show (default: 5)")
|
| 244 |
|
| 245 |
args = parser.parse_args()
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
predictor = BirdPredictorONNX(args.class_map_file, args.model_file)
|
| 249 |
-
|
| 250 |
if args.image and os.path.exists(args.image):
|
| 251 |
try:
|
| 252 |
top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
|
|
|
|
| 21 |
|
| 22 |
class BirdPredictorONNX:
|
| 23 |
"""Bird classification predictor based on ONNX Runtime"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, class_name_file, model_file, mean, std, image_size=224):
|
| 26 |
"""
|
| 27 |
Initialize the predictor.
|
| 28 |
Defaults to CPUExecutionProvider.
|
| 29 |
"""
|
| 30 |
+
self.rgb_mean = mean
|
| 31 |
+
self.rgb_std = std
|
| 32 |
+
self.image_size = image_size
|
| 33 |
self.classes = self.load_classes(class_name_file)
|
| 34 |
+
print(f"build predictor with {model_file}...")
|
| 35 |
providers = ['CPUExecutionProvider']
|
| 36 |
print(f"Loading ONNX model with providers: {providers}")
|
| 37 |
|
|
|
|
| 58 |
|
| 59 |
def preprocess_image(self, image_path):
|
| 60 |
image = Image.open(image_path).convert('RGB')
|
| 61 |
+
image = image.resize((int(self.image_size), int(self.image_size)), Image.BICUBIC)
|
| 62 |
|
| 63 |
img_array = np.array(image, dtype=np.float32) / 255.0
|
| 64 |
img_array = img_array.transpose(2, 0, 1)
|
|
|
|
| 231 |
default="./class_name.txt",
|
| 232 |
help="Path to configuration file")
|
| 233 |
parser.add_argument("-m", "--model_file",
|
| 234 |
+
default="./bird-l.onnx",
|
| 235 |
help="Path to ONNX model file")
|
| 236 |
+
parser.add_argument("-imgsz", "--image_size",
|
| 237 |
+
default=384,
|
| 238 |
+
help="Input image size")
|
| 239 |
+
parser.add_argument("-mean", "--mean",
|
| 240 |
+
type=float,
|
| 241 |
+
nargs='+',
|
| 242 |
+
default=[0.485, 0.456, 0.406],
|
| 243 |
+
help="Mean normalization values")
|
| 244 |
+
parser.add_argument("-std", "--std",
|
| 245 |
+
type=float,
|
| 246 |
+
nargs='+',
|
| 247 |
+
default=[0.229, 0.224, 0.225],
|
| 248 |
+
help="Standard deviation normalization values")
|
| 249 |
parser.add_argument("--image_dir",
|
| 250 |
default="./test_images",
|
| 251 |
help="Directory containing test images")
|
| 252 |
+
parser.add_argument("-i", "--image",
|
| 253 |
help="Path to a single test image")
|
| 254 |
parser.add_argument("--top_k",
|
| 255 |
type=int,
|
|
|
|
| 257 |
help="Number of top predictions to show (default: 5)")
|
| 258 |
|
| 259 |
args = parser.parse_args()
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
predictor = BirdPredictorONNX(args.class_map_file, args.model_file, args.mean, args.std, args.image_size)
|
| 263 |
+
|
| 264 |
if args.image and os.path.exists(args.image):
|
| 265 |
try:
|
| 266 |
top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
|
prediction_result_top5.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
quant/{Bird.json → bird-l.json}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7a41de236d68ddfe7f15614ba8137fc7f7acbb891387d121b225d09f41a1240
|
| 3 |
+
size 899
|
quant/bird-m.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f70945c77a5e91a6bc21eed005fbc8004464703c5472cac1753e99f234b53517
|
| 3 |
+
size 886
|
quant/bird-s.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d9f55be272bb41c5d32136e3cff83778b48603076b2b6ce1c29a5a1f523e9ff
|
| 3 |
+
size 899
|