#!/usr/bin/env python3 """ 模型详细检查脚本 - 检查模型的关键点配置 """ import os import sys import logging import numpy as np import cv2 # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def inspect_model(): """检查模型详细信息""" base_model_path = "models/kunin-mice-pose.v0.1.5n.pt" # 选择模型路径(与SingleMouseProcessor保持一致) model_path = select_model_path(base_model_path) try: from ultralytics import YOLO logger.info("正在加载模型...") model = YOLO(model_path, task="pose") # 检查模型配置 logger.info("📊 模型详细信息:") # 获取模型名称信息 if hasattr(model, 'model') and hasattr(model.model, 'names'): logger.info(f" - 类别数量: {len(model.model.names)}") logger.info(f" - 类别名称: {model.model.names}") if hasattr(model, 'model'): if hasattr(model.model, 'kpt_shape'): logger.info(f" - 关键点形状: {model.model.kpt_shape}") elif hasattr(model.model, 'yaml') and 'kpt_shape' in model.model.yaml: logger.info(f" - 关键点形状: {model.model.yaml['kpt_shape']}") # 检查模型的yaml配置 if hasattr(model.model, 'yaml'): yaml_data = model.model.yaml logger.info(f" - YAML配置:") for key in ['nc', 'kpt_shape', 'names']: if key in yaml_data: logger.info(f" {key}: {yaml_data[key]}") # 使用真实的小鼠图像进行测试(如果有的话) test_images = [ "test_mouse.jpg", "test_mouse.png", "mouse_test.jpg", "sample.jpg" ] test_image = None for img_path in test_images: if os.path.exists(img_path): test_image = img_path break if test_image: logger.info(f"📸 使用测试图像: {test_image}") results = model(test_image, conf=0.1, verbose=True) # 降低置信度阈值 else: logger.info("📸 使用随机图像进行测试...") # 创建一个更像小鼠的测试图像 test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8) # 添加一些简单的形状 cv2.circle(test_img, (320, 320), 50, (100, 100, 100), -1) cv2.ellipse(test_img, (320, 320), (80, 40), 0, 0, 360, (150, 150, 150), -1) results = model(test_img, conf=0.1, verbose=True) # 降低置信度阈值 result = results[0] logger.info("🔍 推理结果分析:") logger.info(f" - 原始输出形状: {result.orig_shape}") if hasattr(result, 'boxes') and result.boxes is not None: logger.info(f" - 检测框数量: {len(result.boxes)}") if len(result.boxes) > 0: logger.info(f" - 置信度范围: {result.boxes.conf.min():.3f} - {result.boxes.conf.max():.3f}") else: logger.info(" - 没有检测到目标") if hasattr(result, 'keypoints') and result.keypoints is not None: kpts_data = result.keypoints.data logger.info(f" - 关键点数据形状: {kpts_data.shape}") logger.info(f" - 关键点数据类型: {kpts_data.dtype}") if len(kpts_data.shape) == 3: n_detections, n_keypoints, coords = kpts_data.shape logger.info(f" - 检测数量: {n_detections}") logger.info(f" - 每个检测的关键点数: {n_keypoints}") logger.info(f" - 每个关键点的坐标数: {coords}") # 计算实际的关键点数量 actual_keypoints = coords // 3 if coords % 3 == 0 else coords // 2 logger.info(f" - 推断的关键点数量: {actual_keypoints}") if n_detections > 0: kpts = kpts_data[0] # 第一个检测结果 visible_kpts = (kpts[:, 2] > 0.3).sum() if coords == 3 else (kpts[:, 0] > 0).sum() logger.info(f" - 可见关键点数量: {visible_kpts}") # 显示前几个关键点的值 logger.info(" - 前5个关键点:") for i in range(min(5, n_keypoints)): if coords == 3: x, y, conf = kpts[i] logger.info(f" {i}: ({x:.1f}, {y:.1f}, {conf:.3f})") else: x, y = kpts[i] logger.info(f" {i}: ({x:.1f}, {y:.1f})") return True except Exception as e: logger.error(f"❌ 模型检查失败: {str(e)}") import traceback logger.error(traceback.format_exc()) return False def select_model_path(base_model_path: str) -> str: """根据GPU情况选择模型路径""" try: import torch # 检测是否有NVIDIA GPU if torch.cuda.is_available(): nvidia_gpu_found = False for i in range(torch.cuda.device_count()): gpu_name = torch.cuda.get_device_name(i).lower() if 'nvidia' in gpu_name: nvidia_gpu_found = True break if nvidia_gpu_found: # 构建.engine模型路径 engine_path = base_model_path.replace('.pt', '.engine') if os.path.exists(engine_path): logger.info(f"🚀 检测到NVIDIA GPU,使用TensorRT模型: {engine_path}") return engine_path else: logger.info(f"⚠️ NVIDIA GPU已检测到,但TensorRT模型不存在: {engine_path}") logger.info(f"📍 使用PyTorch模型: {base_model_path}") return base_model_path else: logger.info(f"📍 检测到GPU但非NVIDIA,使用PyTorch模型: {base_model_path}") return base_model_path else: logger.info(f"📍 未检测到GPU,使用CPU模式,PyTorch模型: {base_model_path}") return base_model_path except Exception as e: logger.warning(f"⚠️ GPU检测失败,使用默认模型: {str(e)}") return base_model_path def main(): """主函数""" logger.info("🔍 开始模型详细检查...") success = inspect_model() if success: logger.info("✅ 模型检查完成!") else: logger.error("❌ 模型检查失败") if __name__ == "__main__": main()