single-mouse-webrtc-pose / inspect_model.py
Hakureirm's picture
Add NVIDIA GPU detection and TensorRT engine model support
c924202
#!/usr/bin/env python3
"""
模型详细检查脚本 - 检查模型的关键点配置
"""
import os
import sys
import logging
import numpy as np
import cv2
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def inspect_model():
"""检查模型详细信息"""
base_model_path = "models/kunin-mice-pose.v0.1.5n.pt"
# 选择模型路径(与SingleMouseProcessor保持一致)
model_path = select_model_path(base_model_path)
try:
from ultralytics import YOLO
logger.info("正在加载模型...")
model = YOLO(model_path, task="pose")
# 检查模型配置
logger.info("📊 模型详细信息:")
# 获取模型名称信息
if hasattr(model, 'model') and hasattr(model.model, 'names'):
logger.info(f" - 类别数量: {len(model.model.names)}")
logger.info(f" - 类别名称: {model.model.names}")
if hasattr(model, 'model'):
if hasattr(model.model, 'kpt_shape'):
logger.info(f" - 关键点形状: {model.model.kpt_shape}")
elif hasattr(model.model, 'yaml') and 'kpt_shape' in model.model.yaml:
logger.info(f" - 关键点形状: {model.model.yaml['kpt_shape']}")
# 检查模型的yaml配置
if hasattr(model.model, 'yaml'):
yaml_data = model.model.yaml
logger.info(f" - YAML配置:")
for key in ['nc', 'kpt_shape', 'names']:
if key in yaml_data:
logger.info(f" {key}: {yaml_data[key]}")
# 使用真实的小鼠图像进行测试(如果有的话)
test_images = [
"test_mouse.jpg",
"test_mouse.png",
"mouse_test.jpg",
"sample.jpg"
]
test_image = None
for img_path in test_images:
if os.path.exists(img_path):
test_image = img_path
break
if test_image:
logger.info(f"📸 使用测试图像: {test_image}")
results = model(test_image, conf=0.1, verbose=True) # 降低置信度阈值
else:
logger.info("📸 使用随机图像进行测试...")
# 创建一个更像小鼠的测试图像
test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8)
# 添加一些简单的形状
cv2.circle(test_img, (320, 320), 50, (100, 100, 100), -1)
cv2.ellipse(test_img, (320, 320), (80, 40), 0, 0, 360, (150, 150, 150), -1)
results = model(test_img, conf=0.1, verbose=True) # 降低置信度阈值
result = results[0]
logger.info("🔍 推理结果分析:")
logger.info(f" - 原始输出形状: {result.orig_shape}")
if hasattr(result, 'boxes') and result.boxes is not None:
logger.info(f" - 检测框数量: {len(result.boxes)}")
if len(result.boxes) > 0:
logger.info(f" - 置信度范围: {result.boxes.conf.min():.3f} - {result.boxes.conf.max():.3f}")
else:
logger.info(" - 没有检测到目标")
if hasattr(result, 'keypoints') and result.keypoints is not None:
kpts_data = result.keypoints.data
logger.info(f" - 关键点数据形状: {kpts_data.shape}")
logger.info(f" - 关键点数据类型: {kpts_data.dtype}")
if len(kpts_data.shape) == 3:
n_detections, n_keypoints, coords = kpts_data.shape
logger.info(f" - 检测数量: {n_detections}")
logger.info(f" - 每个检测的关键点数: {n_keypoints}")
logger.info(f" - 每个关键点的坐标数: {coords}")
# 计算实际的关键点数量
actual_keypoints = coords // 3 if coords % 3 == 0 else coords // 2
logger.info(f" - 推断的关键点数量: {actual_keypoints}")
if n_detections > 0:
kpts = kpts_data[0] # 第一个检测结果
visible_kpts = (kpts[:, 2] > 0.3).sum() if coords == 3 else (kpts[:, 0] > 0).sum()
logger.info(f" - 可见关键点数量: {visible_kpts}")
# 显示前几个关键点的值
logger.info(" - 前5个关键点:")
for i in range(min(5, n_keypoints)):
if coords == 3:
x, y, conf = kpts[i]
logger.info(f" {i}: ({x:.1f}, {y:.1f}, {conf:.3f})")
else:
x, y = kpts[i]
logger.info(f" {i}: ({x:.1f}, {y:.1f})")
return True
except Exception as e:
logger.error(f"❌ 模型检查失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return False
def select_model_path(base_model_path: str) -> str:
"""根据GPU情况选择模型路径"""
try:
import torch
# 检测是否有NVIDIA GPU
if torch.cuda.is_available():
nvidia_gpu_found = False
for i in range(torch.cuda.device_count()):
gpu_name = torch.cuda.get_device_name(i).lower()
if 'nvidia' in gpu_name:
nvidia_gpu_found = True
break
if nvidia_gpu_found:
# 构建.engine模型路径
engine_path = base_model_path.replace('.pt', '.engine')
if os.path.exists(engine_path):
logger.info(f"🚀 检测到NVIDIA GPU,使用TensorRT模型: {engine_path}")
return engine_path
else:
logger.info(f"⚠️ NVIDIA GPU已检测到,但TensorRT模型不存在: {engine_path}")
logger.info(f"📍 使用PyTorch模型: {base_model_path}")
return base_model_path
else:
logger.info(f"📍 检测到GPU但非NVIDIA,使用PyTorch模型: {base_model_path}")
return base_model_path
else:
logger.info(f"📍 未检测到GPU,使用CPU模式,PyTorch模型: {base_model_path}")
return base_model_path
except Exception as e:
logger.warning(f"⚠️ GPU检测失败,使用默认模型: {str(e)}")
return base_model_path
def main():
"""主函数"""
logger.info("🔍 开始模型详细检查...")
success = inspect_model()
if success:
logger.info("✅ 模型检查完成!")
else:
logger.error("❌ 模型检查失败")
if __name__ == "__main__":
main()