| | |
| | |
| | """ |
| | ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced) |
| | Loads an exported ONNX model for bird classification. |
| | Defaults to CPU execution. |
| | """ |
| | import os |
| | import argparse |
| | import numpy as np |
| | import cv2 |
| | from PIL import Image |
| | import axengine as axe |
| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif'] |
| | plt.rcParams['axes.unicode_minus'] = False |
| |
|
| | class BirdPredictorONNX: |
| | """Bird classification predictor based on ONNX Runtime""" |
| | |
| | def __init__(self, class_name_file, model_file): |
| | """ |
| | Initialize the predictor. |
| | Defaults to AxEngineExecutionProvider. |
| | """ |
| | self.rgb_mean = [0.5,0.5,0.5] |
| | self.rgb_std = [0.5,0.5,0.5] |
| | self.classes = self.load_classes(class_name_file) |
| | |
| | providers = ['AxEngineExecutionProvider'] |
| | print(f"Loading ONNX model with providers: {providers}") |
| | |
| | try: |
| | self.session = axe.InferenceSession(model_file, providers=providers) |
| | except Exception as e: |
| | print(f"Failed to load model: {e}") |
| | raise |
| | |
| | self.input_name = self.session.get_inputs()[0].name |
| | self.input_shape = self.session.get_inputs()[0].shape |
| | |
| | self.transform = self.get_transform_params() |
| | |
| | def load_classes(self,class_name_file): |
| | with open(class_name_file, 'r', encoding='utf-8') as f: |
| | classes = [line.strip() for line in f.readlines() if line.strip()] |
| | return classes |
| | |
| | def get_transform_params(self): |
| | mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1) |
| | std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1) |
| | return {'mean': mean, 'std': std} |
| | |
| | def preprocess_image(self, image_path): |
| | image = Image.open(image_path).convert('RGB') |
| | image = image.resize((224, 224), Image.BILINEAR) |
| | |
| | img_array = np.array(image, dtype=np.uint8) |
| | img_array = img_array.transpose(2, 0, 1) |
| | img_array = np.expand_dims(img_array, axis=0) |
| | |
| | return img_array |
| | |
| | def predict_image_topk(self, image_path, k=5): |
| | input_data = self.preprocess_image(image_path) |
| | outputs = self.session.run(None, {self.input_name: input_data}) |
| | |
| | logits = outputs[0] |
| | exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True)) |
| | probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) |
| | |
| | probs_0 = probabilities[0] |
| | top_k_indices = np.argsort(probs_0)[::-1][:k] |
| | |
| | results = [] |
| | for idx in top_k_indices: |
| | class_name = self.classes[idx] |
| | conf = float(probs_0[idx]) |
| | results.append((class_name, conf)) |
| | |
| | return results |
| | |
| | def predict_batch_topk(self, image_dir, k=5): |
| | results = [] |
| | image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
| | |
| | files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)]) |
| | print(f"Found {len(files)} images, starting inference (Top-{k})...") |
| | |
| | for filename in files: |
| | image_path = os.path.join(image_dir, filename) |
| | try: |
| | top_k_results = self.predict_image_topk(image_path, k=k) |
| | results.append({ |
| | 'filename': filename, |
| | 'path': image_path, |
| | 'top_k': top_k_results |
| | }) |
| | except Exception as e: |
| | print(f"Error processing image {filename}: {str(e)}") |
| | |
| | return results |
| | |
| | def _wrap_text(self, text, max_chars=25): |
| | """ |
| | Helper function to wrap or truncate long text to fit in table cells. |
| | Tries to break at underscores or hyphens first. |
| | """ |
| | if len(text) <= max_chars: |
| | return text |
| | |
| | |
| | break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']] |
| | |
| | if break_points: |
| | |
| | split_idx = break_points[-1] + 1 |
| | return text[:split_idx] + "\n" + text[split_idx:] |
| | |
| | |
| | mid = max_chars // 2 |
| | return text[:mid] + "-\n" + text[mid:] |
| |
|
| | def visualize_prediction_topk(self, image_path, top_k_results, save_path=None): |
| | image = cv2.imread(image_path) |
| | if image is None: |
| | raise ValueError(f"Cannot read image: {image_path}") |
| | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | |
| | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7)) |
| | |
| | ax1.imshow(image) |
| | ax1.set_title('Input Image', fontsize=14, fontweight='bold') |
| | ax1.axis('off') |
| | |
| | ax2.axis('off') |
| | |
| | table_data = [] |
| | table_data.append(["Rank", "Class Name", "Confidence"]) |
| | |
| | processed_rows = [] |
| | for i, (cls_name, conf) in enumerate(top_k_results): |
| | rank = f"#{i+1}" |
| | conf_str = f"{conf:.4f} ({conf*100:.2f}%)" |
| | |
| | |
| | wrapped_name = self._wrap_text(cls_name, max_chars=28) |
| | processed_rows.append([rank, wrapped_name, conf_str]) |
| | |
| | |
| | full_table_data = [table_data[0]] + processed_rows |
| | |
| | |
| | |
| | table = ax2.table(cellText=full_table_data[1:], |
| | colLabels=full_table_data[0], |
| | loc='center', |
| | cellLoc='left', |
| | colWidths=[0.1, 0.6, 0.3], |
| | bbox=[0.05, 0.1, 0.9, 0.75]) |
| | |
| | table.auto_set_font_size(False) |
| | |
| | |
| | base_font_size = 10 |
| | if any('\n' in row[1] for row in processed_rows): |
| | base_font_size = 8 |
| | |
| | table.set_fontsize(base_font_size) |
| | |
| | |
| | |
| | row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5 |
| | table.scale(1, row_scale) |
| | |
| | |
| | for i in range(3): |
| | cell = table[(0, i)] |
| | cell.set_text_props(fontweight='bold', color='white', ha='center') |
| | cell.set_facecolor('#4472C4') |
| | if i == 1: |
| | cell.set_text_props(ha='center') |
| | |
| | |
| | for i in range(1, len(full_table_data)): |
| | for j in range(3): |
| | cell = table[(i, j)] |
| | cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9') |
| | cell.set_edgecolor('#dddddd') |
| | cell.set_linewidth(1) |
| | |
| | |
| | if j == 0: |
| | cell.set_text_props(ha='center', va='center') |
| | elif j == 1: |
| | cell.set_text_props(ha='left', va='top', wrap=True) |
| | else: |
| | cell.set_text_props(ha='center', va='center') |
| |
|
| | |
| | display_path = image_path |
| | if len(display_path) > 50: |
| | display_path = "..." + display_path[-47:] |
| | |
| | path_text = f"File Path:\n{display_path}" |
| | ax2.text(0.5, 0.92, path_text, |
| | ha='center', va='center', fontsize=9, color='#555555', |
| | bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8)) |
| | |
| | ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20) |
| | |
| | plt.tight_layout() |
| | |
| | out_path = save_path if save_path else 'prediction_result_top5.png' |
| | plt.savefig(out_path, dpi=150, bbox_inches='tight') |
| | plt.close() |
| | print(f"Result saved to: {out_path}") |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)") |
| | parser.add_argument("-c", "--class_map_file", |
| | default="./class_name.txt", |
| | help="Path to configuration file") |
| | parser.add_argument("-m", "--model_file", |
| | default="./model/AX650/bird_650_npu3.axmodel", |
| | help="Path to ONNX model file") |
| | parser.add_argument("--image_dir", |
| | default="./test_images", |
| | help="Directory containing test images") |
| | parser.add_argument("--image", |
| | help="Path to a single test image") |
| | parser.add_argument("--top_k", |
| | type=int, |
| | default=5, |
| | help="Number of top predictions to show (default: 5)") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | predictor = BirdPredictorONNX(args.class_map_file, args.model_file) |
| | |
| | if args.image and os.path.exists(args.image): |
| | try: |
| | top_k_results = predictor.predict_image_topk(args.image, k=args.top_k) |
| | |
| | print(f"\nImage: {args.image}") |
| | print(f"Top-{args.top_k} Predictions:") |
| | for i, (cls_name, conf) in enumerate(top_k_results): |
| | print(f"#{i+1}: {cls_name} ({conf:.4f})") |
| | |
| | predictor.visualize_prediction_topk(args.image, top_k_results) |
| | |
| | except Exception as e: |
| | print(f"Inference failed: {e}") |
| | |
| | elif os.path.exists(args.image_dir): |
| | results = predictor.predict_batch_topk(args.image_dir, k=args.top_k) |
| | |
| | print(f"\nProcessed {len(results)} images:") |
| | for res in results: |
| | print(f"File: {res['filename']}") |
| | for i, (cls_name, conf) in enumerate(res['top_k']): |
| | marker = "[1]" if i == 0 else " " |
| | print(f"{marker} #{i+1}: {cls_name} ({conf:.4f})") |
| | |
| | print("\nNote: Visualization saves only the last processed image in batch mode.") |
| | if results: |
| | last_res = results[-1] |
| | predictor.visualize_prediction_topk(last_res['path'], last_res['top_k'], save_path='batch_last_result.png') |
| |
|
| | else: |
| | print("Specified image or directory not found.") |
| |
|
| | if __name__ == "__main__": |
| | main() |