Hakureirm commited on
Commit
402fd16
·
1 Parent(s): 2fe14bb

Add image back

Browse files
demo_visualization_api.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 演示可视化API功能的示例脚本
4
+ 展示如何获取绘制了关键点和骨架的图像
5
+ """
6
+
7
+ import requests
8
+ import base64
9
+ import json
10
+ import cv2
11
+ import numpy as np
12
+ import time
13
+
14
+ def demo_api_usage():
15
+ """演示API使用方法"""
16
+
17
+ print("🎯 单鼠姿态检测可视化API演示")
18
+ print("=" * 50)
19
+
20
+ # 1. 准备图像数据
21
+ # 这里使用您提供的示例格式
22
+ image_path = "test_mouse.jpg"
23
+
24
+ if not cv2.imread(image_path) is not None:
25
+ print("❌ 找不到测试图像,请确保test_mouse.jpg存在")
26
+ return
27
+
28
+ # 将图像转换为Base64
29
+ with open(image_path, "rb") as image_file:
30
+ image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
31
+
32
+ # 2. 准备API请求数据(按照您的示例格式)
33
+ request_data = {
34
+ "image": image_base64,
35
+ "conf_threshold": 0.3,
36
+ "frame_id": 1,
37
+ "timestamp": 1641234567.123,
38
+ "return_image": True, # 🔥 关键参数:返回可视化图像
39
+ "draw_bbox": False # 只绘制关键点和骨架,不绘制边界框
40
+ }
41
+
42
+ print("📤 发送API请求...")
43
+ print(f"请求参数: conf_threshold={request_data['conf_threshold']}")
44
+ print(f" return_image={request_data['return_image']}")
45
+ print(f" draw_bbox={request_data['draw_bbox']}")
46
+
47
+ # 3. 发送请求
48
+ api_url = "http://localhost:8765/api/process_frame"
49
+
50
+ try:
51
+ start_time = time.time()
52
+ response = requests.post(api_url, json=request_data)
53
+ request_time = time.time() - start_time
54
+
55
+ if response.status_code == 200:
56
+ result = response.json()
57
+
58
+ print(f"✅ 请求成功!(耗时: {request_time:.3f}s)")
59
+ print("\n📊 检测结果:")
60
+ print(f" - 检测到小鼠: {result['mouse_detected']}")
61
+
62
+ if result['mouse_detected']:
63
+ print(f" - 检测置信度: {result['confidence']:.3f}")
64
+ print(f" - 关键点数量: {len(result['keypoints'])}")
65
+ print(f" - 处理FPS: {result['fps']:.1f}")
66
+
67
+ # 显示关键点信息
68
+ print("\n🔍 检测到的关键点:")
69
+ for kpt in result['keypoints']:
70
+ print(f" {kpt['id']}: {kpt['name']} "
71
+ f"({kpt['x']:.1f}, {kpt['y']:.1f}) "
72
+ f"conf={kpt['confidence']:.3f}")
73
+
74
+ # 4. 🎨 处理可视化图像
75
+ if "visualization_image" in result and result["visualization_image"]:
76
+ print("\n🎨 可视化图像处理:")
77
+ print(" ✅ 接收到绘制了关键点和骨架的图像")
78
+
79
+ # 解码Base64图像
80
+ img_bytes = base64.b64decode(result["visualization_image"])
81
+ nparr = np.frombuffer(img_bytes, np.uint8)
82
+ vis_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
83
+
84
+ # 保存可视化结果
85
+ output_path = "demo_visualization_result.jpg"
86
+ cv2.imwrite(output_path, vis_image)
87
+
88
+ print(f" 💾 可视化图像已保存: {output_path}")
89
+ print(f" 📐 图像尺寸: {vis_image.shape}")
90
+
91
+ # 显示图像信息
92
+ print(f" 📊 Base64数据大小: {len(result['visualization_image'])} 字符")
93
+ print(f" 📁 文件大小: {len(img_bytes)} 字节")
94
+
95
+ else:
96
+ print("❌ 没有接收到可视化图像")
97
+
98
+ print("\n" + "=" * 50)
99
+ print("🎉 演示完成!")
100
+ print("\n💡 使用说明:")
101
+ print(" 1. 设置 return_image=true 来获取可视化图像")
102
+ print(" 2. 设置 draw_bbox=false 只绘制关键点和骨架(如您所需)")
103
+ print(" 3. 设置 draw_bbox=true 同时绘制边界框")
104
+ print(" 4. 返回的 visualization_image 是Base64编码的JPEG图像")
105
+
106
+ else:
107
+ print(f"❌ API请求失败: {response.status_code}")
108
+ print(f"错误信息: {response.text}")
109
+
110
+ except Exception as e:
111
+ print(f"❌ 请求出错: {str(e)}")
112
+
113
+ def compare_with_without_visualization():
114
+ """对比有无可视化的API响应"""
115
+
116
+ print("\n" + "=" * 50)
117
+ print("🔄 对比测试:有无可视化图像的API响应")
118
+ print("=" * 50)
119
+
120
+ # 准备图像
121
+ with open("test_mouse.jpg", "rb") as image_file:
122
+ image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
123
+
124
+ # 测试配置
125
+ configs = [
126
+ {
127
+ "name": "不返回可视化图像",
128
+ "data": {
129
+ "image": image_base64,
130
+ "conf_threshold": 0.3,
131
+ "return_image": False
132
+ }
133
+ },
134
+ {
135
+ "name": "返回可视化图像",
136
+ "data": {
137
+ "image": image_base64,
138
+ "conf_threshold": 0.3,
139
+ "return_image": True,
140
+ "draw_bbox": False
141
+ }
142
+ }
143
+ ]
144
+
145
+ for config in configs:
146
+ print(f"\n📋 测试: {config['name']}")
147
+
148
+ try:
149
+ response = requests.post("http://localhost:8765/api/process_frame",
150
+ json=config["data"])
151
+
152
+ if response.status_code == 200:
153
+ result = response.json()
154
+
155
+ # 计算响应大小
156
+ response_size = len(response.content)
157
+ has_vis_image = "visualization_image" in result and result["visualization_image"]
158
+
159
+ print(f" 📊 响应大小: {response_size:,} 字节")
160
+ print(f" 🖼️ 包含可视化图像: {'是' if has_vis_image else '否'}")
161
+
162
+ if has_vis_image:
163
+ vis_size = len(result["visualization_image"])
164
+ print(f" 📸 图像数据大小: {vis_size:,} 字符")
165
+
166
+ print(f" ⚡ 处理时间: {result['processing_time']:.3f}s")
167
+ print(f" 🎯 检测结果: {'检测到小鼠' if result['mouse_detected'] else '未检测到小鼠'}")
168
+
169
+ except Exception as e:
170
+ print(f" ❌ 测试失败: {str(e)}")
171
+
172
+ if __name__ == "__main__":
173
+ # 检查服务是否可用
174
+ try:
175
+ response = requests.get("http://localhost:8765/api/status", timeout=5)
176
+ if response.status_code != 200:
177
+ print("❌ WebRTC服务不可用,请先启动服务:")
178
+ print(" python gradio_webrtc_api.py")
179
+ exit(1)
180
+ except Exception:
181
+ print("❌ 无法连接到WebRTC服务,请先启动服务:")
182
+ print(" python gradio_webrtc_api.py")
183
+ exit(1)
184
+
185
+ # 运行演示
186
+ demo_api_usage()
187
+ compare_with_without_visualization()
demo_visualization_result.jpg ADDED

Git LFS Details

  • SHA256: cc7cea677be19c837544b088a442a496e48e5292379779a00c6ec2d82aa580cd
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
gradio_webrtc_api.py CHANGED
@@ -75,6 +75,8 @@ class FrameData(BaseModel):
75
  conf_threshold: float = 0.3
76
  timestamp: float = None
77
  frame_id: int = None
 
 
78
 
79
  class ProcessingResult(BaseModel):
80
  success: bool
@@ -87,6 +89,7 @@ class ProcessingResult(BaseModel):
87
  processing_time: float = 0.0
88
  fps: float = 0.0
89
  error: Optional[str] = None
 
90
 
91
  # API端点
92
  @app.get("/")
@@ -120,8 +123,13 @@ async def process_frame_api(frame_data: FrameData):
120
  if frame is None:
121
  raise HTTPException(status_code=400, detail="无法解码图像")
122
 
123
- # 处理帧
124
- result = processor.process_frame(frame, frame_data.conf_threshold)
 
 
 
 
 
125
 
126
  # 如果有frame_id,使用提供的值
127
  if frame_data.frame_id is not None:
@@ -229,9 +237,16 @@ async def process_frame_websocket(frame_data: Dict[str, Any]) -> Dict[str, Any]:
229
 
230
  # 获取配置参数
231
  conf_threshold = frame_data.get("conf_threshold", 0.3)
 
 
232
 
233
- # 处理帧
234
- result = processor.process_frame(frame, conf_threshold)
 
 
 
 
 
235
 
236
  # 保留原始的frame_id和timestamp(如果提供)
237
  if "frame_id" in frame_data:
@@ -277,8 +292,13 @@ async def process_batch_frames(frames: List[FrameData]):
277
  })
278
  continue
279
 
280
- # 处理帧
281
- result = processor.process_frame(frame, frame_data.conf_threshold)
 
 
 
 
 
282
 
283
  # 保留原始信息
284
  if frame_data.frame_id is not None:
 
75
  conf_threshold: float = 0.3
76
  timestamp: float = None
77
  frame_id: int = None
78
+ return_image: bool = False # 是否返回可视化图像
79
+ draw_bbox: bool = False # 是否绘制边界框
80
 
81
  class ProcessingResult(BaseModel):
82
  success: bool
 
89
  processing_time: float = 0.0
90
  fps: float = 0.0
91
  error: Optional[str] = None
92
+ visualization_image: Optional[str] = None # Base64编码的可视化图像
93
 
94
  # API端点
95
  @app.get("/")
 
123
  if frame is None:
124
  raise HTTPException(status_code=400, detail="无法解码图像")
125
 
126
+ # 处理帧,支持返回可视化图像
127
+ result = processor.process_frame(
128
+ frame,
129
+ frame_data.conf_threshold,
130
+ return_image=frame_data.return_image,
131
+ draw_bbox=frame_data.draw_bbox
132
+ )
133
 
134
  # 如果有frame_id,使用提供的值
135
  if frame_data.frame_id is not None:
 
237
 
238
  # 获取配置参数
239
  conf_threshold = frame_data.get("conf_threshold", 0.3)
240
+ return_image = frame_data.get("return_image", False)
241
+ draw_bbox = frame_data.get("draw_bbox", False)
242
 
243
+ # 处理帧,支持返回可视化图像
244
+ result = processor.process_frame(
245
+ frame,
246
+ conf_threshold,
247
+ return_image=return_image,
248
+ draw_bbox=draw_bbox
249
+ )
250
 
251
  # 保留原始的frame_id和timestamp(如果提供)
252
  if "frame_id" in frame_data:
 
292
  })
293
  continue
294
 
295
+ # 处理帧,支持返回可视化图像
296
+ result = processor.process_frame(
297
+ frame,
298
+ frame_data.conf_threshold,
299
+ return_image=frame_data.return_image,
300
+ draw_bbox=frame_data.draw_bbox
301
+ )
302
 
303
  # 保留原始信息
304
  if frame_data.frame_id is not None:
gradio_webrtc_server.py CHANGED
@@ -121,13 +121,57 @@ class SingleMouseProcessor:
121
  logger.error(traceback.format_exc())
122
  return False
123
 
124
- def process_frame(self, frame: np.ndarray, conf_threshold: float = 0.3) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  """
126
  处理单帧图像并返回检测结果
127
 
128
  Args:
129
  frame: 输入图像帧
130
  conf_threshold: 置信度阈值
 
 
131
 
132
  Returns:
133
  包含检测结果的字典
@@ -220,6 +264,27 @@ class SingleMouseProcessor:
220
  "confidence": float(kpt[2])
221
  })
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  # 更新统计信息
224
  process_time = time.time() - start_time
225
  self.frame_count += 1
@@ -384,9 +449,50 @@ def create_gradio_app():
384
  "image": "base64_encoded_image",
385
  "conf_threshold": 0.3,
386
  "frame_id": 1,
387
- "timestamp": 1641234567.123
 
 
388
  }
389
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  """)
391
 
392
  return app
@@ -432,9 +538,16 @@ def process_webrtc_data(data: Dict[str, Any]) -> Dict[str, Any]:
432
 
433
  # 获取配置参数
434
  conf_threshold = data.get("conf_threshold", 0.3)
 
 
435
 
436
- # 处理帧
437
- result = processor.process_frame(frame, conf_threshold)
 
 
 
 
 
438
 
439
  return result
440
 
 
121
  logger.error(traceback.format_exc())
122
  return False
123
 
124
+ def _draw_keypoints_and_skeleton(self, frame: np.ndarray, keypoints: List[Dict], draw_bbox: bool = False, bbox: Dict = None) -> np.ndarray:
125
+ """
126
+ 在图像上绘制关键点和骨架
127
+
128
+ Args:
129
+ frame: 输入图像帧
130
+ keypoints: 关键点列表
131
+ draw_bbox: 是否绘制边界框
132
+ bbox: 边界框信息
133
+
134
+ Returns:
135
+ 绘制后的图像
136
+ """
137
+ vis_frame = frame.copy()
138
+
139
+ # 绘制边界框(如果需要)
140
+ if draw_bbox and bbox is not None:
141
+ cv2.rectangle(vis_frame,
142
+ (int(bbox["x1"]), int(bbox["y1"])),
143
+ (int(bbox["x2"]), int(bbox["y2"])),
144
+ (0, 255, 0), 2)
145
+
146
+ # 绘制关键点
147
+ for kpt in keypoints:
148
+ cv2.circle(vis_frame,
149
+ (int(kpt["x"]), int(kpt["y"])),
150
+ 5, (0, 0, 255), -1)
151
+ cv2.putText(vis_frame,
152
+ kpt["name"],
153
+ (int(kpt["x"]) + 5, int(kpt["y"]) - 5),
154
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
155
+
156
+ # 绘制骨架连接
157
+ kpt_dict = {kpt["id"]: (kpt["x"], kpt["y"]) for kpt in keypoints}
158
+ for connection in self.keypoint_connections:
159
+ if connection[0] in kpt_dict and connection[1] in kpt_dict:
160
+ pt1 = (int(kpt_dict[connection[0]][0]), int(kpt_dict[connection[0]][1]))
161
+ pt2 = (int(kpt_dict[connection[1]][0]), int(kpt_dict[connection[1]][1]))
162
+ cv2.line(vis_frame, pt1, pt2, (255, 0, 0), 2)
163
+
164
+ return vis_frame
165
+
166
+ def process_frame(self, frame: np.ndarray, conf_threshold: float = 0.3, return_image: bool = False, draw_bbox: bool = False) -> Dict[str, Any]:
167
  """
168
  处理单帧图像并返回检测结果
169
 
170
  Args:
171
  frame: 输入图像帧
172
  conf_threshold: 置信度阈值
173
+ return_image: 是否返回绘制了关键点的图像
174
+ draw_bbox: 是否在返回图像中绘制边界框
175
 
176
  Returns:
177
  包含检测结果的字典
 
264
  "confidence": float(kpt[2])
265
  })
266
 
267
+ # 生成可视化图像(如果需要)
268
+ if return_image and detection_data["mouse_detected"]:
269
+ vis_frame = self._draw_keypoints_and_skeleton(
270
+ frame,
271
+ detection_data["keypoints"],
272
+ draw_bbox=draw_bbox,
273
+ bbox=detection_data["bbox"]
274
+ )
275
+
276
+ # 将图像编码为Base64
277
+ _, buffer = cv2.imencode('.jpg', vis_frame)
278
+ image_base64 = base64.b64encode(buffer).decode('utf-8')
279
+ detection_data["visualization_image"] = image_base64
280
+ elif return_image:
281
+ # 即使没有检测到小鼠也返回原图
282
+ _, buffer = cv2.imencode('.jpg', frame)
283
+ image_base64 = base64.b64encode(buffer).decode('utf-8')
284
+ detection_data["visualization_image"] = image_base64
285
+ else:
286
+ logger.info("不需要生成可视化图像")
287
+
288
  # 更新统计信息
289
  process_time = time.time() - start_time
290
  self.frame_count += 1
 
449
  "image": "base64_encoded_image",
450
  "conf_threshold": 0.3,
451
  "frame_id": 1,
452
+ "timestamp": 1641234567.123,
453
+ "return_image": true,
454
+ "draw_bbox": false
455
  }
456
  ```
457
+
458
+ ### 新增参数说明
459
+ - **return_image** (bool, 可选): 是否返回绘制了关键点和骨架的可视化图像
460
+ - **draw_bbox** (bool, 可选): 是否在可视化图像中绘制边界框(需要return_image=true)
461
+
462
+ ### 响应格式
463
+ 当 `return_image=true` 时,响应会包含额外的字段:
464
+ ```json
465
+ {
466
+ "success": true,
467
+ "mouse_detected": true,
468
+ "keypoints": [...],
469
+ "visualization_image": "base64_encoded_image_with_keypoints_and_skeleton"
470
+ }
471
+ ```
472
+
473
+ ### 使用示例
474
+ ```javascript
475
+ // 只获取检测数据
476
+ fetch('/api/process_frame', {
477
+ method: 'POST',
478
+ headers: {'Content-Type': 'application/json'},
479
+ body: JSON.stringify({
480
+ image: base64Image,
481
+ return_image: false
482
+ })
483
+ });
484
+
485
+ // 获取检测数据和可视化图像
486
+ fetch('/api/process_frame', {
487
+ method: 'POST',
488
+ headers: {'Content-Type': 'application/json'},
489
+ body: JSON.stringify({
490
+ image: base64Image,
491
+ return_image: true,
492
+ draw_bbox: false // 只绘制关键点和骨架,不绘制边界框
493
+ })
494
+ });
495
+ ```
496
  """)
497
 
498
  return app
 
538
 
539
  # 获取配置参数
540
  conf_threshold = data.get("conf_threshold", 0.3)
541
+ return_image = data.get("return_image", False)
542
+ draw_bbox = data.get("draw_bbox", False)
543
 
544
+ # 处理帧,支持返回可视化图像
545
+ result = processor.process_frame(
546
+ frame,
547
+ conf_threshold,
548
+ return_image=return_image,
549
+ draw_bbox=draw_bbox
550
+ )
551
 
552
  return result
553
 
test_mouse.jpg CHANGED

Git LFS Details

  • SHA256: ddb0c43a5e575857e392996434085356b978223bd58973ee12de2972aa9ab1d9
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB

Git LFS Details

  • SHA256: 0aa47d294ab4714ea74f9bb73abb56bc08db1df60565f31244ea537ce6e0baf8
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB
test_result_2_返回图像不包括边界框.jpg ADDED

Git LFS Details

  • SHA256: 5bffd19c9e5041ccbcea0fa7705704c08179954ed127bc1340d2408679e16229
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB
test_result_3_返回图像包括边界框.jpg ADDED

Git LFS Details

  • SHA256: 624b623df6f99f8c1f81cea6fcd448a7562f635e5570f7efaad18c2588434c7c
  • Pointer size: 131 Bytes
  • Size of remote file: 213 kB
test_visualization_api.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 测试可视化图像API功能
4
+ """
5
+
6
+ import requests
7
+ import base64
8
+ import json
9
+ import cv2
10
+ import numpy as np
11
+ import time
12
+ import os
13
+
14
+ def load_test_image():
15
+ """加载测试图像"""
16
+ # 尝试查找测试图像
17
+ test_images = ["test_mouse.jpg", "test_mouse.png", "sample.jpg"]
18
+
19
+ for img_path in test_images:
20
+ if os.path.exists(img_path):
21
+ print(f"✅ 找到测试图像: {img_path}")
22
+ return cv2.imread(img_path)
23
+
24
+ # 如果没有找到测试图像,创建一个简单的
25
+ print("📸 创建合成测试图像...")
26
+ test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8)
27
+
28
+ # 添加简单的小鼠形状
29
+ cv2.circle(test_img, (320, 300), 50, (100, 100, 100), -1) # 头部
30
+ cv2.ellipse(test_img, (320, 380), (80, 120), 0, 0, 360, (120, 120, 120), -1) # 身体
31
+ cv2.ellipse(test_img, (320, 500), (20, 80), 0, 0, 360, (90, 90, 90), -1) # 尾巴
32
+
33
+ return test_img
34
+
35
+ def image_to_base64(image):
36
+ """将OpenCV图像转换为Base64字符串"""
37
+ _, buffer = cv2.imencode('.jpg', image)
38
+ return base64.b64encode(buffer).decode('utf-8')
39
+
40
+ def base64_to_image(base64_str):
41
+ """将Base64字符串转换为OpenCV图像"""
42
+ img_bytes = base64.b64decode(base64_str)
43
+ nparr = np.frombuffer(img_bytes, np.uint8)
44
+ return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
45
+
46
+ def test_api_with_visualization():
47
+ """测试带可视化功能的API"""
48
+ print("=== 测试可视化图像API功能 ===")
49
+
50
+ # 加载测试图像
51
+ test_image = load_test_image()
52
+ if test_image is None:
53
+ print("❌ 无法加载测试图像")
54
+ return False
55
+
56
+ print(f"测试图像尺寸: {test_image.shape}")
57
+
58
+ # 转换为Base64
59
+ image_base64 = image_to_base64(test_image)
60
+
61
+ # 测试不同的配置
62
+ test_configs = [
63
+ {
64
+ "name": "不返回图像",
65
+ "data": {
66
+ "image": image_base64,
67
+ "conf_threshold": 0.1,
68
+ "return_image": False,
69
+ "frame_id": 1,
70
+ "timestamp": time.time()
71
+ }
72
+ },
73
+ {
74
+ "name": "返回图像(不包括边界框)",
75
+ "data": {
76
+ "image": image_base64,
77
+ "conf_threshold": 0.1,
78
+ "return_image": True,
79
+ "draw_bbox": False,
80
+ "frame_id": 2,
81
+ "timestamp": time.time()
82
+ }
83
+ },
84
+ {
85
+ "name": "返回图像(包括边界框)",
86
+ "data": {
87
+ "image": image_base64,
88
+ "conf_threshold": 0.1,
89
+ "return_image": True,
90
+ "draw_bbox": True,
91
+ "frame_id": 3,
92
+ "timestamp": time.time()
93
+ }
94
+ }
95
+ ]
96
+
97
+ api_url = "http://localhost:8765/api/process_frame"
98
+
99
+ for i, config in enumerate(test_configs):
100
+ print(f"\n--- 测试 {i+1}: {config['name']} ---")
101
+
102
+ try:
103
+ # 发送请求
104
+ start_time = time.time()
105
+ response = requests.post(api_url, json=config["data"])
106
+ request_time = time.time() - start_time
107
+
108
+ if response.status_code == 200:
109
+ result = response.json()
110
+ print(f"✅ 请求成功 (耗时: {request_time:.3f}s)")
111
+ print(f"检测到小鼠: {result['mouse_detected']}")
112
+
113
+ if result['mouse_detected']:
114
+ print(f"置信度: {result['confidence']:.3f}")
115
+ print(f"关键点数量: {len(result['keypoints'])}")
116
+ print(f"处理FPS: {result['fps']:.1f}")
117
+
118
+ # 检查是否有可视化图像
119
+ if "visualization_image" in result and result["visualization_image"]:
120
+ print("✅ 包含可视化图像")
121
+
122
+ # 解码并保存可视化图像
123
+ vis_image = base64_to_image(result["visualization_image"])
124
+ output_filename = f"test_result_{i+1}_{config['name'].replace(' ', '_').replace('(', '').replace(')', '')}.jpg"
125
+ cv2.imwrite(output_filename, vis_image)
126
+ print(f"可视化图像已保存: {output_filename}")
127
+
128
+ else:
129
+ if "visualization_image" in result:
130
+ print(f"❌ visualization_image存在但为空: {result['visualization_image'] is None}")
131
+ else:
132
+ print("❌ 响应中没有visualization_image字段")
133
+ print(f"响应中的所有字段: {list(result.keys())}")
134
+ # 打印部分响应内容(不包括过长的字段)
135
+ debug_result = {k: v for k, v in result.items() if k != 'visualization_image'}
136
+ print(f"响应内容(除图像外): {json.dumps(debug_result, indent=2, ensure_ascii=False)}")
137
+
138
+ else:
139
+ print(f"❌ 请求失败: {response.status_code}")
140
+ print(f"错误信息: {response.text}")
141
+
142
+ except Exception as e:
143
+ print(f"❌ 请求出错: {str(e)}")
144
+
145
+ return True
146
+
147
+ def test_websocket_with_visualization():
148
+ """测试WebSocket的可视化功能"""
149
+ print("\n=== 测试WebSocket可视化功能 ===")
150
+
151
+ try:
152
+ import websockets
153
+ import asyncio
154
+
155
+ async def test_ws():
156
+ # 加载测试图像
157
+ test_image = load_test_image()
158
+ image_base64 = image_to_base64(test_image)
159
+
160
+ uri = "ws://localhost:8765/ws/stream"
161
+
162
+ async with websockets.connect(uri) as websocket:
163
+ print("✅ WebSocket连接成功")
164
+
165
+ # 发送带可视化的数据
166
+ frame_data = {
167
+ "image": image_base64,
168
+ "conf_threshold": 0.1, # 降低置信度阈值
169
+ "return_image": True,
170
+ "draw_bbox": False,
171
+ "frame_id": 100,
172
+ "timestamp": time.time()
173
+ }
174
+
175
+ await websocket.send(json.dumps(frame_data))
176
+
177
+ # 接收结果
178
+ response = await websocket.recv()
179
+ result = json.loads(response)
180
+
181
+ if result['success']:
182
+ print(f"✅ WebSocket处理成功")
183
+ print(f"检测到小鼠: {result['mouse_detected']}")
184
+
185
+ if "visualization_image" in result and result["visualization_image"]:
186
+ print("✅ 包含可视化图像")
187
+
188
+ # 保存WebSocket结果
189
+ vis_image = base64_to_image(result["visualization_image"])
190
+ cv2.imwrite("websocket_result.jpg", vis_image)
191
+ print("WebSocket可视化图像已保存: websocket_result.jpg")
192
+ else:
193
+ print("❌ WebSocket结果不包含可视化图像")
194
+ else:
195
+ print(f"❌ WebSocket处理失败: {result.get('error', 'Unknown error')}")
196
+
197
+ # 运行异步测试
198
+ asyncio.run(test_ws())
199
+ return True
200
+
201
+ except ImportError:
202
+ print("❌ 未安装websockets库,跳过WebSocket测试")
203
+ return False
204
+ except Exception as e:
205
+ print(f"❌ WebSocket测试出错: {str(e)}")
206
+ return False
207
+
208
+ def main():
209
+ """主函数"""
210
+ print("🚀 开始测试可视化图像API功能")
211
+ print("确保WebRTC服务已启动在 http://localhost:8765")
212
+
213
+ # 检查服务是否可用
214
+ try:
215
+ response = requests.get("http://localhost:8765/api/status", timeout=5)
216
+ if response.status_code == 200:
217
+ print("✅ WebRTC服务可用")
218
+ else:
219
+ print("❌ WebRTC服务不可用")
220
+ return
221
+ except Exception as e:
222
+ print(f"❌ 无法连接到WebRTC服务: {str(e)}")
223
+ print("请先启动服务: python gradio_webrtc_api.py")
224
+ return
225
+
226
+ # 测试REST API
227
+ test_api_with_visualization()
228
+
229
+ # 测试WebSocket
230
+ test_websocket_with_visualization()
231
+
232
+ print("\n🎉 测试完成!检查生成的图像文件查看结果。")
233
+
234
+ if __name__ == "__main__":
235
+ main()
websocket_result.jpg ADDED

Git LFS Details

  • SHA256: 5bffd19c9e5041ccbcea0fa7705704c08179954ed127bc1340d2408679e16229
  • Pointer size: 131 Bytes
  • Size of remote file: 212 kB