first commit
Browse files- .gitattributes +5 -0
- README.md +109 -1
- axmodel_infer.py +279 -0
- class_name.txt +0 -0
- model/AX615/bird_615_npu1.axmodel +3 -0
- model/AX615/bird_615_npu2.axmodel +3 -0
- model/AX620E/bird_630_npu1.axmodel +3 -0
- model/AX620E/bird_630_npu2.axmodel +3 -0
- model/AX650/bird_650_npu3.axmodel +3 -0
- onnx_infer.py +283 -0
- prediction_result_top5.png +3 -0
- quant/Bird.json +3 -0
- quant/bird.tar.gz +3 -0
- test_images/03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg +3 -0
- test_images/03332_01b365c3-a741-4f45-bac2-4345bc901ec6.jpg +3 -0
- test_images/03412_0ffc115b-43b4-4474-a373-24233f391de3.jpg +3 -0
- test_images/03615_0dfbf6ae-434d-4648-b5d2-08412546ea64.jpg +3 -0
- test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg +3 -0
- test_images/04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg +3 -0
- test_images/04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.axmodel filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -7,4 +7,112 @@ metrics:
|
|
| 7 |
pipeline_tag: image-classification
|
| 8 |
tags:
|
| 9 |
- biology
|
| 10 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
pipeline_tag: image-classification
|
| 8 |
tags:
|
| 9 |
- biology
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Bird-Species-Classification
|
| 13 |
+
|
| 14 |
+
This project only to show a demo of Bird Species Classification model within 1486 species of birds.
|
| 15 |
+
|
| 16 |
+
The model is trained with Resnet18 with 224x224 resolution.
|
| 17 |
+
|
| 18 |
+
This model has been converted to run on the Axera NPU using **w8a16** quantization.
|
| 19 |
+
|
| 20 |
+
This model has been optimized with the following LoRA:
|
| 21 |
+
|
| 22 |
+
Compatible with Pulsar2 version: 5.1
|
| 23 |
+
|
| 24 |
+
## Convert tools links:
|
| 25 |
+
|
| 26 |
+
For those who are interested in model conversion, you can try to export axmodel through
|
| 27 |
+
|
| 28 |
+
- [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
## Support Platform
|
| 32 |
+
|
| 33 |
+
- AX650
|
| 34 |
+
- [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
|
| 35 |
+
- [M.2 Accelerator card](https://docs.m5stack.com/zh_CN/ai_hardware/LLM-8850_Card)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
| Platforms | latency |
|
| 39 |
+
| -------------| ------------- |
|
| 40 |
+
| AX650 | 0.57ms |
|
| 41 |
+
| AX630C | 2.52ms |
|
| 42 |
+
| AX615 | 4.97ms |
|
| 43 |
+
|
| 44 |
+
## How to use
|
| 45 |
+
|
| 46 |
+
Download all files from this repository to the device
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
root@ax650:~/Bird-Species-Classification# tree
|
| 50 |
+
.
|
| 51 |
+
├── README.md
|
| 52 |
+
├── axmodel_infer.py
|
| 53 |
+
├── bird_hwc.axmodel
|
| 54 |
+
├── class_name.txt
|
| 55 |
+
├── model
|
| 56 |
+
│ ├── AX615
|
| 57 |
+
│ │ ├── bird_615_npu1.axmodel
|
| 58 |
+
│ │ └── bird_615_npu2.axmodel
|
| 59 |
+
│ ├── AX620E
|
| 60 |
+
│ │ ├── bird_630_npu1.axmodel
|
| 61 |
+
│ │ └── bird_630_npu2.axmodel
|
| 62 |
+
│ └── AX650
|
| 63 |
+
│ └── bird_650_npu3.axmodel
|
| 64 |
+
├── onnx_infer.py
|
| 65 |
+
├── prediction_result_top5.png
|
| 66 |
+
├── quant
|
| 67 |
+
│ ├── Bird.json
|
| 68 |
+
│ ├── README.md
|
| 69 |
+
│ └── bird.tar.gz
|
| 70 |
+
└── test_images
|
| 71 |
+
├── 03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg
|
| 72 |
+
├── 03332_01b365c3-a741-4f45-bac2-4345bc901ec6.jpg
|
| 73 |
+
├── 03412_0ffc115b-43b4-4474-a373-24233f391de3.jpg
|
| 74 |
+
├── 03615_0dfbf6ae-434d-4648-b5d2-08412546ea64.jpg
|
| 75 |
+
├── 04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 76 |
+
├── 04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg
|
| 77 |
+
└── 04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg
|
| 78 |
+
|
| 79 |
+
6 directories, 21 files
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### python env requirement
|
| 84 |
+
|
| 85 |
+
#### pyaxengine
|
| 86 |
+
|
| 87 |
+
https://github.com/AXERA-TECH/pyaxengine
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
wget https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3rc0/axengine-0.1.3-py3-none-any.whl
|
| 91 |
+
pip install axengine-0.1.3-py3-none-any.whl
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
|
| 95 |
+
|
| 96 |
+
```
|
| 97 |
+
root@ax650:~/Bird-Species-Classification# python3 axmodel_infer.py --image test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 98 |
+
[INFO] Available providers: ['AxEngineExecutionProvider']
|
| 99 |
+
Loading ONNX model with providers: ['AxEngineExecutionProvider']
|
| 100 |
+
[INFO] Using provider: AxEngineExecutionProvider
|
| 101 |
+
[INFO] Chip type: ChipType.MC50
|
| 102 |
+
[INFO] VNPU type: VNPUType.DISABLED
|
| 103 |
+
[INFO] Engine version: 2.12.0s
|
| 104 |
+
[INFO] Model type: 2 (triple core)
|
| 105 |
+
[INFO] Compiler version: 5.1-patch1 74996179
|
| 106 |
+
|
| 107 |
+
Image: test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
|
| 108 |
+
Top-5 Predictions:
|
| 109 |
+
#1: 04251_Animalia_Chordata_Aves_Passeriformes_Tityridae_Tityra_semifasciata (0.9999)
|
| 110 |
+
#2: 04019_Animalia_Chordata_Aves_Passeriformes_Oriolidae_Sphecotheres_vieilloti (0.0000)
|
| 111 |
+
#3: 04250_Animalia_Chordata_Aves_Passeriformes_Tityridae_Pachyramphus_aglaiae (0.0000)
|
| 112 |
+
#4: 03917_Animalia_Chordata_Aves_Passeriformes_Malaconotidae_Dryoscopus_cubla (0.0000)
|
| 113 |
+
#5: 04194_Animalia_Chordata_Aves_Passeriformes_Sturnidae_Aplonis_panayensis (0.0000)
|
| 114 |
+
Result saved to: prediction_result_top5.png
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
output:
|
| 118 |
+

|
axmodel_infer.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced)
|
| 5 |
+
Loads an exported ONNX model for bird classification.
|
| 6 |
+
Defaults to CPU execution.
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import axengine as axe
|
| 14 |
+
import matplotlib
|
| 15 |
+
matplotlib.use('Agg')
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
|
| 18 |
+
# Ensure English fonts are used to avoid warnings
|
| 19 |
+
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
|
| 20 |
+
plt.rcParams['axes.unicode_minus'] = False
|
| 21 |
+
|
| 22 |
+
class BirdPredictorONNX:
|
| 23 |
+
"""Bird classification predictor based on ONNX Runtime"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, class_name_file, model_file):
|
| 26 |
+
"""
|
| 27 |
+
Initialize the predictor.
|
| 28 |
+
Defaults to AxEngineExecutionProvider.
|
| 29 |
+
"""
|
| 30 |
+
self.rgb_mean = [0.5,0.5,0.5]
|
| 31 |
+
self.rgb_std = [0.5,0.5,0.5]
|
| 32 |
+
self.classes = self.load_classes(class_name_file)
|
| 33 |
+
|
| 34 |
+
providers = ['AxEngineExecutionProvider']
|
| 35 |
+
print(f"Loading ONNX model with providers: {providers}")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
self.session = axe.InferenceSession(model_file, providers=providers)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Failed to load model: {e}")
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 44 |
+
self.input_shape = self.session.get_inputs()[0].shape
|
| 45 |
+
|
| 46 |
+
self.transform = self.get_transform_params()
|
| 47 |
+
|
| 48 |
+
def load_classes(self,class_name_file):
|
| 49 |
+
with open(class_name_file, 'r', encoding='utf-8') as f:
|
| 50 |
+
classes = [line.strip() for line in f.readlines() if line.strip()]
|
| 51 |
+
return classes
|
| 52 |
+
|
| 53 |
+
def get_transform_params(self):
|
| 54 |
+
mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1)
|
| 55 |
+
std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1)
|
| 56 |
+
return {'mean': mean, 'std': std}
|
| 57 |
+
|
| 58 |
+
def preprocess_image(self, image_path):
|
| 59 |
+
image = Image.open(image_path).convert('RGB')
|
| 60 |
+
image = image.resize((224, 224), Image.BILINEAR)
|
| 61 |
+
|
| 62 |
+
img_array = np.array(image, dtype=np.uint8)
|
| 63 |
+
img_array = img_array.transpose(2, 0, 1)
|
| 64 |
+
img_array = np.expand_dims(img_array, axis=0)
|
| 65 |
+
|
| 66 |
+
return img_array
|
| 67 |
+
|
| 68 |
+
def predict_image_topk(self, image_path, k=5):
|
| 69 |
+
input_data = self.preprocess_image(image_path)
|
| 70 |
+
outputs = self.session.run(None, {self.input_name: input_data})
|
| 71 |
+
|
| 72 |
+
logits = outputs[0]
|
| 73 |
+
exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True))
|
| 74 |
+
probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
|
| 75 |
+
|
| 76 |
+
probs_0 = probabilities[0]
|
| 77 |
+
top_k_indices = np.argsort(probs_0)[::-1][:k]
|
| 78 |
+
|
| 79 |
+
results = []
|
| 80 |
+
for idx in top_k_indices:
|
| 81 |
+
class_name = self.classes[idx]
|
| 82 |
+
conf = float(probs_0[idx])
|
| 83 |
+
results.append((class_name, conf))
|
| 84 |
+
|
| 85 |
+
return results
|
| 86 |
+
|
| 87 |
+
def predict_batch_topk(self, image_dir, k=5):
|
| 88 |
+
results = []
|
| 89 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
| 90 |
+
|
| 91 |
+
files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)])
|
| 92 |
+
print(f"Found {len(files)} images, starting inference (Top-{k})...")
|
| 93 |
+
|
| 94 |
+
for filename in files:
|
| 95 |
+
image_path = os.path.join(image_dir, filename)
|
| 96 |
+
try:
|
| 97 |
+
top_k_results = self.predict_image_topk(image_path, k=k)
|
| 98 |
+
results.append({
|
| 99 |
+
'filename': filename,
|
| 100 |
+
'path': image_path,
|
| 101 |
+
'top_k': top_k_results
|
| 102 |
+
})
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error processing image {filename}: {str(e)}")
|
| 105 |
+
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
def _wrap_text(self, text, max_chars=25):
|
| 109 |
+
"""
|
| 110 |
+
Helper function to wrap or truncate long text to fit in table cells.
|
| 111 |
+
Tries to break at underscores or hyphens first.
|
| 112 |
+
"""
|
| 113 |
+
if len(text) <= max_chars:
|
| 114 |
+
return text
|
| 115 |
+
|
| 116 |
+
# Try to find a good breaking point (underscore or hyphen) near the limit
|
| 117 |
+
break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']]
|
| 118 |
+
|
| 119 |
+
if break_points:
|
| 120 |
+
# Break at the last found separator within the limit
|
| 121 |
+
split_idx = break_points[-1] + 1
|
| 122 |
+
return text[:split_idx] + "\n" + text[split_idx:]
|
| 123 |
+
|
| 124 |
+
# If no good break point, just force split in the middle
|
| 125 |
+
mid = max_chars // 2
|
| 126 |
+
return text[:mid] + "-\n" + text[mid:]
|
| 127 |
+
|
| 128 |
+
def visualize_prediction_topk(self, image_path, top_k_results, save_path=None):
|
| 129 |
+
image = cv2.imread(image_path)
|
| 130 |
+
if image is None:
|
| 131 |
+
raise ValueError(f"Cannot read image: {image_path}")
|
| 132 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 133 |
+
|
| 134 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
| 135 |
+
|
| 136 |
+
ax1.imshow(image)
|
| 137 |
+
ax1.set_title('Input Image', fontsize=14, fontweight='bold')
|
| 138 |
+
ax1.axis('off')
|
| 139 |
+
|
| 140 |
+
ax2.axis('off')
|
| 141 |
+
|
| 142 |
+
table_data = []
|
| 143 |
+
table_data.append(["Rank", "Class Name", "Confidence"])
|
| 144 |
+
|
| 145 |
+
processed_rows = []
|
| 146 |
+
for i, (cls_name, conf) in enumerate(top_k_results):
|
| 147 |
+
rank = f"#{i+1}"
|
| 148 |
+
conf_str = f"{conf:.4f} ({conf*100:.2f}%)"
|
| 149 |
+
|
| 150 |
+
# Process long class names
|
| 151 |
+
wrapped_name = self._wrap_text(cls_name, max_chars=28) # Increased limit slightly but allow wrapping
|
| 152 |
+
processed_rows.append([rank, wrapped_name, conf_str])
|
| 153 |
+
|
| 154 |
+
# Combine header and rows
|
| 155 |
+
full_table_data = [table_data[0]] + processed_rows
|
| 156 |
+
|
| 157 |
+
# Create table with specific column widths
|
| 158 |
+
# Col widths: Rank (10%), Name (60%), Conf (30%)
|
| 159 |
+
table = ax2.table(cellText=full_table_data[1:],
|
| 160 |
+
colLabels=full_table_data[0],
|
| 161 |
+
loc='center',
|
| 162 |
+
cellLoc='left', # Left align for text content usually looks better with wraps
|
| 163 |
+
colWidths=[0.1, 0.6, 0.3],
|
| 164 |
+
bbox=[0.05, 0.1, 0.9, 0.75]) # Adjusted bbox to give more vertical space
|
| 165 |
+
|
| 166 |
+
table.auto_set_font_size(False)
|
| 167 |
+
|
| 168 |
+
# Dynamically adjust font size if names are very long/wrapped
|
| 169 |
+
base_font_size = 10
|
| 170 |
+
if any('\n' in row[1] for row in processed_rows):
|
| 171 |
+
base_font_size = 8 # Reduce font if wrapping occurred
|
| 172 |
+
|
| 173 |
+
table.set_fontsize(base_font_size)
|
| 174 |
+
|
| 175 |
+
# Scale row height to accommodate wrapped text
|
| 176 |
+
# Base scale 1.5, increase if wrapped
|
| 177 |
+
row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5
|
| 178 |
+
table.scale(1, row_scale)
|
| 179 |
+
|
| 180 |
+
# Style the header
|
| 181 |
+
for i in range(3):
|
| 182 |
+
cell = table[(0, i)]
|
| 183 |
+
cell.set_text_props(fontweight='bold', color='white', ha='center')
|
| 184 |
+
cell.set_facecolor('#4472C4')
|
| 185 |
+
if i == 1: # Center the header of the name column
|
| 186 |
+
cell.set_text_props(ha='center')
|
| 187 |
+
|
| 188 |
+
# Style body cells
|
| 189 |
+
for i in range(1, len(full_table_data)):
|
| 190 |
+
for j in range(3):
|
| 191 |
+
cell = table[(i, j)]
|
| 192 |
+
cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9')
|
| 193 |
+
cell.set_edgecolor('#dddddd')
|
| 194 |
+
cell.set_linewidth(1)
|
| 195 |
+
|
| 196 |
+
# Alignment logic
|
| 197 |
+
if j == 0: # Rank
|
| 198 |
+
cell.set_text_props(ha='center', va='center')
|
| 199 |
+
elif j == 1: # Name (Left aligned, top aligned for wrapped text)
|
| 200 |
+
cell.set_text_props(ha='left', va='top', wrap=True)
|
| 201 |
+
else: # Confidence
|
| 202 |
+
cell.set_text_props(ha='center', va='center')
|
| 203 |
+
|
| 204 |
+
# Add File Path Text
|
| 205 |
+
display_path = image_path
|
| 206 |
+
if len(display_path) > 50:
|
| 207 |
+
display_path = "..." + display_path[-47:]
|
| 208 |
+
|
| 209 |
+
path_text = f"File Path:\n{display_path}"
|
| 210 |
+
ax2.text(0.5, 0.92, path_text,
|
| 211 |
+
ha='center', va='center', fontsize=9, color='#555555',
|
| 212 |
+
bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8))
|
| 213 |
+
|
| 214 |
+
ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20)
|
| 215 |
+
|
| 216 |
+
plt.tight_layout()
|
| 217 |
+
|
| 218 |
+
out_path = save_path if save_path else 'prediction_result_top5.png'
|
| 219 |
+
plt.savefig(out_path, dpi=150, bbox_inches='tight')
|
| 220 |
+
plt.close()
|
| 221 |
+
print(f"Result saved to: {out_path}")
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)")
|
| 225 |
+
parser.add_argument("-c", "--class_map_file",
|
| 226 |
+
default="./class_name.txt",
|
| 227 |
+
help="Path to configuration file")
|
| 228 |
+
parser.add_argument("-m", "--model_file",
|
| 229 |
+
default="./model/AX650/bird_650_npu3.axmodel",
|
| 230 |
+
help="Path to ONNX model file")
|
| 231 |
+
parser.add_argument("--image_dir",
|
| 232 |
+
default="./test_images",
|
| 233 |
+
help="Directory containing test images")
|
| 234 |
+
parser.add_argument("--image",
|
| 235 |
+
help="Path to a single test image")
|
| 236 |
+
parser.add_argument("--top_k",
|
| 237 |
+
type=int,
|
| 238 |
+
default=5,
|
| 239 |
+
help="Number of top predictions to show (default: 5)")
|
| 240 |
+
|
| 241 |
+
args = parser.parse_args()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
predictor = BirdPredictorONNX(args.class_map_file, args.model_file)
|
| 245 |
+
|
| 246 |
+
if args.image and os.path.exists(args.image):
|
| 247 |
+
try:
|
| 248 |
+
top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
|
| 249 |
+
|
| 250 |
+
print(f"\nImage: {args.image}")
|
| 251 |
+
print(f"Top-{args.top_k} Predictions:")
|
| 252 |
+
for i, (cls_name, conf) in enumerate(top_k_results):
|
| 253 |
+
print(f"#{i+1}: {cls_name} ({conf:.4f})")
|
| 254 |
+
|
| 255 |
+
predictor.visualize_prediction_topk(args.image, top_k_results)
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(f"Inference failed: {e}")
|
| 259 |
+
|
| 260 |
+
elif os.path.exists(args.image_dir):
|
| 261 |
+
results = predictor.predict_batch_topk(args.image_dir, k=args.top_k)
|
| 262 |
+
|
| 263 |
+
print(f"\nProcessed {len(results)} images:")
|
| 264 |
+
for res in results:
|
| 265 |
+
print(f"File: {res['filename']}")
|
| 266 |
+
for i, (cls_name, conf) in enumerate(res['top_k']):
|
| 267 |
+
marker = "[1]" if i == 0 else " "
|
| 268 |
+
print(f"{marker} #{i+1}: {cls_name} ({conf:.4f})")
|
| 269 |
+
|
| 270 |
+
print("\nNote: Visualization saves only the last processed image in batch mode.")
|
| 271 |
+
if results:
|
| 272 |
+
last_res = results[-1]
|
| 273 |
+
predictor.visualize_prediction_topk(last_res['path'], last_res['top_k'], save_path='batch_last_result.png')
|
| 274 |
+
|
| 275 |
+
else:
|
| 276 |
+
print("Specified image or directory not found.")
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
class_name.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/AX615/bird_615_npu1.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f68795ef21bb848113545378feb9f7b4fa8eda552680fdf6f3b83f2f9d8bfcb
|
| 3 |
+
size 12148436
|
model/AX615/bird_615_npu2.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5639d48e44c218e6e1a4b2ad075c4249ae89981c12ff92cdbe358e85d1f0091a
|
| 3 |
+
size 12080120
|
model/AX620E/bird_630_npu1.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21e4e53ab56cf74335a2ea37d675d3816600630fdaa6d8838492c92870d6c832
|
| 3 |
+
size 12285724
|
model/AX620E/bird_630_npu2.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:895d300404d0134497a8fa643a4bc2136f758a6c4d393da813c10274bdfafed4
|
| 3 |
+
size 12067712
|
model/AX650/bird_650_npu3.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e544a71cb3ca619c8f064f8952193453c33e816467e03d2b5e7bed75aec99038
|
| 3 |
+
size 12209712
|
onnx_infer.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced)
|
| 5 |
+
Loads an exported ONNX model for bird classification.
|
| 6 |
+
Defaults to CPU execution.
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import onnxruntime as ort
|
| 14 |
+
import matplotlib
|
| 15 |
+
matplotlib.use('Agg')
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
|
| 18 |
+
# Ensure English fonts are used to avoid warnings
|
| 19 |
+
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
|
| 20 |
+
plt.rcParams['axes.unicode_minus'] = False
|
| 21 |
+
|
| 22 |
+
class BirdPredictorONNX:
|
| 23 |
+
"""Bird classification predictor based on ONNX Runtime"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, class_name_file, model_file):
|
| 26 |
+
"""
|
| 27 |
+
Initialize the predictor.
|
| 28 |
+
Defaults to CPUExecutionProvider.
|
| 29 |
+
"""
|
| 30 |
+
self.rgb_mean = [0.5,0.5,0.5]
|
| 31 |
+
self.rgb_std = [0.5,0.5,0.5]
|
| 32 |
+
self.classes = self.load_classes(class_name_file)
|
| 33 |
+
|
| 34 |
+
providers = ['CPUExecutionProvider']
|
| 35 |
+
print(f"Loading ONNX model with providers: {providers}")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
self.session = ort.InferenceSession(model_file, providers=providers)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Failed to load model: {e}")
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 44 |
+
self.input_shape = self.session.get_inputs()[0].shape
|
| 45 |
+
|
| 46 |
+
self.transform = self.get_transform_params()
|
| 47 |
+
|
| 48 |
+
def load_classes(self,class_name_file):
|
| 49 |
+
with open(class_name_file, 'r', encoding='utf-8') as f:
|
| 50 |
+
classes = [line.strip() for line in f.readlines() if line.strip()]
|
| 51 |
+
return classes
|
| 52 |
+
|
| 53 |
+
def get_transform_params(self):
|
| 54 |
+
mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1)
|
| 55 |
+
std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1)
|
| 56 |
+
return {'mean': mean, 'std': std}
|
| 57 |
+
|
| 58 |
+
def preprocess_image(self, image_path):
|
| 59 |
+
image = Image.open(image_path).convert('RGB')
|
| 60 |
+
image = image.resize((224, 224), Image.BILINEAR)
|
| 61 |
+
|
| 62 |
+
img_array = np.array(image, dtype=np.float32) / 255.0
|
| 63 |
+
img_array = img_array.transpose(2, 0, 1)
|
| 64 |
+
img_array = np.expand_dims(img_array, axis=0)
|
| 65 |
+
|
| 66 |
+
mean = self.transform['mean']
|
| 67 |
+
std = self.transform['std']
|
| 68 |
+
img_array = (img_array - mean) / std
|
| 69 |
+
|
| 70 |
+
return img_array.astype(np.float32)
|
| 71 |
+
|
| 72 |
+
def predict_image_topk(self, image_path, k=5):
|
| 73 |
+
input_data = self.preprocess_image(image_path)
|
| 74 |
+
outputs = self.session.run(None, {self.input_name: input_data})
|
| 75 |
+
|
| 76 |
+
logits = outputs[0]
|
| 77 |
+
exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True))
|
| 78 |
+
probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
|
| 79 |
+
|
| 80 |
+
probs_0 = probabilities[0]
|
| 81 |
+
top_k_indices = np.argsort(probs_0)[::-1][:k]
|
| 82 |
+
|
| 83 |
+
results = []
|
| 84 |
+
for idx in top_k_indices:
|
| 85 |
+
class_name = self.classes[idx]
|
| 86 |
+
conf = float(probs_0[idx])
|
| 87 |
+
results.append((class_name, conf))
|
| 88 |
+
|
| 89 |
+
return results
|
| 90 |
+
|
| 91 |
+
def predict_batch_topk(self, image_dir, k=5):
|
| 92 |
+
results = []
|
| 93 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
| 94 |
+
|
| 95 |
+
files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)])
|
| 96 |
+
print(f"Found {len(files)} images, starting inference (Top-{k})...")
|
| 97 |
+
|
| 98 |
+
for filename in files:
|
| 99 |
+
image_path = os.path.join(image_dir, filename)
|
| 100 |
+
try:
|
| 101 |
+
top_k_results = self.predict_image_topk(image_path, k=k)
|
| 102 |
+
results.append({
|
| 103 |
+
'filename': filename,
|
| 104 |
+
'path': image_path,
|
| 105 |
+
'top_k': top_k_results
|
| 106 |
+
})
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Error processing image {filename}: {str(e)}")
|
| 109 |
+
|
| 110 |
+
return results
|
| 111 |
+
|
| 112 |
+
def _wrap_text(self, text, max_chars=25):
|
| 113 |
+
"""
|
| 114 |
+
Helper function to wrap or truncate long text to fit in table cells.
|
| 115 |
+
Tries to break at underscores or hyphens first.
|
| 116 |
+
"""
|
| 117 |
+
if len(text) <= max_chars:
|
| 118 |
+
return text
|
| 119 |
+
|
| 120 |
+
# Try to find a good breaking point (underscore or hyphen) near the limit
|
| 121 |
+
break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']]
|
| 122 |
+
|
| 123 |
+
if break_points:
|
| 124 |
+
# Break at the last found separator within the limit
|
| 125 |
+
split_idx = break_points[-1] + 1
|
| 126 |
+
return text[:split_idx] + "\n" + text[split_idx:]
|
| 127 |
+
|
| 128 |
+
# If no good break point, just force split in the middle
|
| 129 |
+
mid = max_chars // 2
|
| 130 |
+
return text[:mid] + "-\n" + text[mid:]
|
| 131 |
+
|
| 132 |
+
def visualize_prediction_topk(self, image_path, top_k_results, save_path=None):
|
| 133 |
+
image = cv2.imread(image_path)
|
| 134 |
+
if image is None:
|
| 135 |
+
raise ValueError(f"Cannot read image: {image_path}")
|
| 136 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 137 |
+
|
| 138 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
| 139 |
+
|
| 140 |
+
ax1.imshow(image)
|
| 141 |
+
ax1.set_title('Input Image', fontsize=14, fontweight='bold')
|
| 142 |
+
ax1.axis('off')
|
| 143 |
+
|
| 144 |
+
ax2.axis('off')
|
| 145 |
+
|
| 146 |
+
table_data = []
|
| 147 |
+
table_data.append(["Rank", "Class Name", "Confidence"])
|
| 148 |
+
|
| 149 |
+
processed_rows = []
|
| 150 |
+
for i, (cls_name, conf) in enumerate(top_k_results):
|
| 151 |
+
rank = f"#{i+1}"
|
| 152 |
+
conf_str = f"{conf:.4f} ({conf*100:.2f}%)"
|
| 153 |
+
|
| 154 |
+
# Process long class names
|
| 155 |
+
wrapped_name = self._wrap_text(cls_name, max_chars=28) # Increased limit slightly but allow wrapping
|
| 156 |
+
processed_rows.append([rank, wrapped_name, conf_str])
|
| 157 |
+
|
| 158 |
+
# Combine header and rows
|
| 159 |
+
full_table_data = [table_data[0]] + processed_rows
|
| 160 |
+
|
| 161 |
+
# Create table with specific column widths
|
| 162 |
+
# Col widths: Rank (10%), Name (60%), Conf (30%)
|
| 163 |
+
table = ax2.table(cellText=full_table_data[1:],
|
| 164 |
+
colLabels=full_table_data[0],
|
| 165 |
+
loc='center',
|
| 166 |
+
cellLoc='left', # Left align for text content usually looks better with wraps
|
| 167 |
+
colWidths=[0.1, 0.6, 0.3],
|
| 168 |
+
bbox=[0.05, 0.1, 0.9, 0.75]) # Adjusted bbox to give more vertical space
|
| 169 |
+
|
| 170 |
+
table.auto_set_font_size(False)
|
| 171 |
+
|
| 172 |
+
# Dynamically adjust font size if names are very long/wrapped
|
| 173 |
+
base_font_size = 10
|
| 174 |
+
if any('\n' in row[1] for row in processed_rows):
|
| 175 |
+
base_font_size = 8 # Reduce font if wrapping occurred
|
| 176 |
+
|
| 177 |
+
table.set_fontsize(base_font_size)
|
| 178 |
+
|
| 179 |
+
# Scale row height to accommodate wrapped text
|
| 180 |
+
# Base scale 1.5, increase if wrapped
|
| 181 |
+
row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5
|
| 182 |
+
table.scale(1, row_scale)
|
| 183 |
+
|
| 184 |
+
# Style the header
|
| 185 |
+
for i in range(3):
|
| 186 |
+
cell = table[(0, i)]
|
| 187 |
+
cell.set_text_props(fontweight='bold', color='white', ha='center')
|
| 188 |
+
cell.set_facecolor('#4472C4')
|
| 189 |
+
if i == 1: # Center the header of the name column
|
| 190 |
+
cell.set_text_props(ha='center')
|
| 191 |
+
|
| 192 |
+
# Style body cells
|
| 193 |
+
for i in range(1, len(full_table_data)):
|
| 194 |
+
for j in range(3):
|
| 195 |
+
cell = table[(i, j)]
|
| 196 |
+
cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9')
|
| 197 |
+
cell.set_edgecolor('#dddddd')
|
| 198 |
+
cell.set_linewidth(1)
|
| 199 |
+
|
| 200 |
+
# Alignment logic
|
| 201 |
+
if j == 0: # Rank
|
| 202 |
+
cell.set_text_props(ha='center', va='center')
|
| 203 |
+
elif j == 1: # Name (Left aligned, top aligned for wrapped text)
|
| 204 |
+
cell.set_text_props(ha='left', va='top', wrap=True)
|
| 205 |
+
else: # Confidence
|
| 206 |
+
cell.set_text_props(ha='center', va='center')
|
| 207 |
+
|
| 208 |
+
# Add File Path Text
|
| 209 |
+
display_path = image_path
|
| 210 |
+
if len(display_path) > 50:
|
| 211 |
+
display_path = "..." + display_path[-47:]
|
| 212 |
+
|
| 213 |
+
path_text = f"File Path:\n{display_path}"
|
| 214 |
+
ax2.text(0.5, 0.92, path_text,
|
| 215 |
+
ha='center', va='center', fontsize=9, color='#555555',
|
| 216 |
+
bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8))
|
| 217 |
+
|
| 218 |
+
ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20)
|
| 219 |
+
|
| 220 |
+
plt.tight_layout()
|
| 221 |
+
|
| 222 |
+
out_path = save_path if save_path else 'prediction_result_top5.png'
|
| 223 |
+
plt.savefig(out_path, dpi=150, bbox_inches='tight')
|
| 224 |
+
plt.close()
|
| 225 |
+
print(f"Result saved to: {out_path}")
|
| 226 |
+
|
| 227 |
+
def main():
|
| 228 |
+
parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)")
|
| 229 |
+
parser.add_argument("-c", "--class_map_file",
|
| 230 |
+
default="./class_name.txt",
|
| 231 |
+
help="Path to configuration file")
|
| 232 |
+
parser.add_argument("-m", "--model_file",
|
| 233 |
+
default="./bird_resnet18.onnx",
|
| 234 |
+
help="Path to ONNX model file")
|
| 235 |
+
parser.add_argument("--image_dir",
|
| 236 |
+
default="./test_images",
|
| 237 |
+
help="Directory containing test images")
|
| 238 |
+
parser.add_argument("--image",
|
| 239 |
+
help="Path to a single test image")
|
| 240 |
+
parser.add_argument("--top_k",
|
| 241 |
+
type=int,
|
| 242 |
+
default=5,
|
| 243 |
+
help="Number of top predictions to show (default: 5)")
|
| 244 |
+
|
| 245 |
+
args = parser.parse_args()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
predictor = BirdPredictorONNX(args.class_map_file, args.model_file)
|
| 249 |
+
|
| 250 |
+
if args.image and os.path.exists(args.image):
|
| 251 |
+
try:
|
| 252 |
+
top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
|
| 253 |
+
|
| 254 |
+
print(f"\nImage: {args.image}")
|
| 255 |
+
print(f"Top-{args.top_k} Predictions:")
|
| 256 |
+
for i, (cls_name, conf) in enumerate(top_k_results):
|
| 257 |
+
print(f"#{i+1}: {cls_name} ({conf:.4f})")
|
| 258 |
+
|
| 259 |
+
predictor.visualize_prediction_topk(args.image, top_k_results)
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Inference failed: {e}")
|
| 263 |
+
|
| 264 |
+
elif os.path.exists(args.image_dir):
|
| 265 |
+
results = predictor.predict_batch_topk(args.image_dir, k=args.top_k)
|
| 266 |
+
|
| 267 |
+
print(f"\nProcessed {len(results)} images:")
|
| 268 |
+
for res in results:
|
| 269 |
+
print(f"File: {res['filename']}")
|
| 270 |
+
for i, (cls_name, conf) in enumerate(res['top_k']):
|
| 271 |
+
marker = "[1]" if i == 0 else " "
|
| 272 |
+
print(f"{marker} #{i+1}: {cls_name} ({conf:.4f})")
|
| 273 |
+
|
| 274 |
+
print("\nNote: Visualization saves only the last processed image in batch mode.")
|
| 275 |
+
if results:
|
| 276 |
+
last_res = results[-1]
|
| 277 |
+
predictor.visualize_prediction_topk(last_res['path'], last_res['top_k'], save_path='batch_last_result.png')
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
print("Specified image or directory not found.")
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
main()
|
prediction_result_top5.png
ADDED
|
Git LFS Details
|
quant/Bird.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:218c30ea230c8b830402bccced57c2d6727f0f2d9d14e61d4315d9ce2e2e79ae
|
| 3 |
+
size 893
|
quant/bird.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb5cc7dbf1ccd7f9c2cafb24ec65eb461bea5c8a25ef83c846da421433114758
|
| 3 |
+
size 108199387
|
test_images/03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg
ADDED
|
Git LFS Details
|
test_images/03332_01b365c3-a741-4f45-bac2-4345bc901ec6.jpg
ADDED
|
Git LFS Details
|
test_images/03412_0ffc115b-43b4-4474-a373-24233f391de3.jpg
ADDED
|
Git LFS Details
|
test_images/03615_0dfbf6ae-434d-4648-b5d2-08412546ea64.jpg
ADDED
|
Git LFS Details
|
test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
ADDED
|
Git LFS Details
|
test_images/04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg
ADDED
|
Git LFS Details
|
test_images/04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg
ADDED
|
Git LFS Details
|