Hakureirm commited on
Commit
c924202
·
1 Parent(s): 48af9b3

Add NVIDIA GPU detection and TensorRT engine model support

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. check_model.py +38 -1
  3. gradio_webrtc_server.py +36 -1
  4. inspect_model.py +38 -1
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.engine filter=lfs diff=lfs merge=lfs -text
check_model.py CHANGED
@@ -15,7 +15,10 @@ logger = logging.getLogger(__name__)
15
 
16
  def check_model():
17
  """检查模型文件"""
18
- model_path = "models/kunin-mice-pose.v0.1.5n.pt"
 
 
 
19
 
20
  logger.info(f"检查模型文件: {model_path}")
21
 
@@ -87,6 +90,40 @@ def check_model():
87
  logger.error(traceback.format_exc())
88
  return False
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def main():
91
  """主函数"""
92
  logger.info("🔍 开始模型检查...")
 
15
 
16
  def check_model():
17
  """检查模型文件"""
18
+ base_model_path = "models/kunin-mice-pose.v0.1.5n.pt"
19
+
20
+ # 选择模型路径(与SingleMouseProcessor保持一致)
21
+ model_path = select_model_path(base_model_path)
22
 
23
  logger.info(f"检查模型文件: {model_path}")
24
 
 
90
  logger.error(traceback.format_exc())
91
  return False
92
 
93
+ def select_model_path(base_model_path: str) -> str:
94
+ """根据GPU情况选择模型路径"""
95
+ try:
96
+ import torch
97
+ # 检测是否有NVIDIA GPU
98
+ if torch.cuda.is_available():
99
+ nvidia_gpu_found = False
100
+ for i in range(torch.cuda.device_count()):
101
+ gpu_name = torch.cuda.get_device_name(i).lower()
102
+ if 'nvidia' in gpu_name:
103
+ nvidia_gpu_found = True
104
+ break
105
+
106
+ if nvidia_gpu_found:
107
+ # 构建.engine模型路径
108
+ engine_path = base_model_path.replace('.pt', '.engine')
109
+ if os.path.exists(engine_path):
110
+ logger.info(f"🚀 检测到NVIDIA GPU,使用TensorRT模型: {engine_path}")
111
+ return engine_path
112
+ else:
113
+ logger.info(f"⚠️ NVIDIA GPU已检测到,但TensorRT模型不存在: {engine_path}")
114
+ logger.info(f"📍 使用PyTorch模型: {base_model_path}")
115
+ return base_model_path
116
+ else:
117
+ logger.info(f"📍 检测到GPU但非NVIDIA,使用PyTorch模型: {base_model_path}")
118
+ return base_model_path
119
+ else:
120
+ logger.info(f"📍 未检测到GPU,使用CPU模式,PyTorch模型: {base_model_path}")
121
+ return base_model_path
122
+
123
+ except Exception as e:
124
+ logger.warning(f"⚠️ GPU检测失败,使用默认模型: {str(e)}")
125
+ return base_model_path
126
+
127
  def main():
128
  """主函数"""
129
  logger.info("🔍 开始模型检查...")
gradio_webrtc_server.py CHANGED
@@ -24,7 +24,8 @@ class SingleMouseProcessor:
24
  """单鼠姿态检测处理器"""
25
 
26
  def __init__(self, model_path: str = "models/kunin-mice-pose.v0.1.5n.pt"):
27
- self.model_path = model_path
 
28
  self.model = None
29
  self.lock = Lock()
30
  self.frame_count = 0
@@ -57,6 +58,40 @@ class SingleMouseProcessor:
57
  # 加载模型
58
  self._load_model()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def _load_model(self):
61
  """加载YOLO模型"""
62
  try:
 
24
  """单鼠姿态检测处理器"""
25
 
26
  def __init__(self, model_path: str = "models/kunin-mice-pose.v0.1.5n.pt"):
27
+ self.base_model_path = model_path
28
+ self.model_path = self._select_model_path()
29
  self.model = None
30
  self.lock = Lock()
31
  self.frame_count = 0
 
58
  # 加载模型
59
  self._load_model()
60
 
61
+ def _select_model_path(self) -> str:
62
+ """根据GPU情况选择模型路径"""
63
+ try:
64
+ import torch
65
+ # 检测是否有NVIDIA GPU
66
+ if torch.cuda.is_available():
67
+ nvidia_gpu_found = False
68
+ for i in range(torch.cuda.device_count()):
69
+ gpu_name = torch.cuda.get_device_name(i).lower()
70
+ if 'nvidia' in gpu_name:
71
+ nvidia_gpu_found = True
72
+ break
73
+
74
+ if nvidia_gpu_found:
75
+ # 构建.engine模型路径
76
+ engine_path = self.base_model_path.replace('.pt', '.engine')
77
+ if os.path.exists(engine_path):
78
+ logger.info(f"🚀 检测到NVIDIA GPU,使用TensorRT模型: {engine_path}")
79
+ return engine_path
80
+ else:
81
+ logger.info(f"⚠️ NVIDIA GPU已检测到,但TensorRT模型不存在: {engine_path}")
82
+ logger.info(f"📍 使用PyTorch模型: {self.base_model_path}")
83
+ return self.base_model_path
84
+ else:
85
+ logger.info(f"📍 检测到GPU但非NVIDIA,使用PyTorch模型: {self.base_model_path}")
86
+ return self.base_model_path
87
+ else:
88
+ logger.info(f"📍 未检测到GPU,使用CPU模式,PyTorch模型: {self.base_model_path}")
89
+ return self.base_model_path
90
+
91
+ except Exception as e:
92
+ logger.warning(f"⚠️ GPU检测失败,使用默认模型: {str(e)}")
93
+ return self.base_model_path
94
+
95
  def _load_model(self):
96
  """加载YOLO模型"""
97
  try:
inspect_model.py CHANGED
@@ -15,7 +15,10 @@ logger = logging.getLogger(__name__)
15
 
16
  def inspect_model():
17
  """检查模型详细信息"""
18
- model_path = "models/kunin-mice-pose.v0.1.5n.pt"
 
 
 
19
 
20
  try:
21
  from ultralytics import YOLO
@@ -122,6 +125,40 @@ def inspect_model():
122
  logger.error(traceback.format_exc())
123
  return False
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def main():
126
  """主函数"""
127
  logger.info("🔍 开始模型详细检查...")
 
15
 
16
  def inspect_model():
17
  """检查模型详细信息"""
18
+ base_model_path = "models/kunin-mice-pose.v0.1.5n.pt"
19
+
20
+ # 选择模型路径(与SingleMouseProcessor保持一致)
21
+ model_path = select_model_path(base_model_path)
22
 
23
  try:
24
  from ultralytics import YOLO
 
125
  logger.error(traceback.format_exc())
126
  return False
127
 
128
+ def select_model_path(base_model_path: str) -> str:
129
+ """根据GPU情况选择模型路径"""
130
+ try:
131
+ import torch
132
+ # 检测是否有NVIDIA GPU
133
+ if torch.cuda.is_available():
134
+ nvidia_gpu_found = False
135
+ for i in range(torch.cuda.device_count()):
136
+ gpu_name = torch.cuda.get_device_name(i).lower()
137
+ if 'nvidia' in gpu_name:
138
+ nvidia_gpu_found = True
139
+ break
140
+
141
+ if nvidia_gpu_found:
142
+ # 构建.engine模型路径
143
+ engine_path = base_model_path.replace('.pt', '.engine')
144
+ if os.path.exists(engine_path):
145
+ logger.info(f"🚀 检测到NVIDIA GPU,使用TensorRT模型: {engine_path}")
146
+ return engine_path
147
+ else:
148
+ logger.info(f"⚠️ NVIDIA GPU已检测到,但TensorRT模型不存在: {engine_path}")
149
+ logger.info(f"📍 使用PyTorch模型: {base_model_path}")
150
+ return base_model_path
151
+ else:
152
+ logger.info(f"📍 检测到GPU但非NVIDIA,使用PyTorch模型: {base_model_path}")
153
+ return base_model_path
154
+ else:
155
+ logger.info(f"📍 未检测到GPU,使用CPU模式,PyTorch模型: {base_model_path}")
156
+ return base_model_path
157
+
158
+ except Exception as e:
159
+ logger.warning(f"⚠️ GPU检测失败,使用默认模型: {str(e)}")
160
+ return base_model_path
161
+
162
  def main():
163
  """主函数"""
164
  logger.info("🔍 开始模型详细检查...")