Add image back
Browse files- demo_visualization_api.py +187 -0
- demo_visualization_result.jpg +3 -0
- gradio_webrtc_api.py +26 -6
- gradio_webrtc_server.py +117 -4
- test_mouse.jpg +2 -2
- test_result_2_返回图像不包括边界框.jpg +3 -0
- test_result_3_返回图像包括边界框.jpg +3 -0
- test_visualization_api.py +235 -0
- websocket_result.jpg +3 -0
demo_visualization_api.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
演示可视化API功能的示例脚本
|
| 4 |
+
展示如何获取绘制了关键点和骨架的图像
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import base64
|
| 9 |
+
import json
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
def demo_api_usage():
|
| 15 |
+
"""演示API使用方法"""
|
| 16 |
+
|
| 17 |
+
print("🎯 单鼠姿态检测可视化API演示")
|
| 18 |
+
print("=" * 50)
|
| 19 |
+
|
| 20 |
+
# 1. 准备图像数据
|
| 21 |
+
# 这里使用您提供的示例格式
|
| 22 |
+
image_path = "test_mouse.jpg"
|
| 23 |
+
|
| 24 |
+
if not cv2.imread(image_path) is not None:
|
| 25 |
+
print("❌ 找不到测试图像,请确保test_mouse.jpg存在")
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
# 将图像转换为Base64
|
| 29 |
+
with open(image_path, "rb") as image_file:
|
| 30 |
+
image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
|
| 31 |
+
|
| 32 |
+
# 2. 准备API请求数据(按照您的示例格式)
|
| 33 |
+
request_data = {
|
| 34 |
+
"image": image_base64,
|
| 35 |
+
"conf_threshold": 0.3,
|
| 36 |
+
"frame_id": 1,
|
| 37 |
+
"timestamp": 1641234567.123,
|
| 38 |
+
"return_image": True, # 🔥 关键参数:返回可视化图像
|
| 39 |
+
"draw_bbox": False # 只绘制关键点和骨架,不绘制边界框
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
print("📤 发送API请求...")
|
| 43 |
+
print(f"请求参数: conf_threshold={request_data['conf_threshold']}")
|
| 44 |
+
print(f" return_image={request_data['return_image']}")
|
| 45 |
+
print(f" draw_bbox={request_data['draw_bbox']}")
|
| 46 |
+
|
| 47 |
+
# 3. 发送请求
|
| 48 |
+
api_url = "http://localhost:8765/api/process_frame"
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
response = requests.post(api_url, json=request_data)
|
| 53 |
+
request_time = time.time() - start_time
|
| 54 |
+
|
| 55 |
+
if response.status_code == 200:
|
| 56 |
+
result = response.json()
|
| 57 |
+
|
| 58 |
+
print(f"✅ 请求成功!(耗时: {request_time:.3f}s)")
|
| 59 |
+
print("\n📊 检测结果:")
|
| 60 |
+
print(f" - 检测到小鼠: {result['mouse_detected']}")
|
| 61 |
+
|
| 62 |
+
if result['mouse_detected']:
|
| 63 |
+
print(f" - 检测置信度: {result['confidence']:.3f}")
|
| 64 |
+
print(f" - 关键点数量: {len(result['keypoints'])}")
|
| 65 |
+
print(f" - 处理FPS: {result['fps']:.1f}")
|
| 66 |
+
|
| 67 |
+
# 显示关键点信息
|
| 68 |
+
print("\n🔍 检测到的关键点:")
|
| 69 |
+
for kpt in result['keypoints']:
|
| 70 |
+
print(f" {kpt['id']}: {kpt['name']} "
|
| 71 |
+
f"({kpt['x']:.1f}, {kpt['y']:.1f}) "
|
| 72 |
+
f"conf={kpt['confidence']:.3f}")
|
| 73 |
+
|
| 74 |
+
# 4. 🎨 处理可视化图像
|
| 75 |
+
if "visualization_image" in result and result["visualization_image"]:
|
| 76 |
+
print("\n🎨 可视化图像处理:")
|
| 77 |
+
print(" ✅ 接收到绘制了关键点和骨架的图像")
|
| 78 |
+
|
| 79 |
+
# 解码Base64图像
|
| 80 |
+
img_bytes = base64.b64decode(result["visualization_image"])
|
| 81 |
+
nparr = np.frombuffer(img_bytes, np.uint8)
|
| 82 |
+
vis_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 83 |
+
|
| 84 |
+
# 保存可视化结果
|
| 85 |
+
output_path = "demo_visualization_result.jpg"
|
| 86 |
+
cv2.imwrite(output_path, vis_image)
|
| 87 |
+
|
| 88 |
+
print(f" 💾 可视化图像已保存: {output_path}")
|
| 89 |
+
print(f" 📐 图像尺寸: {vis_image.shape}")
|
| 90 |
+
|
| 91 |
+
# 显示图像信息
|
| 92 |
+
print(f" 📊 Base64数据大小: {len(result['visualization_image'])} 字符")
|
| 93 |
+
print(f" 📁 文件大小: {len(img_bytes)} 字节")
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
print("❌ 没有接收到可视化图像")
|
| 97 |
+
|
| 98 |
+
print("\n" + "=" * 50)
|
| 99 |
+
print("🎉 演示完成!")
|
| 100 |
+
print("\n💡 使用说明:")
|
| 101 |
+
print(" 1. 设置 return_image=true 来获取可视化图像")
|
| 102 |
+
print(" 2. 设置 draw_bbox=false 只绘制关键点和骨架(如您所需)")
|
| 103 |
+
print(" 3. 设置 draw_bbox=true 同时绘制边界框")
|
| 104 |
+
print(" 4. 返回的 visualization_image 是Base64编码的JPEG图像")
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
print(f"❌ API请求失败: {response.status_code}")
|
| 108 |
+
print(f"错误信息: {response.text}")
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"❌ 请求出错: {str(e)}")
|
| 112 |
+
|
| 113 |
+
def compare_with_without_visualization():
|
| 114 |
+
"""对比有无可视化的API响应"""
|
| 115 |
+
|
| 116 |
+
print("\n" + "=" * 50)
|
| 117 |
+
print("🔄 对比测试:有无可视化图像的API响应")
|
| 118 |
+
print("=" * 50)
|
| 119 |
+
|
| 120 |
+
# 准备图像
|
| 121 |
+
with open("test_mouse.jpg", "rb") as image_file:
|
| 122 |
+
image_base64 = base64.b64encode(image_file.read()).decode('utf-8')
|
| 123 |
+
|
| 124 |
+
# 测试配置
|
| 125 |
+
configs = [
|
| 126 |
+
{
|
| 127 |
+
"name": "不返回可视化图像",
|
| 128 |
+
"data": {
|
| 129 |
+
"image": image_base64,
|
| 130 |
+
"conf_threshold": 0.3,
|
| 131 |
+
"return_image": False
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"name": "返回可视化图像",
|
| 136 |
+
"data": {
|
| 137 |
+
"image": image_base64,
|
| 138 |
+
"conf_threshold": 0.3,
|
| 139 |
+
"return_image": True,
|
| 140 |
+
"draw_bbox": False
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
for config in configs:
|
| 146 |
+
print(f"\n📋 测试: {config['name']}")
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
response = requests.post("http://localhost:8765/api/process_frame",
|
| 150 |
+
json=config["data"])
|
| 151 |
+
|
| 152 |
+
if response.status_code == 200:
|
| 153 |
+
result = response.json()
|
| 154 |
+
|
| 155 |
+
# 计算响应大小
|
| 156 |
+
response_size = len(response.content)
|
| 157 |
+
has_vis_image = "visualization_image" in result and result["visualization_image"]
|
| 158 |
+
|
| 159 |
+
print(f" 📊 响应大小: {response_size:,} 字节")
|
| 160 |
+
print(f" 🖼️ 包含可视化图像: {'是' if has_vis_image else '否'}")
|
| 161 |
+
|
| 162 |
+
if has_vis_image:
|
| 163 |
+
vis_size = len(result["visualization_image"])
|
| 164 |
+
print(f" 📸 图像数据大小: {vis_size:,} 字符")
|
| 165 |
+
|
| 166 |
+
print(f" ⚡ 处理时间: {result['processing_time']:.3f}s")
|
| 167 |
+
print(f" 🎯 检测结果: {'检测到小鼠' if result['mouse_detected'] else '未检测到小鼠'}")
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f" ❌ 测试失败: {str(e)}")
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
# 检查服务是否可用
|
| 174 |
+
try:
|
| 175 |
+
response = requests.get("http://localhost:8765/api/status", timeout=5)
|
| 176 |
+
if response.status_code != 200:
|
| 177 |
+
print("❌ WebRTC服务不可用,请先启动服务:")
|
| 178 |
+
print(" python gradio_webrtc_api.py")
|
| 179 |
+
exit(1)
|
| 180 |
+
except Exception:
|
| 181 |
+
print("❌ 无法连接到WebRTC服务,请先启动服务:")
|
| 182 |
+
print(" python gradio_webrtc_api.py")
|
| 183 |
+
exit(1)
|
| 184 |
+
|
| 185 |
+
# 运行演示
|
| 186 |
+
demo_api_usage()
|
| 187 |
+
compare_with_without_visualization()
|
demo_visualization_result.jpg
ADDED
|
Git LFS Details
|
gradio_webrtc_api.py
CHANGED
|
@@ -75,6 +75,8 @@ class FrameData(BaseModel):
|
|
| 75 |
conf_threshold: float = 0.3
|
| 76 |
timestamp: float = None
|
| 77 |
frame_id: int = None
|
|
|
|
|
|
|
| 78 |
|
| 79 |
class ProcessingResult(BaseModel):
|
| 80 |
success: bool
|
|
@@ -87,6 +89,7 @@ class ProcessingResult(BaseModel):
|
|
| 87 |
processing_time: float = 0.0
|
| 88 |
fps: float = 0.0
|
| 89 |
error: Optional[str] = None
|
|
|
|
| 90 |
|
| 91 |
# API端点
|
| 92 |
@app.get("/")
|
|
@@ -120,8 +123,13 @@ async def process_frame_api(frame_data: FrameData):
|
|
| 120 |
if frame is None:
|
| 121 |
raise HTTPException(status_code=400, detail="无法解码图像")
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
result = processor.process_frame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
# 如果有frame_id,使用提供的值
|
| 127 |
if frame_data.frame_id is not None:
|
|
@@ -229,9 +237,16 @@ async def process_frame_websocket(frame_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 229 |
|
| 230 |
# 获取配置参数
|
| 231 |
conf_threshold = frame_data.get("conf_threshold", 0.3)
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
#
|
| 234 |
-
result = processor.process_frame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
# 保留原始的frame_id和timestamp(如果提供)
|
| 237 |
if "frame_id" in frame_data:
|
|
@@ -277,8 +292,13 @@ async def process_batch_frames(frames: List[FrameData]):
|
|
| 277 |
})
|
| 278 |
continue
|
| 279 |
|
| 280 |
-
#
|
| 281 |
-
result = processor.process_frame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# 保留原始信息
|
| 284 |
if frame_data.frame_id is not None:
|
|
|
|
| 75 |
conf_threshold: float = 0.3
|
| 76 |
timestamp: float = None
|
| 77 |
frame_id: int = None
|
| 78 |
+
return_image: bool = False # 是否返回可视化图像
|
| 79 |
+
draw_bbox: bool = False # 是否绘制边界框
|
| 80 |
|
| 81 |
class ProcessingResult(BaseModel):
|
| 82 |
success: bool
|
|
|
|
| 89 |
processing_time: float = 0.0
|
| 90 |
fps: float = 0.0
|
| 91 |
error: Optional[str] = None
|
| 92 |
+
visualization_image: Optional[str] = None # Base64编码的可视化图像
|
| 93 |
|
| 94 |
# API端点
|
| 95 |
@app.get("/")
|
|
|
|
| 123 |
if frame is None:
|
| 124 |
raise HTTPException(status_code=400, detail="无法解码图像")
|
| 125 |
|
| 126 |
+
# 处理帧,支持返回可视化图像
|
| 127 |
+
result = processor.process_frame(
|
| 128 |
+
frame,
|
| 129 |
+
frame_data.conf_threshold,
|
| 130 |
+
return_image=frame_data.return_image,
|
| 131 |
+
draw_bbox=frame_data.draw_bbox
|
| 132 |
+
)
|
| 133 |
|
| 134 |
# 如果有frame_id,使用提供的值
|
| 135 |
if frame_data.frame_id is not None:
|
|
|
|
| 237 |
|
| 238 |
# 获取配置参数
|
| 239 |
conf_threshold = frame_data.get("conf_threshold", 0.3)
|
| 240 |
+
return_image = frame_data.get("return_image", False)
|
| 241 |
+
draw_bbox = frame_data.get("draw_bbox", False)
|
| 242 |
|
| 243 |
+
# 处理帧,支持返回可视化图像
|
| 244 |
+
result = processor.process_frame(
|
| 245 |
+
frame,
|
| 246 |
+
conf_threshold,
|
| 247 |
+
return_image=return_image,
|
| 248 |
+
draw_bbox=draw_bbox
|
| 249 |
+
)
|
| 250 |
|
| 251 |
# 保留原始的frame_id和timestamp(如果提供)
|
| 252 |
if "frame_id" in frame_data:
|
|
|
|
| 292 |
})
|
| 293 |
continue
|
| 294 |
|
| 295 |
+
# 处理帧,支持返回可视化图像
|
| 296 |
+
result = processor.process_frame(
|
| 297 |
+
frame,
|
| 298 |
+
frame_data.conf_threshold,
|
| 299 |
+
return_image=frame_data.return_image,
|
| 300 |
+
draw_bbox=frame_data.draw_bbox
|
| 301 |
+
)
|
| 302 |
|
| 303 |
# 保留原始信息
|
| 304 |
if frame_data.frame_id is not None:
|
gradio_webrtc_server.py
CHANGED
|
@@ -121,13 +121,57 @@ class SingleMouseProcessor:
|
|
| 121 |
logger.error(traceback.format_exc())
|
| 122 |
return False
|
| 123 |
|
| 124 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
处理单帧图像并返回检测结果
|
| 127 |
|
| 128 |
Args:
|
| 129 |
frame: 输入图像帧
|
| 130 |
conf_threshold: 置信度阈值
|
|
|
|
|
|
|
| 131 |
|
| 132 |
Returns:
|
| 133 |
包含检测结果的字典
|
|
@@ -220,6 +264,27 @@ class SingleMouseProcessor:
|
|
| 220 |
"confidence": float(kpt[2])
|
| 221 |
})
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
# 更新统计信息
|
| 224 |
process_time = time.time() - start_time
|
| 225 |
self.frame_count += 1
|
|
@@ -384,9 +449,50 @@ def create_gradio_app():
|
|
| 384 |
"image": "base64_encoded_image",
|
| 385 |
"conf_threshold": 0.3,
|
| 386 |
"frame_id": 1,
|
| 387 |
-
"timestamp": 1641234567.123
|
|
|
|
|
|
|
| 388 |
}
|
| 389 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
""")
|
| 391 |
|
| 392 |
return app
|
|
@@ -432,9 +538,16 @@ def process_webrtc_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 432 |
|
| 433 |
# 获取配置参数
|
| 434 |
conf_threshold = data.get("conf_threshold", 0.3)
|
|
|
|
|
|
|
| 435 |
|
| 436 |
-
#
|
| 437 |
-
result = processor.process_frame(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
return result
|
| 440 |
|
|
|
|
| 121 |
logger.error(traceback.format_exc())
|
| 122 |
return False
|
| 123 |
|
| 124 |
+
def _draw_keypoints_and_skeleton(self, frame: np.ndarray, keypoints: List[Dict], draw_bbox: bool = False, bbox: Dict = None) -> np.ndarray:
|
| 125 |
+
"""
|
| 126 |
+
在图像上绘制关键点和骨架
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
frame: 输入图像帧
|
| 130 |
+
keypoints: 关键点列表
|
| 131 |
+
draw_bbox: 是否绘制边界框
|
| 132 |
+
bbox: 边界框信息
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
绘制后的图像
|
| 136 |
+
"""
|
| 137 |
+
vis_frame = frame.copy()
|
| 138 |
+
|
| 139 |
+
# 绘制边界框(如果需要)
|
| 140 |
+
if draw_bbox and bbox is not None:
|
| 141 |
+
cv2.rectangle(vis_frame,
|
| 142 |
+
(int(bbox["x1"]), int(bbox["y1"])),
|
| 143 |
+
(int(bbox["x2"]), int(bbox["y2"])),
|
| 144 |
+
(0, 255, 0), 2)
|
| 145 |
+
|
| 146 |
+
# 绘制关键点
|
| 147 |
+
for kpt in keypoints:
|
| 148 |
+
cv2.circle(vis_frame,
|
| 149 |
+
(int(kpt["x"]), int(kpt["y"])),
|
| 150 |
+
5, (0, 0, 255), -1)
|
| 151 |
+
cv2.putText(vis_frame,
|
| 152 |
+
kpt["name"],
|
| 153 |
+
(int(kpt["x"]) + 5, int(kpt["y"]) - 5),
|
| 154 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
| 155 |
+
|
| 156 |
+
# 绘制骨架连接
|
| 157 |
+
kpt_dict = {kpt["id"]: (kpt["x"], kpt["y"]) for kpt in keypoints}
|
| 158 |
+
for connection in self.keypoint_connections:
|
| 159 |
+
if connection[0] in kpt_dict and connection[1] in kpt_dict:
|
| 160 |
+
pt1 = (int(kpt_dict[connection[0]][0]), int(kpt_dict[connection[0]][1]))
|
| 161 |
+
pt2 = (int(kpt_dict[connection[1]][0]), int(kpt_dict[connection[1]][1]))
|
| 162 |
+
cv2.line(vis_frame, pt1, pt2, (255, 0, 0), 2)
|
| 163 |
+
|
| 164 |
+
return vis_frame
|
| 165 |
+
|
| 166 |
+
def process_frame(self, frame: np.ndarray, conf_threshold: float = 0.3, return_image: bool = False, draw_bbox: bool = False) -> Dict[str, Any]:
|
| 167 |
"""
|
| 168 |
处理单帧图像并返回检测结果
|
| 169 |
|
| 170 |
Args:
|
| 171 |
frame: 输入图像帧
|
| 172 |
conf_threshold: 置信度阈值
|
| 173 |
+
return_image: 是否返回绘制了关键点的图像
|
| 174 |
+
draw_bbox: 是否在返回图像中绘制边界框
|
| 175 |
|
| 176 |
Returns:
|
| 177 |
包含检测结果的字典
|
|
|
|
| 264 |
"confidence": float(kpt[2])
|
| 265 |
})
|
| 266 |
|
| 267 |
+
# 生成可视化图像(如果需要)
|
| 268 |
+
if return_image and detection_data["mouse_detected"]:
|
| 269 |
+
vis_frame = self._draw_keypoints_and_skeleton(
|
| 270 |
+
frame,
|
| 271 |
+
detection_data["keypoints"],
|
| 272 |
+
draw_bbox=draw_bbox,
|
| 273 |
+
bbox=detection_data["bbox"]
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# 将图像编码为Base64
|
| 277 |
+
_, buffer = cv2.imencode('.jpg', vis_frame)
|
| 278 |
+
image_base64 = base64.b64encode(buffer).decode('utf-8')
|
| 279 |
+
detection_data["visualization_image"] = image_base64
|
| 280 |
+
elif return_image:
|
| 281 |
+
# 即使没有检测到小鼠也返回原图
|
| 282 |
+
_, buffer = cv2.imencode('.jpg', frame)
|
| 283 |
+
image_base64 = base64.b64encode(buffer).decode('utf-8')
|
| 284 |
+
detection_data["visualization_image"] = image_base64
|
| 285 |
+
else:
|
| 286 |
+
logger.info("不需要生成可视化图像")
|
| 287 |
+
|
| 288 |
# 更新统计信息
|
| 289 |
process_time = time.time() - start_time
|
| 290 |
self.frame_count += 1
|
|
|
|
| 449 |
"image": "base64_encoded_image",
|
| 450 |
"conf_threshold": 0.3,
|
| 451 |
"frame_id": 1,
|
| 452 |
+
"timestamp": 1641234567.123,
|
| 453 |
+
"return_image": true,
|
| 454 |
+
"draw_bbox": false
|
| 455 |
}
|
| 456 |
```
|
| 457 |
+
|
| 458 |
+
### 新增参数说明
|
| 459 |
+
- **return_image** (bool, 可选): 是否返回绘制了关键点和骨架的可视化图像
|
| 460 |
+
- **draw_bbox** (bool, 可选): 是否在可视化图像中绘制边界框(需要return_image=true)
|
| 461 |
+
|
| 462 |
+
### 响应格式
|
| 463 |
+
当 `return_image=true` 时,响应会包含额外的字段:
|
| 464 |
+
```json
|
| 465 |
+
{
|
| 466 |
+
"success": true,
|
| 467 |
+
"mouse_detected": true,
|
| 468 |
+
"keypoints": [...],
|
| 469 |
+
"visualization_image": "base64_encoded_image_with_keypoints_and_skeleton"
|
| 470 |
+
}
|
| 471 |
+
```
|
| 472 |
+
|
| 473 |
+
### 使用示例
|
| 474 |
+
```javascript
|
| 475 |
+
// 只获取检测数据
|
| 476 |
+
fetch('/api/process_frame', {
|
| 477 |
+
method: 'POST',
|
| 478 |
+
headers: {'Content-Type': 'application/json'},
|
| 479 |
+
body: JSON.stringify({
|
| 480 |
+
image: base64Image,
|
| 481 |
+
return_image: false
|
| 482 |
+
})
|
| 483 |
+
});
|
| 484 |
+
|
| 485 |
+
// 获取检测数据和可视化图像
|
| 486 |
+
fetch('/api/process_frame', {
|
| 487 |
+
method: 'POST',
|
| 488 |
+
headers: {'Content-Type': 'application/json'},
|
| 489 |
+
body: JSON.stringify({
|
| 490 |
+
image: base64Image,
|
| 491 |
+
return_image: true,
|
| 492 |
+
draw_bbox: false // 只绘制关键点和骨架,不绘制边界框
|
| 493 |
+
})
|
| 494 |
+
});
|
| 495 |
+
```
|
| 496 |
""")
|
| 497 |
|
| 498 |
return app
|
|
|
|
| 538 |
|
| 539 |
# 获取配置参数
|
| 540 |
conf_threshold = data.get("conf_threshold", 0.3)
|
| 541 |
+
return_image = data.get("return_image", False)
|
| 542 |
+
draw_bbox = data.get("draw_bbox", False)
|
| 543 |
|
| 544 |
+
# 处理帧,支持返回可视化图像
|
| 545 |
+
result = processor.process_frame(
|
| 546 |
+
frame,
|
| 547 |
+
conf_threshold,
|
| 548 |
+
return_image=return_image,
|
| 549 |
+
draw_bbox=draw_bbox
|
| 550 |
+
)
|
| 551 |
|
| 552 |
return result
|
| 553 |
|
test_mouse.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
test_result_2_返回图像不包括边界框.jpg
ADDED
|
Git LFS Details
|
test_result_3_返回图像包括边界框.jpg
ADDED
|
Git LFS Details
|
test_visualization_api.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
测试可视化图像API功能
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import base64
|
| 8 |
+
import json
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import time
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
def load_test_image():
|
| 15 |
+
"""加载测试图像"""
|
| 16 |
+
# 尝试查找测试图像
|
| 17 |
+
test_images = ["test_mouse.jpg", "test_mouse.png", "sample.jpg"]
|
| 18 |
+
|
| 19 |
+
for img_path in test_images:
|
| 20 |
+
if os.path.exists(img_path):
|
| 21 |
+
print(f"✅ 找到测试图像: {img_path}")
|
| 22 |
+
return cv2.imread(img_path)
|
| 23 |
+
|
| 24 |
+
# 如果没有找到测试图像,创建一个简单的
|
| 25 |
+
print("📸 创建合成测试图像...")
|
| 26 |
+
test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8)
|
| 27 |
+
|
| 28 |
+
# 添加简单的小鼠形状
|
| 29 |
+
cv2.circle(test_img, (320, 300), 50, (100, 100, 100), -1) # 头部
|
| 30 |
+
cv2.ellipse(test_img, (320, 380), (80, 120), 0, 0, 360, (120, 120, 120), -1) # 身体
|
| 31 |
+
cv2.ellipse(test_img, (320, 500), (20, 80), 0, 0, 360, (90, 90, 90), -1) # 尾巴
|
| 32 |
+
|
| 33 |
+
return test_img
|
| 34 |
+
|
| 35 |
+
def image_to_base64(image):
|
| 36 |
+
"""将OpenCV图像转换为Base64字符串"""
|
| 37 |
+
_, buffer = cv2.imencode('.jpg', image)
|
| 38 |
+
return base64.b64encode(buffer).decode('utf-8')
|
| 39 |
+
|
| 40 |
+
def base64_to_image(base64_str):
|
| 41 |
+
"""将Base64字符串转换为OpenCV图像"""
|
| 42 |
+
img_bytes = base64.b64decode(base64_str)
|
| 43 |
+
nparr = np.frombuffer(img_bytes, np.uint8)
|
| 44 |
+
return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 45 |
+
|
| 46 |
+
def test_api_with_visualization():
|
| 47 |
+
"""测试带可视化功能的API"""
|
| 48 |
+
print("=== 测试可视化图像API功能 ===")
|
| 49 |
+
|
| 50 |
+
# 加载测试图像
|
| 51 |
+
test_image = load_test_image()
|
| 52 |
+
if test_image is None:
|
| 53 |
+
print("❌ 无法加载测试图像")
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
print(f"测试图像尺寸: {test_image.shape}")
|
| 57 |
+
|
| 58 |
+
# 转换为Base64
|
| 59 |
+
image_base64 = image_to_base64(test_image)
|
| 60 |
+
|
| 61 |
+
# 测试不同的配置
|
| 62 |
+
test_configs = [
|
| 63 |
+
{
|
| 64 |
+
"name": "不返回图像",
|
| 65 |
+
"data": {
|
| 66 |
+
"image": image_base64,
|
| 67 |
+
"conf_threshold": 0.1,
|
| 68 |
+
"return_image": False,
|
| 69 |
+
"frame_id": 1,
|
| 70 |
+
"timestamp": time.time()
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"name": "返回图像(不包括边界框)",
|
| 75 |
+
"data": {
|
| 76 |
+
"image": image_base64,
|
| 77 |
+
"conf_threshold": 0.1,
|
| 78 |
+
"return_image": True,
|
| 79 |
+
"draw_bbox": False,
|
| 80 |
+
"frame_id": 2,
|
| 81 |
+
"timestamp": time.time()
|
| 82 |
+
}
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "返回图像(包括边界框)",
|
| 86 |
+
"data": {
|
| 87 |
+
"image": image_base64,
|
| 88 |
+
"conf_threshold": 0.1,
|
| 89 |
+
"return_image": True,
|
| 90 |
+
"draw_bbox": True,
|
| 91 |
+
"frame_id": 3,
|
| 92 |
+
"timestamp": time.time()
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
api_url = "http://localhost:8765/api/process_frame"
|
| 98 |
+
|
| 99 |
+
for i, config in enumerate(test_configs):
|
| 100 |
+
print(f"\n--- 测试 {i+1}: {config['name']} ---")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# 发送请求
|
| 104 |
+
start_time = time.time()
|
| 105 |
+
response = requests.post(api_url, json=config["data"])
|
| 106 |
+
request_time = time.time() - start_time
|
| 107 |
+
|
| 108 |
+
if response.status_code == 200:
|
| 109 |
+
result = response.json()
|
| 110 |
+
print(f"✅ 请求成功 (耗时: {request_time:.3f}s)")
|
| 111 |
+
print(f"检测到小鼠: {result['mouse_detected']}")
|
| 112 |
+
|
| 113 |
+
if result['mouse_detected']:
|
| 114 |
+
print(f"置信度: {result['confidence']:.3f}")
|
| 115 |
+
print(f"关键点数量: {len(result['keypoints'])}")
|
| 116 |
+
print(f"处理FPS: {result['fps']:.1f}")
|
| 117 |
+
|
| 118 |
+
# 检查是否有可视化图像
|
| 119 |
+
if "visualization_image" in result and result["visualization_image"]:
|
| 120 |
+
print("✅ 包含可视化图像")
|
| 121 |
+
|
| 122 |
+
# 解码并保存可视化图像
|
| 123 |
+
vis_image = base64_to_image(result["visualization_image"])
|
| 124 |
+
output_filename = f"test_result_{i+1}_{config['name'].replace(' ', '_').replace('(', '').replace(')', '')}.jpg"
|
| 125 |
+
cv2.imwrite(output_filename, vis_image)
|
| 126 |
+
print(f"可视化图像已保存: {output_filename}")
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
if "visualization_image" in result:
|
| 130 |
+
print(f"❌ visualization_image存在但为空: {result['visualization_image'] is None}")
|
| 131 |
+
else:
|
| 132 |
+
print("❌ 响应中没有visualization_image字段")
|
| 133 |
+
print(f"响应中的所有字段: {list(result.keys())}")
|
| 134 |
+
# 打印部分响应内容(不包括过长的字段)
|
| 135 |
+
debug_result = {k: v for k, v in result.items() if k != 'visualization_image'}
|
| 136 |
+
print(f"响应内容(除图像外): {json.dumps(debug_result, indent=2, ensure_ascii=False)}")
|
| 137 |
+
|
| 138 |
+
else:
|
| 139 |
+
print(f"❌ 请求失败: {response.status_code}")
|
| 140 |
+
print(f"错误信息: {response.text}")
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"❌ 请求出错: {str(e)}")
|
| 144 |
+
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
def test_websocket_with_visualization():
|
| 148 |
+
"""测试WebSocket的可视化功能"""
|
| 149 |
+
print("\n=== 测试WebSocket可视化功能 ===")
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
import websockets
|
| 153 |
+
import asyncio
|
| 154 |
+
|
| 155 |
+
async def test_ws():
|
| 156 |
+
# 加载测试图像
|
| 157 |
+
test_image = load_test_image()
|
| 158 |
+
image_base64 = image_to_base64(test_image)
|
| 159 |
+
|
| 160 |
+
uri = "ws://localhost:8765/ws/stream"
|
| 161 |
+
|
| 162 |
+
async with websockets.connect(uri) as websocket:
|
| 163 |
+
print("✅ WebSocket连接成功")
|
| 164 |
+
|
| 165 |
+
# 发送带可视化的数据
|
| 166 |
+
frame_data = {
|
| 167 |
+
"image": image_base64,
|
| 168 |
+
"conf_threshold": 0.1, # 降低置信度阈值
|
| 169 |
+
"return_image": True,
|
| 170 |
+
"draw_bbox": False,
|
| 171 |
+
"frame_id": 100,
|
| 172 |
+
"timestamp": time.time()
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
await websocket.send(json.dumps(frame_data))
|
| 176 |
+
|
| 177 |
+
# 接收结果
|
| 178 |
+
response = await websocket.recv()
|
| 179 |
+
result = json.loads(response)
|
| 180 |
+
|
| 181 |
+
if result['success']:
|
| 182 |
+
print(f"✅ WebSocket处理成功")
|
| 183 |
+
print(f"检测到小鼠: {result['mouse_detected']}")
|
| 184 |
+
|
| 185 |
+
if "visualization_image" in result and result["visualization_image"]:
|
| 186 |
+
print("✅ 包含可视化图像")
|
| 187 |
+
|
| 188 |
+
# 保存WebSocket结果
|
| 189 |
+
vis_image = base64_to_image(result["visualization_image"])
|
| 190 |
+
cv2.imwrite("websocket_result.jpg", vis_image)
|
| 191 |
+
print("WebSocket可视化图像已保存: websocket_result.jpg")
|
| 192 |
+
else:
|
| 193 |
+
print("❌ WebSocket结果不包含可视化图像")
|
| 194 |
+
else:
|
| 195 |
+
print(f"❌ WebSocket处理失败: {result.get('error', 'Unknown error')}")
|
| 196 |
+
|
| 197 |
+
# 运行异步测试
|
| 198 |
+
asyncio.run(test_ws())
|
| 199 |
+
return True
|
| 200 |
+
|
| 201 |
+
except ImportError:
|
| 202 |
+
print("❌ 未安装websockets库,跳过WebSocket测试")
|
| 203 |
+
return False
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"❌ WebSocket测试出错: {str(e)}")
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
def main():
|
| 209 |
+
"""主函数"""
|
| 210 |
+
print("🚀 开始测试可视化图像API功能")
|
| 211 |
+
print("确保WebRTC服务已启动在 http://localhost:8765")
|
| 212 |
+
|
| 213 |
+
# 检查服务是否可用
|
| 214 |
+
try:
|
| 215 |
+
response = requests.get("http://localhost:8765/api/status", timeout=5)
|
| 216 |
+
if response.status_code == 200:
|
| 217 |
+
print("✅ WebRTC服务可用")
|
| 218 |
+
else:
|
| 219 |
+
print("❌ WebRTC服务不可用")
|
| 220 |
+
return
|
| 221 |
+
except Exception as e:
|
| 222 |
+
print(f"❌ 无法连接到WebRTC服务: {str(e)}")
|
| 223 |
+
print("请先启动服务: python gradio_webrtc_api.py")
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
# 测试REST API
|
| 227 |
+
test_api_with_visualization()
|
| 228 |
+
|
| 229 |
+
# 测试WebSocket
|
| 230 |
+
test_websocket_with_visualization()
|
| 231 |
+
|
| 232 |
+
print("\n🎉 测试完成!检查生成的图像文件查看结果。")
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
main()
|
websocket_result.jpg
ADDED
|
Git LFS Details
|