#!/usr/bin/env python3 """ 测试可视化图像API功能 """ import requests import base64 import json import cv2 import numpy as np import time import os def load_test_image(): """加载测试图像""" # 尝试查找测试图像 test_images = ["test_mouse.jpg", "test_mouse.png", "sample.jpg"] for img_path in test_images: if os.path.exists(img_path): print(f"✅ 找到测试图像: {img_path}") return cv2.imread(img_path) # 如果没有找到测试图像,创建一个简单的 print("📸 创建合成测试图像...") test_img = np.random.randint(50, 200, (640, 640, 3), dtype=np.uint8) # 添加简单的小鼠形状 cv2.circle(test_img, (320, 300), 50, (100, 100, 100), -1) # 头部 cv2.ellipse(test_img, (320, 380), (80, 120), 0, 0, 360, (120, 120, 120), -1) # 身体 cv2.ellipse(test_img, (320, 500), (20, 80), 0, 0, 360, (90, 90, 90), -1) # 尾巴 return test_img def image_to_base64(image): """将OpenCV图像转换为Base64字符串""" _, buffer = cv2.imencode('.jpg', image) return base64.b64encode(buffer).decode('utf-8') def base64_to_image(base64_str): """将Base64字符串转换为OpenCV图像""" img_bytes = base64.b64decode(base64_str) nparr = np.frombuffer(img_bytes, np.uint8) return cv2.imdecode(nparr, cv2.IMREAD_COLOR) def test_api_with_visualization(): """测试带可视化功能的API""" print("=== 测试可视化图像API功能 ===") # 加载测试图像 test_image = load_test_image() if test_image is None: print("❌ 无法加载测试图像") return False print(f"测试图像尺寸: {test_image.shape}") # 转换为Base64 image_base64 = image_to_base64(test_image) # 测试不同的配置 test_configs = [ { "name": "不返回图像", "data": { "image": image_base64, "conf_threshold": 0.1, "return_image": False, "frame_id": 1, "timestamp": time.time() } }, { "name": "返回图像(不包括边界框)", "data": { "image": image_base64, "conf_threshold": 0.1, "return_image": True, "draw_bbox": False, "frame_id": 2, "timestamp": time.time() } }, { "name": "返回图像(包括边界框)", "data": { "image": image_base64, "conf_threshold": 0.1, "return_image": True, "draw_bbox": True, "frame_id": 3, "timestamp": time.time() } } ] api_url = "http://localhost:8765/api/process_frame" for i, config in enumerate(test_configs): print(f"\n--- 测试 {i+1}: {config['name']} ---") try: # 发送请求 start_time = time.time() response = requests.post(api_url, json=config["data"]) request_time = time.time() - start_time if response.status_code == 200: result = response.json() print(f"✅ 请求成功 (耗时: {request_time:.3f}s)") print(f"检测到小鼠: {result['mouse_detected']}") if result['mouse_detected']: print(f"置信度: {result['confidence']:.3f}") print(f"关键点数量: {len(result['keypoints'])}") print(f"处理FPS: {result['fps']:.1f}") # 检查是否有可视化图像 if "visualization_image" in result and result["visualization_image"]: print("✅ 包含可视化图像") # 解码并保存可视化图像 vis_image = base64_to_image(result["visualization_image"]) output_filename = f"test_result_{i+1}_{config['name'].replace(' ', '_').replace('(', '').replace(')', '')}.jpg" cv2.imwrite(output_filename, vis_image) print(f"可视化图像已保存: {output_filename}") else: if "visualization_image" in result: print(f"❌ visualization_image存在但为空: {result['visualization_image'] is None}") else: print("❌ 响应中没有visualization_image字段") print(f"响应中的所有字段: {list(result.keys())}") # 打印部分响应内容(不包括过长的字段) debug_result = {k: v for k, v in result.items() if k != 'visualization_image'} print(f"响应内容(除图像外): {json.dumps(debug_result, indent=2, ensure_ascii=False)}") else: print(f"❌ 请求失败: {response.status_code}") print(f"错误信息: {response.text}") except Exception as e: print(f"❌ 请求出错: {str(e)}") return True def test_websocket_with_visualization(): """测试WebSocket的可视化功能""" print("\n=== 测试WebSocket可视化功能 ===") try: import websockets import asyncio async def test_ws(): # 加载测试图像 test_image = load_test_image() image_base64 = image_to_base64(test_image) uri = "ws://localhost:8765/ws/stream" async with websockets.connect(uri) as websocket: print("✅ WebSocket连接成功") # 发送带可视化的数据 frame_data = { "image": image_base64, "conf_threshold": 0.1, # 降低置信度阈值 "return_image": True, "draw_bbox": False, "frame_id": 100, "timestamp": time.time() } await websocket.send(json.dumps(frame_data)) # 接收结果 response = await websocket.recv() result = json.loads(response) if result['success']: print(f"✅ WebSocket处理成功") print(f"检测到小鼠: {result['mouse_detected']}") if "visualization_image" in result and result["visualization_image"]: print("✅ 包含可视化图像") # 保存WebSocket结果 vis_image = base64_to_image(result["visualization_image"]) cv2.imwrite("websocket_result.jpg", vis_image) print("WebSocket可视化图像已保存: websocket_result.jpg") else: print("❌ WebSocket结果不包含可视化图像") else: print(f"❌ WebSocket处理失败: {result.get('error', 'Unknown error')}") # 运行异步测试 asyncio.run(test_ws()) return True except ImportError: print("❌ 未安装websockets库,跳过WebSocket测试") return False except Exception as e: print(f"❌ WebSocket测试出错: {str(e)}") return False def main(): """主函数""" print("🚀 开始测试可视化图像API功能") print("确保WebRTC服务已启动在 http://localhost:8765") # 检查服务是否可用 try: response = requests.get("http://localhost:8765/api/status", timeout=5) if response.status_code == 200: print("✅ WebRTC服务可用") else: print("❌ WebRTC服务不可用") return except Exception as e: print(f"❌ 无法连接到WebRTC服务: {str(e)}") print("请先启动服务: python gradio_webrtc_api.py") return # 测试REST API test_api_with_visualization() # 测试WebSocket test_websocket_with_visualization() print("\n🎉 测试完成!检查生成的图像文件查看结果。") if __name__ == "__main__": main()