wzf19947 commited on
Commit
5035fe7
·
1 Parent(s): c2b5bcd

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.axmodel filter=lfs diff=lfs merge=lfs -text
38
+ *.json filter=lfs diff=lfs merge=lfs -text
39
+ *.jpg filter=lfs diff=lfs merge=lfs -text
40
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -7,4 +7,112 @@ metrics:
7
  pipeline_tag: image-classification
8
  tags:
9
  - biology
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  pipeline_tag: image-classification
8
  tags:
9
  - biology
10
+ ---
11
+
12
+ # Bird-Species-Classification
13
+
14
+ This project only to show a demo of Bird Species Classification model within 1486 species of birds.
15
+
16
+ The model is trained with Resnet18 with 224x224 resolution.
17
+
18
+ This model has been converted to run on the Axera NPU using **w8a16** quantization.
19
+
20
+ This model has been optimized with the following LoRA:
21
+
22
+ Compatible with Pulsar2 version: 5.1
23
+
24
+ ## Convert tools links:
25
+
26
+ For those who are interested in model conversion, you can try to export axmodel through
27
+
28
+ - [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
29
+
30
+
31
+ ## Support Platform
32
+
33
+ - AX650
34
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
35
+ - [M.2 Accelerator card](https://docs.m5stack.com/zh_CN/ai_hardware/LLM-8850_Card)
36
+
37
+
38
+ | Platforms | latency |
39
+ | -------------| ------------- |
40
+ | AX650 | 0.57ms |
41
+ | AX630C | 2.52ms |
42
+ | AX615 | 4.97ms |
43
+
44
+ ## How to use
45
+
46
+ Download all files from this repository to the device
47
+
48
+ ```
49
+ root@ax650:~/Bird-Species-Classification# tree
50
+ .
51
+ ├── README.md
52
+ ├── axmodel_infer.py
53
+ ├── bird_hwc.axmodel
54
+ ├── class_name.txt
55
+ ├── model
56
+ │   ├── AX615
57
+ │   │   ├── bird_615_npu1.axmodel
58
+ │   │   └── bird_615_npu2.axmodel
59
+ │   ├── AX620E
60
+ │   │   ├── bird_630_npu1.axmodel
61
+ │   │   └── bird_630_npu2.axmodel
62
+ │   └── AX650
63
+ │   └── bird_650_npu3.axmodel
64
+ ├── onnx_infer.py
65
+ ├── prediction_result_top5.png
66
+ ├── quant
67
+ │   ├── Bird.json
68
+ │   ├── README.md
69
+ │   └── bird.tar.gz
70
+ └── test_images
71
+ ├── 03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg
72
+ ├── 03332_01b365c3-a741-4f45-bac2-4345bc901ec6.jpg
73
+ ├── 03412_0ffc115b-43b4-4474-a373-24233f391de3.jpg
74
+ ├── 03615_0dfbf6ae-434d-4648-b5d2-08412546ea64.jpg
75
+ ├── 04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
76
+ ├── 04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg
77
+ └── 04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg
78
+
79
+ 6 directories, 21 files
80
+
81
+ ```
82
+
83
+ ### python env requirement
84
+
85
+ #### pyaxengine
86
+
87
+ https://github.com/AXERA-TECH/pyaxengine
88
+
89
+ ```
90
+ wget https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3rc0/axengine-0.1.3-py3-none-any.whl
91
+ pip install axengine-0.1.3-py3-none-any.whl
92
+ ```
93
+
94
+ ## Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
95
+
96
+ ```
97
+ root@ax650:~/Bird-Species-Classification# python3 axmodel_infer.py --image test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
98
+ [INFO] Available providers: ['AxEngineExecutionProvider']
99
+ Loading ONNX model with providers: ['AxEngineExecutionProvider']
100
+ [INFO] Using provider: AxEngineExecutionProvider
101
+ [INFO] Chip type: ChipType.MC50
102
+ [INFO] VNPU type: VNPUType.DISABLED
103
+ [INFO] Engine version: 2.12.0s
104
+ [INFO] Model type: 2 (triple core)
105
+ [INFO] Compiler version: 5.1-patch1 74996179
106
+
107
+ Image: test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg
108
+ Top-5 Predictions:
109
+ #1: 04251_Animalia_Chordata_Aves_Passeriformes_Tityridae_Tityra_semifasciata (0.9999)
110
+ #2: 04019_Animalia_Chordata_Aves_Passeriformes_Oriolidae_Sphecotheres_vieilloti (0.0000)
111
+ #3: 04250_Animalia_Chordata_Aves_Passeriformes_Tityridae_Pachyramphus_aglaiae (0.0000)
112
+ #4: 03917_Animalia_Chordata_Aves_Passeriformes_Malaconotidae_Dryoscopus_cubla (0.0000)
113
+ #5: 04194_Animalia_Chordata_Aves_Passeriformes_Sturnidae_Aplonis_panayensis (0.0000)
114
+ Result saved to: prediction_result_top5.png
115
+ ```
116
+
117
+ output:
118
+ ![](./prediction_result_top5.png)
axmodel_infer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced)
5
+ Loads an exported ONNX model for bird classification.
6
+ Defaults to CPU execution.
7
+ """
8
+ import os
9
+ import argparse
10
+ import numpy as np
11
+ import cv2
12
+ from PIL import Image
13
+ import axengine as axe
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ import matplotlib.pyplot as plt
17
+
18
+ # Ensure English fonts are used to avoid warnings
19
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
20
+ plt.rcParams['axes.unicode_minus'] = False
21
+
22
+ class BirdPredictorONNX:
23
+ """Bird classification predictor based on ONNX Runtime"""
24
+
25
+ def __init__(self, class_name_file, model_file):
26
+ """
27
+ Initialize the predictor.
28
+ Defaults to AxEngineExecutionProvider.
29
+ """
30
+ self.rgb_mean = [0.5,0.5,0.5]
31
+ self.rgb_std = [0.5,0.5,0.5]
32
+ self.classes = self.load_classes(class_name_file)
33
+
34
+ providers = ['AxEngineExecutionProvider']
35
+ print(f"Loading ONNX model with providers: {providers}")
36
+
37
+ try:
38
+ self.session = axe.InferenceSession(model_file, providers=providers)
39
+ except Exception as e:
40
+ print(f"Failed to load model: {e}")
41
+ raise
42
+
43
+ self.input_name = self.session.get_inputs()[0].name
44
+ self.input_shape = self.session.get_inputs()[0].shape
45
+
46
+ self.transform = self.get_transform_params()
47
+
48
+ def load_classes(self,class_name_file):
49
+ with open(class_name_file, 'r', encoding='utf-8') as f:
50
+ classes = [line.strip() for line in f.readlines() if line.strip()]
51
+ return classes
52
+
53
+ def get_transform_params(self):
54
+ mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1)
55
+ std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1)
56
+ return {'mean': mean, 'std': std}
57
+
58
+ def preprocess_image(self, image_path):
59
+ image = Image.open(image_path).convert('RGB')
60
+ image = image.resize((224, 224), Image.BILINEAR)
61
+
62
+ img_array = np.array(image, dtype=np.uint8)
63
+ img_array = img_array.transpose(2, 0, 1)
64
+ img_array = np.expand_dims(img_array, axis=0)
65
+
66
+ return img_array
67
+
68
+ def predict_image_topk(self, image_path, k=5):
69
+ input_data = self.preprocess_image(image_path)
70
+ outputs = self.session.run(None, {self.input_name: input_data})
71
+
72
+ logits = outputs[0]
73
+ exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True))
74
+ probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
75
+
76
+ probs_0 = probabilities[0]
77
+ top_k_indices = np.argsort(probs_0)[::-1][:k]
78
+
79
+ results = []
80
+ for idx in top_k_indices:
81
+ class_name = self.classes[idx]
82
+ conf = float(probs_0[idx])
83
+ results.append((class_name, conf))
84
+
85
+ return results
86
+
87
+ def predict_batch_topk(self, image_dir, k=5):
88
+ results = []
89
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
90
+
91
+ files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)])
92
+ print(f"Found {len(files)} images, starting inference (Top-{k})...")
93
+
94
+ for filename in files:
95
+ image_path = os.path.join(image_dir, filename)
96
+ try:
97
+ top_k_results = self.predict_image_topk(image_path, k=k)
98
+ results.append({
99
+ 'filename': filename,
100
+ 'path': image_path,
101
+ 'top_k': top_k_results
102
+ })
103
+ except Exception as e:
104
+ print(f"Error processing image {filename}: {str(e)}")
105
+
106
+ return results
107
+
108
+ def _wrap_text(self, text, max_chars=25):
109
+ """
110
+ Helper function to wrap or truncate long text to fit in table cells.
111
+ Tries to break at underscores or hyphens first.
112
+ """
113
+ if len(text) <= max_chars:
114
+ return text
115
+
116
+ # Try to find a good breaking point (underscore or hyphen) near the limit
117
+ break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']]
118
+
119
+ if break_points:
120
+ # Break at the last found separator within the limit
121
+ split_idx = break_points[-1] + 1
122
+ return text[:split_idx] + "\n" + text[split_idx:]
123
+
124
+ # If no good break point, just force split in the middle
125
+ mid = max_chars // 2
126
+ return text[:mid] + "-\n" + text[mid:]
127
+
128
+ def visualize_prediction_topk(self, image_path, top_k_results, save_path=None):
129
+ image = cv2.imread(image_path)
130
+ if image is None:
131
+ raise ValueError(f"Cannot read image: {image_path}")
132
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
133
+
134
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
135
+
136
+ ax1.imshow(image)
137
+ ax1.set_title('Input Image', fontsize=14, fontweight='bold')
138
+ ax1.axis('off')
139
+
140
+ ax2.axis('off')
141
+
142
+ table_data = []
143
+ table_data.append(["Rank", "Class Name", "Confidence"])
144
+
145
+ processed_rows = []
146
+ for i, (cls_name, conf) in enumerate(top_k_results):
147
+ rank = f"#{i+1}"
148
+ conf_str = f"{conf:.4f} ({conf*100:.2f}%)"
149
+
150
+ # Process long class names
151
+ wrapped_name = self._wrap_text(cls_name, max_chars=28) # Increased limit slightly but allow wrapping
152
+ processed_rows.append([rank, wrapped_name, conf_str])
153
+
154
+ # Combine header and rows
155
+ full_table_data = [table_data[0]] + processed_rows
156
+
157
+ # Create table with specific column widths
158
+ # Col widths: Rank (10%), Name (60%), Conf (30%)
159
+ table = ax2.table(cellText=full_table_data[1:],
160
+ colLabels=full_table_data[0],
161
+ loc='center',
162
+ cellLoc='left', # Left align for text content usually looks better with wraps
163
+ colWidths=[0.1, 0.6, 0.3],
164
+ bbox=[0.05, 0.1, 0.9, 0.75]) # Adjusted bbox to give more vertical space
165
+
166
+ table.auto_set_font_size(False)
167
+
168
+ # Dynamically adjust font size if names are very long/wrapped
169
+ base_font_size = 10
170
+ if any('\n' in row[1] for row in processed_rows):
171
+ base_font_size = 8 # Reduce font if wrapping occurred
172
+
173
+ table.set_fontsize(base_font_size)
174
+
175
+ # Scale row height to accommodate wrapped text
176
+ # Base scale 1.5, increase if wrapped
177
+ row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5
178
+ table.scale(1, row_scale)
179
+
180
+ # Style the header
181
+ for i in range(3):
182
+ cell = table[(0, i)]
183
+ cell.set_text_props(fontweight='bold', color='white', ha='center')
184
+ cell.set_facecolor('#4472C4')
185
+ if i == 1: # Center the header of the name column
186
+ cell.set_text_props(ha='center')
187
+
188
+ # Style body cells
189
+ for i in range(1, len(full_table_data)):
190
+ for j in range(3):
191
+ cell = table[(i, j)]
192
+ cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9')
193
+ cell.set_edgecolor('#dddddd')
194
+ cell.set_linewidth(1)
195
+
196
+ # Alignment logic
197
+ if j == 0: # Rank
198
+ cell.set_text_props(ha='center', va='center')
199
+ elif j == 1: # Name (Left aligned, top aligned for wrapped text)
200
+ cell.set_text_props(ha='left', va='top', wrap=True)
201
+ else: # Confidence
202
+ cell.set_text_props(ha='center', va='center')
203
+
204
+ # Add File Path Text
205
+ display_path = image_path
206
+ if len(display_path) > 50:
207
+ display_path = "..." + display_path[-47:]
208
+
209
+ path_text = f"File Path:\n{display_path}"
210
+ ax2.text(0.5, 0.92, path_text,
211
+ ha='center', va='center', fontsize=9, color='#555555',
212
+ bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8))
213
+
214
+ ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20)
215
+
216
+ plt.tight_layout()
217
+
218
+ out_path = save_path if save_path else 'prediction_result_top5.png'
219
+ plt.savefig(out_path, dpi=150, bbox_inches='tight')
220
+ plt.close()
221
+ print(f"Result saved to: {out_path}")
222
+
223
+ def main():
224
+ parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)")
225
+ parser.add_argument("-c", "--class_map_file",
226
+ default="./class_name.txt",
227
+ help="Path to configuration file")
228
+ parser.add_argument("-m", "--model_file",
229
+ default="./model/AX650/bird_650_npu3.axmodel",
230
+ help="Path to ONNX model file")
231
+ parser.add_argument("--image_dir",
232
+ default="./test_images",
233
+ help="Directory containing test images")
234
+ parser.add_argument("--image",
235
+ help="Path to a single test image")
236
+ parser.add_argument("--top_k",
237
+ type=int,
238
+ default=5,
239
+ help="Number of top predictions to show (default: 5)")
240
+
241
+ args = parser.parse_args()
242
+
243
+
244
+ predictor = BirdPredictorONNX(args.class_map_file, args.model_file)
245
+
246
+ if args.image and os.path.exists(args.image):
247
+ try:
248
+ top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
249
+
250
+ print(f"\nImage: {args.image}")
251
+ print(f"Top-{args.top_k} Predictions:")
252
+ for i, (cls_name, conf) in enumerate(top_k_results):
253
+ print(f"#{i+1}: {cls_name} ({conf:.4f})")
254
+
255
+ predictor.visualize_prediction_topk(args.image, top_k_results)
256
+
257
+ except Exception as e:
258
+ print(f"Inference failed: {e}")
259
+
260
+ elif os.path.exists(args.image_dir):
261
+ results = predictor.predict_batch_topk(args.image_dir, k=args.top_k)
262
+
263
+ print(f"\nProcessed {len(results)} images:")
264
+ for res in results:
265
+ print(f"File: {res['filename']}")
266
+ for i, (cls_name, conf) in enumerate(res['top_k']):
267
+ marker = "[1]" if i == 0 else " "
268
+ print(f"{marker} #{i+1}: {cls_name} ({conf:.4f})")
269
+
270
+ print("\nNote: Visualization saves only the last processed image in batch mode.")
271
+ if results:
272
+ last_res = results[-1]
273
+ predictor.visualize_prediction_topk(last_res['path'], last_res['top_k'], save_path='batch_last_result.png')
274
+
275
+ else:
276
+ print("Specified image or directory not found.")
277
+
278
+ if __name__ == "__main__":
279
+ main()
class_name.txt ADDED
The diff for this file is too large to render. See raw diff
 
model/AX615/bird_615_npu1.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f68795ef21bb848113545378feb9f7b4fa8eda552680fdf6f3b83f2f9d8bfcb
3
+ size 12148436
model/AX615/bird_615_npu2.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5639d48e44c218e6e1a4b2ad075c4249ae89981c12ff92cdbe358e85d1f0091a
3
+ size 12080120
model/AX620E/bird_630_npu1.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21e4e53ab56cf74335a2ea37d675d3816600630fdaa6d8838492c92870d6c832
3
+ size 12285724
model/AX620E/bird_630_npu2.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:895d300404d0134497a8fa643a4bc2136f758a6c4d393da813c10274bdfafed4
3
+ size 12067712
model/AX650/bird_650_npu3.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e544a71cb3ca619c8f064f8952193453c33e816467e03d2b5e7bed75aec99038
3
+ size 12209712
onnx_infer.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced)
5
+ Loads an exported ONNX model for bird classification.
6
+ Defaults to CPU execution.
7
+ """
8
+ import os
9
+ import argparse
10
+ import numpy as np
11
+ import cv2
12
+ from PIL import Image
13
+ import onnxruntime as ort
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ import matplotlib.pyplot as plt
17
+
18
+ # Ensure English fonts are used to avoid warnings
19
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
20
+ plt.rcParams['axes.unicode_minus'] = False
21
+
22
+ class BirdPredictorONNX:
23
+ """Bird classification predictor based on ONNX Runtime"""
24
+
25
+ def __init__(self, class_name_file, model_file):
26
+ """
27
+ Initialize the predictor.
28
+ Defaults to CPUExecutionProvider.
29
+ """
30
+ self.rgb_mean = [0.5,0.5,0.5]
31
+ self.rgb_std = [0.5,0.5,0.5]
32
+ self.classes = self.load_classes(class_name_file)
33
+
34
+ providers = ['CPUExecutionProvider']
35
+ print(f"Loading ONNX model with providers: {providers}")
36
+
37
+ try:
38
+ self.session = ort.InferenceSession(model_file, providers=providers)
39
+ except Exception as e:
40
+ print(f"Failed to load model: {e}")
41
+ raise
42
+
43
+ self.input_name = self.session.get_inputs()[0].name
44
+ self.input_shape = self.session.get_inputs()[0].shape
45
+
46
+ self.transform = self.get_transform_params()
47
+
48
+ def load_classes(self,class_name_file):
49
+ with open(class_name_file, 'r', encoding='utf-8') as f:
50
+ classes = [line.strip() for line in f.readlines() if line.strip()]
51
+ return classes
52
+
53
+ def get_transform_params(self):
54
+ mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1)
55
+ std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1)
56
+ return {'mean': mean, 'std': std}
57
+
58
+ def preprocess_image(self, image_path):
59
+ image = Image.open(image_path).convert('RGB')
60
+ image = image.resize((224, 224), Image.BILINEAR)
61
+
62
+ img_array = np.array(image, dtype=np.float32) / 255.0
63
+ img_array = img_array.transpose(2, 0, 1)
64
+ img_array = np.expand_dims(img_array, axis=0)
65
+
66
+ mean = self.transform['mean']
67
+ std = self.transform['std']
68
+ img_array = (img_array - mean) / std
69
+
70
+ return img_array.astype(np.float32)
71
+
72
+ def predict_image_topk(self, image_path, k=5):
73
+ input_data = self.preprocess_image(image_path)
74
+ outputs = self.session.run(None, {self.input_name: input_data})
75
+
76
+ logits = outputs[0]
77
+ exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True))
78
+ probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
79
+
80
+ probs_0 = probabilities[0]
81
+ top_k_indices = np.argsort(probs_0)[::-1][:k]
82
+
83
+ results = []
84
+ for idx in top_k_indices:
85
+ class_name = self.classes[idx]
86
+ conf = float(probs_0[idx])
87
+ results.append((class_name, conf))
88
+
89
+ return results
90
+
91
+ def predict_batch_topk(self, image_dir, k=5):
92
+ results = []
93
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
94
+
95
+ files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)])
96
+ print(f"Found {len(files)} images, starting inference (Top-{k})...")
97
+
98
+ for filename in files:
99
+ image_path = os.path.join(image_dir, filename)
100
+ try:
101
+ top_k_results = self.predict_image_topk(image_path, k=k)
102
+ results.append({
103
+ 'filename': filename,
104
+ 'path': image_path,
105
+ 'top_k': top_k_results
106
+ })
107
+ except Exception as e:
108
+ print(f"Error processing image {filename}: {str(e)}")
109
+
110
+ return results
111
+
112
+ def _wrap_text(self, text, max_chars=25):
113
+ """
114
+ Helper function to wrap or truncate long text to fit in table cells.
115
+ Tries to break at underscores or hyphens first.
116
+ """
117
+ if len(text) <= max_chars:
118
+ return text
119
+
120
+ # Try to find a good breaking point (underscore or hyphen) near the limit
121
+ break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']]
122
+
123
+ if break_points:
124
+ # Break at the last found separator within the limit
125
+ split_idx = break_points[-1] + 1
126
+ return text[:split_idx] + "\n" + text[split_idx:]
127
+
128
+ # If no good break point, just force split in the middle
129
+ mid = max_chars // 2
130
+ return text[:mid] + "-\n" + text[mid:]
131
+
132
+ def visualize_prediction_topk(self, image_path, top_k_results, save_path=None):
133
+ image = cv2.imread(image_path)
134
+ if image is None:
135
+ raise ValueError(f"Cannot read image: {image_path}")
136
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
137
+
138
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
139
+
140
+ ax1.imshow(image)
141
+ ax1.set_title('Input Image', fontsize=14, fontweight='bold')
142
+ ax1.axis('off')
143
+
144
+ ax2.axis('off')
145
+
146
+ table_data = []
147
+ table_data.append(["Rank", "Class Name", "Confidence"])
148
+
149
+ processed_rows = []
150
+ for i, (cls_name, conf) in enumerate(top_k_results):
151
+ rank = f"#{i+1}"
152
+ conf_str = f"{conf:.4f} ({conf*100:.2f}%)"
153
+
154
+ # Process long class names
155
+ wrapped_name = self._wrap_text(cls_name, max_chars=28) # Increased limit slightly but allow wrapping
156
+ processed_rows.append([rank, wrapped_name, conf_str])
157
+
158
+ # Combine header and rows
159
+ full_table_data = [table_data[0]] + processed_rows
160
+
161
+ # Create table with specific column widths
162
+ # Col widths: Rank (10%), Name (60%), Conf (30%)
163
+ table = ax2.table(cellText=full_table_data[1:],
164
+ colLabels=full_table_data[0],
165
+ loc='center',
166
+ cellLoc='left', # Left align for text content usually looks better with wraps
167
+ colWidths=[0.1, 0.6, 0.3],
168
+ bbox=[0.05, 0.1, 0.9, 0.75]) # Adjusted bbox to give more vertical space
169
+
170
+ table.auto_set_font_size(False)
171
+
172
+ # Dynamically adjust font size if names are very long/wrapped
173
+ base_font_size = 10
174
+ if any('\n' in row[1] for row in processed_rows):
175
+ base_font_size = 8 # Reduce font if wrapping occurred
176
+
177
+ table.set_fontsize(base_font_size)
178
+
179
+ # Scale row height to accommodate wrapped text
180
+ # Base scale 1.5, increase if wrapped
181
+ row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5
182
+ table.scale(1, row_scale)
183
+
184
+ # Style the header
185
+ for i in range(3):
186
+ cell = table[(0, i)]
187
+ cell.set_text_props(fontweight='bold', color='white', ha='center')
188
+ cell.set_facecolor('#4472C4')
189
+ if i == 1: # Center the header of the name column
190
+ cell.set_text_props(ha='center')
191
+
192
+ # Style body cells
193
+ for i in range(1, len(full_table_data)):
194
+ for j in range(3):
195
+ cell = table[(i, j)]
196
+ cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9')
197
+ cell.set_edgecolor('#dddddd')
198
+ cell.set_linewidth(1)
199
+
200
+ # Alignment logic
201
+ if j == 0: # Rank
202
+ cell.set_text_props(ha='center', va='center')
203
+ elif j == 1: # Name (Left aligned, top aligned for wrapped text)
204
+ cell.set_text_props(ha='left', va='top', wrap=True)
205
+ else: # Confidence
206
+ cell.set_text_props(ha='center', va='center')
207
+
208
+ # Add File Path Text
209
+ display_path = image_path
210
+ if len(display_path) > 50:
211
+ display_path = "..." + display_path[-47:]
212
+
213
+ path_text = f"File Path:\n{display_path}"
214
+ ax2.text(0.5, 0.92, path_text,
215
+ ha='center', va='center', fontsize=9, color='#555555',
216
+ bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8))
217
+
218
+ ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20)
219
+
220
+ plt.tight_layout()
221
+
222
+ out_path = save_path if save_path else 'prediction_result_top5.png'
223
+ plt.savefig(out_path, dpi=150, bbox_inches='tight')
224
+ plt.close()
225
+ print(f"Result saved to: {out_path}")
226
+
227
+ def main():
228
+ parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)")
229
+ parser.add_argument("-c", "--class_map_file",
230
+ default="./class_name.txt",
231
+ help="Path to configuration file")
232
+ parser.add_argument("-m", "--model_file",
233
+ default="./bird_resnet18.onnx",
234
+ help="Path to ONNX model file")
235
+ parser.add_argument("--image_dir",
236
+ default="./test_images",
237
+ help="Directory containing test images")
238
+ parser.add_argument("--image",
239
+ help="Path to a single test image")
240
+ parser.add_argument("--top_k",
241
+ type=int,
242
+ default=5,
243
+ help="Number of top predictions to show (default: 5)")
244
+
245
+ args = parser.parse_args()
246
+
247
+
248
+ predictor = BirdPredictorONNX(args.class_map_file, args.model_file)
249
+
250
+ if args.image and os.path.exists(args.image):
251
+ try:
252
+ top_k_results = predictor.predict_image_topk(args.image, k=args.top_k)
253
+
254
+ print(f"\nImage: {args.image}")
255
+ print(f"Top-{args.top_k} Predictions:")
256
+ for i, (cls_name, conf) in enumerate(top_k_results):
257
+ print(f"#{i+1}: {cls_name} ({conf:.4f})")
258
+
259
+ predictor.visualize_prediction_topk(args.image, top_k_results)
260
+
261
+ except Exception as e:
262
+ print(f"Inference failed: {e}")
263
+
264
+ elif os.path.exists(args.image_dir):
265
+ results = predictor.predict_batch_topk(args.image_dir, k=args.top_k)
266
+
267
+ print(f"\nProcessed {len(results)} images:")
268
+ for res in results:
269
+ print(f"File: {res['filename']}")
270
+ for i, (cls_name, conf) in enumerate(res['top_k']):
271
+ marker = "[1]" if i == 0 else " "
272
+ print(f"{marker} #{i+1}: {cls_name} ({conf:.4f})")
273
+
274
+ print("\nNote: Visualization saves only the last processed image in batch mode.")
275
+ if results:
276
+ last_res = results[-1]
277
+ predictor.visualize_prediction_topk(last_res['path'], last_res['top_k'], save_path='batch_last_result.png')
278
+
279
+ else:
280
+ print("Specified image or directory not found.")
281
+
282
+ if __name__ == "__main__":
283
+ main()
prediction_result_top5.png ADDED

Git LFS Details

  • SHA256: d7a0b69c8cf157174bfe12182be62ba476d96c760dc0ebc04976b0da86616098
  • Pointer size: 131 Bytes
  • Size of remote file: 626 kB
quant/Bird.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:218c30ea230c8b830402bccced57c2d6727f0f2d9d14e61d4315d9ce2e2e79ae
3
+ size 893
quant/bird.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb5cc7dbf1ccd7f9c2cafb24ec65eb461bea5c8a25ef83c846da421433114758
3
+ size 108199387
test_images/03111_2c0dfa5a-c4a0-47f8-ac89-6a289208050f.jpg ADDED

Git LFS Details

  • SHA256: 1ae2bc22d3f8cd5d59822b0faf99deaadc4bde4cfc9be262524697214f3d6fe5
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
test_images/03332_01b365c3-a741-4f45-bac2-4345bc901ec6.jpg ADDED

Git LFS Details

  • SHA256: 83f927f95e9d8724cff65421ec25627415c00ede4298705c3f0a39b3d0d77bd9
  • Pointer size: 130 Bytes
  • Size of remote file: 89.9 kB
test_images/03412_0ffc115b-43b4-4474-a373-24233f391de3.jpg ADDED

Git LFS Details

  • SHA256: 4625ca6ee197480977fad86f38a04571adb517d04b42975d390c1f2e5af4ce5d
  • Pointer size: 130 Bytes
  • Size of remote file: 45.7 kB
test_images/03615_0dfbf6ae-434d-4648-b5d2-08412546ea64.jpg ADDED

Git LFS Details

  • SHA256: 1f1dbd4119a39d3969ac9f34b59f9407bbb59fc7c32cc3c27f7aa93a490dbeef
  • Pointer size: 130 Bytes
  • Size of remote file: 86.1 kB
test_images/04251_3a52191e-be71-4539-98ea-14a8f2347330.jpg ADDED

Git LFS Details

  • SHA256: 4049b07ff83ae7ff9085d4dcc0b8a867cece51423a45790a102d5aae584dbba5
  • Pointer size: 130 Bytes
  • Size of remote file: 34.1 kB
test_images/04405_0c5a6785-0bc2-49d9-9702-b9e94ba9b686.jpg ADDED

Git LFS Details

  • SHA256: e1de3d2581befe850dfadb4963945f227e1dc51548a18a552272774b5ec6dfb7
  • Pointer size: 130 Bytes
  • Size of remote file: 80 kB
test_images/04593_3d74d5a7-15b1-4bb9-af6f-1bcd78485787.jpg ADDED

Git LFS Details

  • SHA256: e92aee2984391a9a9fea5599525ea5aafda1bd37ad97e46973576ebe47c9cca7
  • Pointer size: 130 Bytes
  • Size of remote file: 98 kB