File size: 5,398 Bytes
a282d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
BankBot WebSocket Streaming Validation Script
==============================================
Tests the /api/ai/chat/ws WebSocket endpoint for:
  1. Streaming chat response (chat_start β†’ chat_chunk(s) β†’ chat_end)
  2. Ping/pong keepalive
  3. Invalid JSON error handling

Usage:
    # From the backend/ directory with the server running:
    python app/scripts/test_websocket.py

Exit codes:
    0 β€” all tests passed
    1 β€” one or more tests failed
"""

import sys
import json
import asyncio
import websockets

WS_URL = "ws://127.0.0.1:8000/api/ai/chat/ws"

# ─── Result tracking ──────────────────────────────────────────────────────────

results = []  # list of (name, passed, detail)

def record(name: str, passed: bool, detail: str = ""):
    results.append((name, passed, detail))

# ─── Tests ────────────────────────────────────────────────────────────────────

async def test_chat_streaming():
    """
    Sends a chat message and verifies the full streaming protocol:
    chat_start β†’ one or more chat_chunk β†’ chat_end
    """
    async with websockets.connect(WS_URL, open_timeout=10) as ws:
        await ws.send(json.dumps({
            "type": "chat",
            "message": "What is my current balance and savings rate?"
        }))

        got_start  = False
        got_chunk  = False
        got_end    = False
        full_reply = ""

        # Collect messages with a 30-second timeout
        deadline = asyncio.get_event_loop().time() + 30
        while asyncio.get_event_loop().time() < deadline:
            try:
                raw = await asyncio.wait_for(ws.recv(), timeout=30)
            except asyncio.TimeoutError:
                break

            msg = json.loads(raw)
            t   = msg.get("type")

            if t == "chat_start":
                got_start = True
            elif t == "chat_chunk":
                got_chunk = True
                full_reply += msg.get("content", "")
            elif t == "chat_end":
                got_end = True
                break
            elif t == "error":
                raise AssertionError(f"Server returned error: {msg.get('message')}")

        assert got_start,  "Never received chat_start"
        assert got_chunk,  "Never received any chat_chunk"
        assert got_end,    "Never received chat_end"
        assert len(full_reply) > 5, f"Assembled reply is too short: '{full_reply}'"

        record("WS chat streaming", True,
               f"reply_len={len(full_reply)} chars | preview: {full_reply[:80].strip()}...")


async def test_ping_pong():
    """
    Sends a ping and verifies the server responds with pong.
    """
    async with websockets.connect(WS_URL, open_timeout=10) as ws:
        await ws.send(json.dumps({"type": "ping"}))

        raw = await asyncio.wait_for(ws.recv(), timeout=10)
        msg = json.loads(raw)

        assert msg.get("type") == "pong", \
            f"Expected pong, got: {msg}"

        record("WS ping/pong", True)


async def test_invalid_json():
    """
    Sends a non-JSON string and verifies the server returns an error message.
    """
    async with websockets.connect(WS_URL, open_timeout=10) as ws:
        await ws.send("this is not valid json {{{{")

        raw = await asyncio.wait_for(ws.recv(), timeout=10)
        msg = json.loads(raw)

        assert msg.get("type") == "error", \
            f"Expected error response, got: {msg}"

        record("WS invalid JSON handling", True,
               f"error_msg={msg.get('message', '')[:60]}")


# ─── Runner ───────────────────────────────────────────────────────────────────

async def main():
    print(f"\n{'─'*60}")
    print(f"  BankBot WebSocket Validation  β€”  {WS_URL}")
    print(f"{'─'*60}\n")

    tests = [
        ("WS chat streaming",       test_chat_streaming),
        ("WS ping/pong",            test_ping_pong),
        ("WS invalid JSON handling", test_invalid_json),
    ]

    for name, test_fn in tests:
        try:
            await test_fn()
        except AssertionError as e:
            record(name, False, str(e))
        except Exception as e:
            record(name, False, f"Exception: {type(e).__name__}: {e}")

    # ── Summary table ─────────────────────────────────────────────────────────
    print(f"\n{'─'*60}")
    print(f"  {'TEST':<35} {'RESULT':<8} DETAIL")
    print(f"{'─'*60}")

    passed = 0
    failed = 0
    for test_name, ok, detail in results:
        status = "βœ… PASS" if ok else "❌ FAIL"
        print(f"  {test_name:<35} {status:<8} {detail}")
        if ok:
            passed += 1
        else:
            failed += 1

    print(f"{'─'*60}")
    print(f"  {passed} passed  |  {failed} failed  |  {len(results)} total")
    print(f"{'─'*60}\n")

    return failed


if __name__ == "__main__":
    failed_count = asyncio.run(main())
    sys.exit(0 if failed_count == 0 else 1)