Spaces:
Build error
Build error
| """ | |
| BankBot WebSocket Streaming Validation Script | |
| ============================================== | |
| Tests the /api/ai/chat/ws WebSocket endpoint for: | |
| 1. Streaming chat response (chat_start β chat_chunk(s) β chat_end) | |
| 2. Ping/pong keepalive | |
| 3. Invalid JSON error handling | |
| Usage: | |
| # From the backend/ directory with the server running: | |
| python app/scripts/test_websocket.py | |
| Exit codes: | |
| 0 β all tests passed | |
| 1 β one or more tests failed | |
| """ | |
| import sys | |
| import json | |
| import asyncio | |
| import websockets | |
| WS_URL = "ws://127.0.0.1:8000/api/ai/chat/ws" | |
| # βββ Result tracking ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = [] # list of (name, passed, detail) | |
| def record(name: str, passed: bool, detail: str = ""): | |
| results.append((name, passed, detail)) | |
| # βββ Tests ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def test_chat_streaming(): | |
| """ | |
| Sends a chat message and verifies the full streaming protocol: | |
| chat_start β one or more chat_chunk β chat_end | |
| """ | |
| async with websockets.connect(WS_URL, open_timeout=10) as ws: | |
| await ws.send(json.dumps({ | |
| "type": "chat", | |
| "message": "What is my current balance and savings rate?" | |
| })) | |
| got_start = False | |
| got_chunk = False | |
| got_end = False | |
| full_reply = "" | |
| # Collect messages with a 30-second timeout | |
| deadline = asyncio.get_event_loop().time() + 30 | |
| while asyncio.get_event_loop().time() < deadline: | |
| try: | |
| raw = await asyncio.wait_for(ws.recv(), timeout=30) | |
| except asyncio.TimeoutError: | |
| break | |
| msg = json.loads(raw) | |
| t = msg.get("type") | |
| if t == "chat_start": | |
| got_start = True | |
| elif t == "chat_chunk": | |
| got_chunk = True | |
| full_reply += msg.get("content", "") | |
| elif t == "chat_end": | |
| got_end = True | |
| break | |
| elif t == "error": | |
| raise AssertionError(f"Server returned error: {msg.get('message')}") | |
| assert got_start, "Never received chat_start" | |
| assert got_chunk, "Never received any chat_chunk" | |
| assert got_end, "Never received chat_end" | |
| assert len(full_reply) > 5, f"Assembled reply is too short: '{full_reply}'" | |
| record("WS chat streaming", True, | |
| f"reply_len={len(full_reply)} chars | preview: {full_reply[:80].strip()}...") | |
| async def test_ping_pong(): | |
| """ | |
| Sends a ping and verifies the server responds with pong. | |
| """ | |
| async with websockets.connect(WS_URL, open_timeout=10) as ws: | |
| await ws.send(json.dumps({"type": "ping"})) | |
| raw = await asyncio.wait_for(ws.recv(), timeout=10) | |
| msg = json.loads(raw) | |
| assert msg.get("type") == "pong", \ | |
| f"Expected pong, got: {msg}" | |
| record("WS ping/pong", True) | |
| async def test_invalid_json(): | |
| """ | |
| Sends a non-JSON string and verifies the server returns an error message. | |
| """ | |
| async with websockets.connect(WS_URL, open_timeout=10) as ws: | |
| await ws.send("this is not valid json {{{{") | |
| raw = await asyncio.wait_for(ws.recv(), timeout=10) | |
| msg = json.loads(raw) | |
| assert msg.get("type") == "error", \ | |
| f"Expected error response, got: {msg}" | |
| record("WS invalid JSON handling", True, | |
| f"error_msg={msg.get('message', '')[:60]}") | |
| # βββ Runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def main(): | |
| print(f"\n{'β'*60}") | |
| print(f" BankBot WebSocket Validation β {WS_URL}") | |
| print(f"{'β'*60}\n") | |
| tests = [ | |
| ("WS chat streaming", test_chat_streaming), | |
| ("WS ping/pong", test_ping_pong), | |
| ("WS invalid JSON handling", test_invalid_json), | |
| ] | |
| for name, test_fn in tests: | |
| try: | |
| await test_fn() | |
| except AssertionError as e: | |
| record(name, False, str(e)) | |
| except Exception as e: | |
| record(name, False, f"Exception: {type(e).__name__}: {e}") | |
| # ββ Summary table βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'β'*60}") | |
| print(f" {'TEST':<35} {'RESULT':<8} DETAIL") | |
| print(f"{'β'*60}") | |
| passed = 0 | |
| failed = 0 | |
| for test_name, ok, detail in results: | |
| status = "β PASS" if ok else "β FAIL" | |
| print(f" {test_name:<35} {status:<8} {detail}") | |
| if ok: | |
| passed += 1 | |
| else: | |
| failed += 1 | |
| print(f"{'β'*60}") | |
| print(f" {passed} passed | {failed} failed | {len(results)} total") | |
| print(f"{'β'*60}\n") | |
| return failed | |
| if __name__ == "__main__": | |
| failed_count = asyncio.run(main()) | |
| sys.exit(0 if failed_count == 0 else 1) | |