|
|
|
|
|
""" |
|
|
测试可视化图像API功能 |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import base64 |
|
|
import json |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import time |
|
|
import os |
|
|
|
|
|
def load_test_image(): |
|
|
"""加载测试图像""" |
|
|
|
|
|
test_images = ["test_mouse.jpg", "test_mouse.png", "sample.jpg"] |
|
|
|
|
|
for img_path in test_images: |
|
|
if os.path.exists(img_path): |
|
|
print(f"✅ 找到测试图像: {img_path}") |
|
|
return cv2.imread(img_path) |
|
|
|
|
|
|
|
|
print("📸 创建合成测试图像...") |
|
|
test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
cv2.circle(test_img, (320, 300), 50, (100, 100, 100), -1) |
|
|
cv2.ellipse(test_img, (320, 380), (80, 120), 0, 0, 360, (120, 120, 120), -1) |
|
|
cv2.ellipse(test_img, (320, 500), (20, 80), 0, 0, 360, (90, 90, 90), -1) |
|
|
|
|
|
return test_img |
|
|
|
|
|
def image_to_base64(image): |
|
|
"""将OpenCV图像转换为Base64字符串""" |
|
|
_, buffer = cv2.imencode('.jpg', image) |
|
|
return base64.b64encode(buffer).decode('utf-8') |
|
|
|
|
|
def base64_to_image(base64_str): |
|
|
"""将Base64字符串转换为OpenCV图像""" |
|
|
img_bytes = base64.b64decode(base64_str) |
|
|
nparr = np.frombuffer(img_bytes, np.uint8) |
|
|
return cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
|
|
def test_api_with_visualization(): |
|
|
"""测试带可视化功能的API""" |
|
|
print("=== 测试可视化图像API功能 ===") |
|
|
|
|
|
|
|
|
test_image = load_test_image() |
|
|
if test_image is None: |
|
|
print("❌ 无法加载测试图像") |
|
|
return False |
|
|
|
|
|
print(f"测试图像尺寸: {test_image.shape}") |
|
|
|
|
|
|
|
|
image_base64 = image_to_base64(test_image) |
|
|
|
|
|
|
|
|
test_configs = [ |
|
|
{ |
|
|
"name": "不返回图像", |
|
|
"data": { |
|
|
"image": image_base64, |
|
|
"conf_threshold": 0.1, |
|
|
"return_image": False, |
|
|
"frame_id": 1, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "返回图像(不包括边界框)", |
|
|
"data": { |
|
|
"image": image_base64, |
|
|
"conf_threshold": 0.1, |
|
|
"return_image": True, |
|
|
"draw_bbox": False, |
|
|
"frame_id": 2, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
}, |
|
|
{ |
|
|
"name": "返回图像(包括边界框)", |
|
|
"data": { |
|
|
"image": image_base64, |
|
|
"conf_threshold": 0.1, |
|
|
"return_image": True, |
|
|
"draw_bbox": True, |
|
|
"frame_id": 3, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
api_url = "http://localhost:8765/api/process_frame" |
|
|
|
|
|
for i, config in enumerate(test_configs): |
|
|
print(f"\n--- 测试 {i+1}: {config['name']} ---") |
|
|
|
|
|
try: |
|
|
|
|
|
start_time = time.time() |
|
|
response = requests.post(api_url, json=config["data"]) |
|
|
request_time = time.time() - start_time |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
print(f"✅ 请求成功 (耗时: {request_time:.3f}s)") |
|
|
print(f"检测到小鼠: {result['mouse_detected']}") |
|
|
|
|
|
if result['mouse_detected']: |
|
|
print(f"置信度: {result['confidence']:.3f}") |
|
|
print(f"关键点数量: {len(result['keypoints'])}") |
|
|
print(f"处理FPS: {result['fps']:.1f}") |
|
|
|
|
|
|
|
|
if "visualization_image" in result and result["visualization_image"]: |
|
|
print("✅ 包含可视化图像") |
|
|
|
|
|
|
|
|
vis_image = base64_to_image(result["visualization_image"]) |
|
|
output_filename = f"test_result_{i+1}_{config['name'].replace(' ', '_').replace('(', '').replace(')', '')}.jpg" |
|
|
cv2.imwrite(output_filename, vis_image) |
|
|
print(f"可视化图像已保存: {output_filename}") |
|
|
|
|
|
else: |
|
|
if "visualization_image" in result: |
|
|
print(f"❌ visualization_image存在但为空: {result['visualization_image'] is None}") |
|
|
else: |
|
|
print("❌ 响应中没有visualization_image字段") |
|
|
print(f"响应中的所有字段: {list(result.keys())}") |
|
|
|
|
|
debug_result = {k: v for k, v in result.items() if k != 'visualization_image'} |
|
|
print(f"响应内容(除图像外): {json.dumps(debug_result, indent=2, ensure_ascii=False)}") |
|
|
|
|
|
else: |
|
|
print(f"❌ 请求失败: {response.status_code}") |
|
|
print(f"错误信息: {response.text}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 请求出错: {str(e)}") |
|
|
|
|
|
return True |
|
|
|
|
|
def test_websocket_with_visualization(): |
|
|
"""测试WebSocket的可视化功能""" |
|
|
print("\n=== 测试WebSocket可视化功能 ===") |
|
|
|
|
|
try: |
|
|
import websockets |
|
|
import asyncio |
|
|
|
|
|
async def test_ws(): |
|
|
|
|
|
test_image = load_test_image() |
|
|
image_base64 = image_to_base64(test_image) |
|
|
|
|
|
uri = "ws://localhost:8765/ws/stream" |
|
|
|
|
|
async with websockets.connect(uri) as websocket: |
|
|
print("✅ WebSocket连接成功") |
|
|
|
|
|
|
|
|
frame_data = { |
|
|
"image": image_base64, |
|
|
"conf_threshold": 0.1, |
|
|
"return_image": True, |
|
|
"draw_bbox": False, |
|
|
"frame_id": 100, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
|
|
|
await websocket.send(json.dumps(frame_data)) |
|
|
|
|
|
|
|
|
response = await websocket.recv() |
|
|
result = json.loads(response) |
|
|
|
|
|
if result['success']: |
|
|
print(f"✅ WebSocket处理成功") |
|
|
print(f"检测到小鼠: {result['mouse_detected']}") |
|
|
|
|
|
if "visualization_image" in result and result["visualization_image"]: |
|
|
print("✅ 包含可视化图像") |
|
|
|
|
|
|
|
|
vis_image = base64_to_image(result["visualization_image"]) |
|
|
cv2.imwrite("websocket_result.jpg", vis_image) |
|
|
print("WebSocket可视化图像已保存: websocket_result.jpg") |
|
|
else: |
|
|
print("❌ WebSocket结果不包含可视化图像") |
|
|
else: |
|
|
print(f"❌ WebSocket处理失败: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
|
|
|
asyncio.run(test_ws()) |
|
|
return True |
|
|
|
|
|
except ImportError: |
|
|
print("❌ 未安装websockets库,跳过WebSocket测试") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"❌ WebSocket测试出错: {str(e)}") |
|
|
return False |
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
print("🚀 开始测试可视化图像API功能") |
|
|
print("确保WebRTC服务已启动在 http://localhost:8765") |
|
|
|
|
|
|
|
|
try: |
|
|
response = requests.get("http://localhost:8765/api/status", timeout=5) |
|
|
if response.status_code == 200: |
|
|
print("✅ WebRTC服务可用") |
|
|
else: |
|
|
print("❌ WebRTC服务不可用") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"❌ 无法连接到WebRTC服务: {str(e)}") |
|
|
print("请先启动服务: python gradio_webrtc_api.py") |
|
|
return |
|
|
|
|
|
|
|
|
test_api_with_visualization() |
|
|
|
|
|
|
|
|
test_websocket_with_visualization() |
|
|
|
|
|
print("\n🎉 测试完成!检查生成的图像文件查看结果。") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |