single-mouse-webrtc-pose / test_visualization_api.py
Hakureirm's picture
Add image back
402fd16
#!/usr/bin/env python3
"""
测试可视化图像API功能
"""
import requests
import base64
import json
import cv2
import numpy as np
import time
import os
def load_test_image():
"""加载测试图像"""
# 尝试查找测试图像
test_images = ["test_mouse.jpg", "test_mouse.png", "sample.jpg"]
for img_path in test_images:
if os.path.exists(img_path):
print(f"✅ 找到测试图像: {img_path}")
return cv2.imread(img_path)
# 如果没有找到测试图像,创建一个简单的
print("📸 创建合成测试图像...")
test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8)
# 添加简单的小鼠形状
cv2.circle(test_img, (320, 300), 50, (100, 100, 100), -1) # 头部
cv2.ellipse(test_img, (320, 380), (80, 120), 0, 0, 360, (120, 120, 120), -1) # 身体
cv2.ellipse(test_img, (320, 500), (20, 80), 0, 0, 360, (90, 90, 90), -1) # 尾巴
return test_img
def image_to_base64(image):
"""将OpenCV图像转换为Base64字符串"""
_, buffer = cv2.imencode('.jpg', image)
return base64.b64encode(buffer).decode('utf-8')
def base64_to_image(base64_str):
"""将Base64字符串转换为OpenCV图像"""
img_bytes = base64.b64decode(base64_str)
nparr = np.frombuffer(img_bytes, np.uint8)
return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
def test_api_with_visualization():
"""测试带可视化功能的API"""
print("=== 测试可视化图像API功能 ===")
# 加载测试图像
test_image = load_test_image()
if test_image is None:
print("❌ 无法加载测试图像")
return False
print(f"测试图像尺寸: {test_image.shape}")
# 转换为Base64
image_base64 = image_to_base64(test_image)
# 测试不同的配置
test_configs = [
{
"name": "不返回图像",
"data": {
"image": image_base64,
"conf_threshold": 0.1,
"return_image": False,
"frame_id": 1,
"timestamp": time.time()
}
},
{
"name": "返回图像(不包括边界框)",
"data": {
"image": image_base64,
"conf_threshold": 0.1,
"return_image": True,
"draw_bbox": False,
"frame_id": 2,
"timestamp": time.time()
}
},
{
"name": "返回图像(包括边界框)",
"data": {
"image": image_base64,
"conf_threshold": 0.1,
"return_image": True,
"draw_bbox": True,
"frame_id": 3,
"timestamp": time.time()
}
}
]
api_url = "http://localhost:8765/api/process_frame"
for i, config in enumerate(test_configs):
print(f"\n--- 测试 {i+1}: {config['name']} ---")
try:
# 发送请求
start_time = time.time()
response = requests.post(api_url, json=config["data"])
request_time = time.time() - start_time
if response.status_code == 200:
result = response.json()
print(f"✅ 请求成功 (耗时: {request_time:.3f}s)")
print(f"检测到小鼠: {result['mouse_detected']}")
if result['mouse_detected']:
print(f"置信度: {result['confidence']:.3f}")
print(f"关键点数量: {len(result['keypoints'])}")
print(f"处理FPS: {result['fps']:.1f}")
# 检查是否有可视化图像
if "visualization_image" in result and result["visualization_image"]:
print("✅ 包含可视化图像")
# 解码并保存可视化图像
vis_image = base64_to_image(result["visualization_image"])
output_filename = f"test_result_{i+1}_{config['name'].replace(' ', '_').replace('(', '').replace(')', '')}.jpg"
cv2.imwrite(output_filename, vis_image)
print(f"可视化图像已保存: {output_filename}")
else:
if "visualization_image" in result:
print(f"❌ visualization_image存在但为空: {result['visualization_image'] is None}")
else:
print("❌ 响应中没有visualization_image字段")
print(f"响应中的所有字段: {list(result.keys())}")
# 打印部分响应内容(不包括过长的字段)
debug_result = {k: v for k, v in result.items() if k != 'visualization_image'}
print(f"响应内容(除图像外): {json.dumps(debug_result, indent=2, ensure_ascii=False)}")
else:
print(f"❌ 请求失败: {response.status_code}")
print(f"错误信息: {response.text}")
except Exception as e:
print(f"❌ 请求出错: {str(e)}")
return True
def test_websocket_with_visualization():
"""测试WebSocket的可视化功能"""
print("\n=== 测试WebSocket可视化功能 ===")
try:
import websockets
import asyncio
async def test_ws():
# 加载测试图像
test_image = load_test_image()
image_base64 = image_to_base64(test_image)
uri = "ws://localhost:8765/ws/stream"
async with websockets.connect(uri) as websocket:
print("✅ WebSocket连接成功")
# 发送带可视化的数据
frame_data = {
"image": image_base64,
"conf_threshold": 0.1, # 降低置信度阈值
"return_image": True,
"draw_bbox": False,
"frame_id": 100,
"timestamp": time.time()
}
await websocket.send(json.dumps(frame_data))
# 接收结果
response = await websocket.recv()
result = json.loads(response)
if result['success']:
print(f"✅ WebSocket处理成功")
print(f"检测到小鼠: {result['mouse_detected']}")
if "visualization_image" in result and result["visualization_image"]:
print("✅ 包含可视化图像")
# 保存WebSocket结果
vis_image = base64_to_image(result["visualization_image"])
cv2.imwrite("websocket_result.jpg", vis_image)
print("WebSocket可视化图像已保存: websocket_result.jpg")
else:
print("❌ WebSocket结果不包含可视化图像")
else:
print(f"❌ WebSocket处理失败: {result.get('error', 'Unknown error')}")
# 运行异步测试
asyncio.run(test_ws())
return True
except ImportError:
print("❌ 未安装websockets库,跳过WebSocket测试")
return False
except Exception as e:
print(f"❌ WebSocket测试出错: {str(e)}")
return False
def main():
"""主函数"""
print("🚀 开始测试可视化图像API功能")
print("确保WebRTC服务已启动在 http://localhost:8765")
# 检查服务是否可用
try:
response = requests.get("http://localhost:8765/api/status", timeout=5)
if response.status_code == 200:
print("✅ WebRTC服务可用")
else:
print("❌ WebRTC服务不可用")
return
except Exception as e:
print(f"❌ 无法连接到WebRTC服务: {str(e)}")
print("请先启动服务: python gradio_webrtc_api.py")
return
# 测试REST API
test_api_with_visualization()
# 测试WebSocket
test_websocket_with_visualization()
print("\n🎉 测试完成!检查生成的图像文件查看结果。")
if __name__ == "__main__":
main()