File size: 4,764 Bytes
4673545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c924202
 
 
 
4673545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c924202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4673545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""
模型检查脚本 - 验证单鼠姿态检测模型是否正确加载
"""

import os
import sys
import logging
import numpy as np
import cv2

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def check_model():
    """检查模型文件"""
    base_model_path = "models/kunin-mice-pose.v0.1.5n.pt"
    
    # 选择模型路径(与SingleMouseProcessor保持一致)
    model_path = select_model_path(base_model_path)
    
    logger.info(f"检查模型文件: {model_path}")
    
    # 1. 检查文件是否存在
    if not os.path.exists(model_path):
        logger.error(f"❌ 模型文件不存在: {model_path}")
        return False
    
    logger.info(f"✅ 模型文件存在: {model_path}")
    
    # 2. 检查文件大小
    file_size = os.path.getsize(model_path)
    logger.info(f"📄 文件大小: {file_size / (1024*1024):.2f} MB")
    
    # 3. 检查依赖
    try:
        import torch
        import ultralytics
        from ultralytics import YOLO
        logger.info("✅ 依赖包检查通过")
        
        # 检查CUDA
        if torch.cuda.is_available():
            logger.info(f"✅ CUDA可用: {torch.cuda.get_device_name(0)}")
        else:
            logger.info("⚠️ CUDA不可用,将使用CPU")
            
    except ImportError as e:
        logger.error(f"❌ 缺少依赖包: {str(e)}")
        return False
    
    # 4. 尝试加载模型
    try:
        logger.info("正在加载YOLO模型...")
        model = YOLO(model_path, task="pose")
        logger.info("✅ 模型加载成功")
    except Exception as e:
        logger.error(f"❌ 模型加载失败: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        return False
    
    # 5. 测试模型推理
    try:
        logger.info("测试模型推理...")
        
        # 创建测试图像
        test_img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
        
        # 进行推理
        results = model(test_img, conf=0.3, verbose=False)
        
        logger.info("✅ 模型推理测试成功")
        
        # 检查输出结构
        result = results[0]
        logger.info(f"📊 模型输出信息:")
        logger.info(f"  - 是否有boxes属性: {hasattr(result, 'boxes')}")
        logger.info(f"  - 是否有keypoints属性: {hasattr(result, 'keypoints')}")
        
        if hasattr(result, 'keypoints') and result.keypoints is not None:
            logger.info(f"  - Keypoints shape: {result.keypoints.data.shape}")
        
        return True
        
    except Exception as e:
        logger.error(f"❌ 模型推理测试失败: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        return False

def select_model_path(base_model_path: str) -> str:
    """根据GPU情况选择模型路径"""
    try:
        import torch
        # 检测是否有NVIDIA GPU
        if torch.cuda.is_available():
            nvidia_gpu_found = False
            for i in range(torch.cuda.device_count()):
                gpu_name = torch.cuda.get_device_name(i).lower()
                if 'nvidia' in gpu_name:
                    nvidia_gpu_found = True
                    break
            
            if nvidia_gpu_found:
                # 构建.engine模型路径
                engine_path = base_model_path.replace('.pt', '.engine')
                if os.path.exists(engine_path):
                    logger.info(f"🚀 检测到NVIDIA GPU,使用TensorRT模型: {engine_path}")
                    return engine_path
                else:
                    logger.info(f"⚠️ NVIDIA GPU已检测到,但TensorRT模型不存在: {engine_path}")
                    logger.info(f"📍 使用PyTorch模型: {base_model_path}")
                    return base_model_path
            else:
                logger.info(f"📍 检测到GPU但非NVIDIA,使用PyTorch模型: {base_model_path}")
                return base_model_path
        else:
            logger.info(f"📍 未检测到GPU,使用CPU模式,PyTorch模型: {base_model_path}")
            return base_model_path
            
    except Exception as e:
        logger.warning(f"⚠️ GPU检测失败,使用默认模型: {str(e)}")
        return base_model_path

def main():
    """主函数"""
    logger.info("🔍 开始模型检查...")
    
    success = check_model()
    
    if success:
        logger.info("🎉 模型检查完成,所有测试通过!")
        sys.exit(0)
    else:
        logger.error("❌ 模型检查失败,请检查模型文件和依赖")
        sys.exit(1)

if __name__ == "__main__":
    main()