#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced) Loads an exported ONNX model for bird classification. Defaults to CPU execution. """ import os import argparse import numpy as np import cv2 from PIL import Image import axengine as axe import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # Ensure English fonts are used to avoid warnings plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif'] plt.rcParams['axes.unicode_minus'] = False class BirdPredictorONNX: """Bird classification predictor based on ONNX Runtime""" def __init__(self, class_name_file, model_file): """ Initialize the predictor. Defaults to AxEngineExecutionProvider. """ self.rgb_mean = [0.5,0.5,0.5] self.rgb_std = [0.5,0.5,0.5] self.classes = self.load_classes(class_name_file) providers = ['AxEngineExecutionProvider'] print(f"Loading ONNX model with providers: {providers}") try: self.session = axe.InferenceSession(model_file, providers=providers) except Exception as e: print(f"Failed to load model: {e}") raise self.input_name = self.session.get_inputs()[0].name self.input_shape = self.session.get_inputs()[0].shape self.transform = self.get_transform_params() def load_classes(self,class_name_file): with open(class_name_file, 'r', encoding='utf-8') as f: classes = [line.strip() for line in f.readlines() if line.strip()] return classes def get_transform_params(self): mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1) std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1) return {'mean': mean, 'std': std} def preprocess_image(self, image_path): image = Image.open(image_path).convert('RGB') image = image.resize((224, 224), Image.BILINEAR) img_array = np.array(image, dtype=np.uint8) img_array = img_array.transpose(2, 0, 1) img_array = np.expand_dims(img_array, axis=0) return img_array def predict_image_topk(self, image_path, k=5): input_data = self.preprocess_image(image_path) outputs = self.session.run(None, {self.input_name: input_data}) logits = outputs[0] exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True)) probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) probs_0 = probabilities[0] top_k_indices = np.argsort(probs_0)[::-1][:k] results = [] for idx in top_k_indices: class_name = self.classes[idx] conf = float(probs_0[idx]) results.append((class_name, conf)) return results def predict_batch_topk(self, image_dir, k=5): results = [] image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)]) print(f"Found {len(files)} images, starting inference (Top-{k})...") for filename in files: image_path = os.path.join(image_dir, filename) try: top_k_results = self.predict_image_topk(image_path, k=k) results.append({ 'filename': filename, 'path': image_path, 'top_k': top_k_results }) except Exception as e: print(f"Error processing image {filename}: {str(e)}") return results def _wrap_text(self, text, max_chars=25): """ Helper function to wrap or truncate long text to fit in table cells. Tries to break at underscores or hyphens first. """ if len(text) <= max_chars: return text # Try to find a good breaking point (underscore or hyphen) near the limit break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']] if break_points: # Break at the last found separator within the limit split_idx = break_points[-1] + 1 return text[:split_idx] + "\n" + text[split_idx:] # If no good break point, just force split in the middle mid = max_chars // 2 return text[:mid] + "-\n" + text[mid:] def visualize_prediction_topk(self, image_path, top_k_results, save_path=None): image = cv2.imread(image_path) if image is None: raise ValueError(f"Cannot read image: {image_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7)) ax1.imshow(image) ax1.set_title('Input Image', fontsize=14, fontweight='bold') ax1.axis('off') ax2.axis('off') table_data = [] table_data.append(["Rank", "Class Name", "Confidence"]) processed_rows = [] for i, (cls_name, conf) in enumerate(top_k_results): rank = f"#{i+1}" conf_str = f"{conf:.4f} ({conf*100:.2f}%)" # Process long class names wrapped_name = self._wrap_text(cls_name, max_chars=28) # Increased limit slightly but allow wrapping processed_rows.append([rank, wrapped_name, conf_str]) # Combine header and rows full_table_data = [table_data[0]] + processed_rows # Create table with specific column widths # Col widths: Rank (10%), Name (60%), Conf (30%) table = ax2.table(cellText=full_table_data[1:], colLabels=full_table_data[0], loc='center', cellLoc='left', # Left align for text content usually looks better with wraps colWidths=[0.1, 0.6, 0.3], bbox=[0.05, 0.1, 0.9, 0.75]) # Adjusted bbox to give more vertical space table.auto_set_font_size(False) # Dynamically adjust font size if names are very long/wrapped base_font_size = 10 if any('\n' in row[1] for row in processed_rows): base_font_size = 8 # Reduce font if wrapping occurred table.set_fontsize(base_font_size) # Scale row height to accommodate wrapped text # Base scale 1.5, increase if wrapped row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5 table.scale(1, row_scale) # Style the header for i in range(3): cell = table[(0, i)] cell.set_text_props(fontweight='bold', color='white', ha='center') cell.set_facecolor('#4472C4') if i == 1: # Center the header of the name column cell.set_text_props(ha='center') # Style body cells for i in range(1, len(full_table_data)): for j in range(3): cell = table[(i, j)] cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9') cell.set_edgecolor('#dddddd') cell.set_linewidth(1) # Alignment logic if j == 0: # Rank cell.set_text_props(ha='center', va='center') elif j == 1: # Name (Left aligned, top aligned for wrapped text) cell.set_text_props(ha='left', va='top', wrap=True) else: # Confidence cell.set_text_props(ha='center', va='center') # Add File Path Text display_path = image_path if len(display_path) > 50: display_path = "..." + display_path[-47:] path_text = f"File Path:\n{display_path}" ax2.text(0.5, 0.92, path_text, ha='center', va='center', fontsize=9, color='#555555', bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8)) ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20) plt.tight_layout() out_path = save_path if save_path else 'prediction_result_top5.png' plt.savefig(out_path, dpi=150, bbox_inches='tight') plt.close() print(f"Result saved to: {out_path}") def main(): parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)") parser.add_argument("-c", "--class_map_file", default="./class_name.txt", help="Path to configuration file") parser.add_argument("-m", "--model_file", default="./model/AX650/bird_650_npu3.axmodel", help="Path to ONNX model file") parser.add_argument("--image_dir", default="./test_images", help="Directory containing test images") parser.add_argument("--image", help="Path to a single test image") parser.add_argument("--top_k", type=int, default=5, help="Number of top predictions to show (default: 5)") args = parser.parse_args() predictor = BirdPredictorONNX(args.class_map_file, args.model_file) if args.image and os.path.exists(args.image): try: top_k_results = predictor.predict_image_topk(args.image, k=args.top_k) print(f"\nImage: {args.image}") print(f"Top-{args.top_k} Predictions:") for i, (cls_name, conf) in enumerate(top_k_results): print(f"#{i+1}: {cls_name} ({conf:.4f})") predictor.visualize_prediction_topk(args.image, top_k_results) except Exception as e: print(f"Inference failed: {e}") elif os.path.exists(args.image_dir): results = predictor.predict_batch_topk(args.image_dir, k=args.top_k) print(f"\nProcessed {len(results)} images:") for res in results: print(f"File: {res['filename']}") for i, (cls_name, conf) in enumerate(res['top_k']): marker = "[1]" if i == 0 else " " print(f"{marker} #{i+1}: {cls_name} ({conf:.4f})") print("\nNote: Visualization saves only the last processed image in batch mode.") if results: last_res = results[-1] predictor.visualize_prediction_topk(last_res['path'], last_res['top_k'], save_path='batch_last_result.png') else: print("Specified image or directory not found.") if __name__ == "__main__": main()