File size: 7,920 Bytes
395651c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import asyncio
import copy
import json
import os
import time

import httpx
import pytest

from tests.cases.pipeline_cases import (
    Q10_FOLLOW_UP,
    Q13_HISTORY_STEPS,
    QUERIES,
    validate_q10_step2_dsl,
    validate_query_result,
)

BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000")
USER_ID = os.getenv("TEST_USER_ID")
SESSION_ID = os.getenv("TEST_SESSION_ID")

test_stats: list[dict] = []


_SOLVER_TRANSIENT = "Solver failed after multiple attempts"


async def run_single_api_query(client, q, headers, default_session_id: str | None):
    print(f"\nπŸš€ [RUNNING] {q['id']}: {q['text']}")
    start_time = time.time()

    payload = {
        "text": q["text"],
        "request_video": q.get("request_video", False),
    }

    max_rounds = 3

    try:
        for round_idx in range(max_rounds):
            if q.get("isolate", True):
                session_resp = await client.post("/api/v1/sessions", headers=headers)
                if session_resp.status_code != 200:
                    return {
                        "id": q["id"],
                        "query": q["text"],
                        "success": False,
                        "error": f"Session creation failed: {session_resp.text}",
                    }
                session_id = session_resp.json()["id"]
            else:
                session_id = q.get("session_id", default_session_id)

            res = await client.post(
                f"/api/v1/sessions/{session_id}/solve",
                json=payload,
                headers=headers,
            )
            if res.status_code != 200:
                print(f"   ❌ FAILED: Status {res.status_code} - {res.text}")
                return {
                    "id": q["id"],
                    "query": q["text"],
                    "success": False,
                    "error": f"HTTP {res.status_code}: {res.text}",
                }

            job_id = res.json()["job_id"]
            print(f"   βœ… Job Created: {job_id}")

            max_attempts = 45
            result_data = None
            last_error = None
            for i in range(max_attempts):
                await asyncio.sleep(4)
                res = await client.get(f"/api/v1/solve/{job_id}", headers=headers)
                data = res.json()
                status = data.get("status")
                print(f"      - Polling ({i + 1}): {status}")

                if status == "success":
                    result_data = data["result"]
                    break
                if status == "error":
                    last_error = data.get("result", {}).get("error")
                    print(f"   ❌ ERROR: {last_error}")
                    err_s = str(last_error or "")
                    if _SOLVER_TRANSIENT in err_s and round_idx < max_rounds - 1:
                        print(
                            f"   ↻ Retry {round_idx + 2}/{max_rounds} (transient solver/LLM flake)"
                        )
                        result_data = None
                        break
                    return {
                        "id": q["id"],
                        "query": q["text"],
                        "success": False,
                        "error": last_error,
                    }

                if i == max_attempts - 1:
                    print("   ❌ TIMEOUT")
                    return {"id": q["id"], "query": q["text"], "success": False, "error": "Timeout"}

            if result_data is None:
                continue

            elapsed = time.time() - start_time
            errors = validate_query_result(q, result_data)

            if q.get("request_video") and not result_data.get("video_url"):
                print("      ⚠️ Video requested but no URL found (Expected in some test envs)")

            if errors:
                print(f"   ❌ VALIDATION FAILED: {', '.join(errors)}")
                return {
                    "id": q["id"],
                    "query": q["text"],
                    "success": False,
                    "error": "; ".join(errors),
                    "elapsed": elapsed,
                    "result": result_data,
                }

            print(f"   βœ… PASS ({elapsed:.2f}s)")
            return {
                "id": q["id"],
                "query": q["text"],
                "success": True,
                "elapsed": elapsed,
                "job_id": job_id,
                "result": result_data,
            }

        raise RuntimeError("run_single_api_query: retry loop fell through (bug)")

    except Exception as e:
        print(f"   ❌ EXCEPTION: {str(e)}")
        return {"id": q["id"], "query": q["text"], "success": False, "error": str(e)}


@pytest.mark.real_api
@pytest.mark.slow
@pytest.mark.asyncio
async def test_full_api_suite():
    if not USER_ID or not SESSION_ID:
        pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set")

    global test_stats
    test_stats = []

    headers = {"Authorization": f"Test {USER_ID}"}

    async with httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) as client:
        for q in QUERIES:
            if q["id"] == "Q10-Step1":
                continue
            qc = copy.deepcopy(q)
            res = await run_single_api_query(client, qc, headers, SESSION_ID)
            test_stats.append(res)

        print("\n--- Testing Multi-turn API Flow (Q10) ---")
        shared_session_resp = await client.post("/api/v1/sessions", headers=headers)
        assert shared_session_resp.status_code == 200
        shared_session = shared_session_resp.json()["id"]

        q10_1 = copy.deepcopy(next(q for q in QUERIES if q["id"] == "Q10-Step1"))
        q10_1["session_id"] = shared_session
        q10_1["isolate"] = False
        res10_1 = await run_single_api_query(client, q10_1, headers, SESSION_ID)
        test_stats.append(res10_1)

        if res10_1["success"]:
            q10_2 = copy.deepcopy(Q10_FOLLOW_UP)
            q10_2["session_id"] = shared_session
            q10_2["isolate"] = False
            res10_2 = await run_single_api_query(client, q10_2, headers, SESSION_ID)

            if res10_2["success"]:
                dsl = res10_2.get("result", {}).get("geometry_dsl", "") or ""
                if not validate_q10_step2_dsl(dsl):
                    res10_2["success"] = False
                    res10_2["error"] = "DSL did not merge history correctly"

            test_stats.append(res10_2)

        print("\n--- Testing Multi-turn API Flow (Q13 history in one session) ---")
        q13_session_resp = await client.post("/api/v1/sessions", headers=headers)
        assert q13_session_resp.status_code == 200
        q13_session = q13_session_resp.json()["id"]

        prev_ok = True
        for step in Q13_HISTORY_STEPS:
            if not prev_ok:
                test_stats.append(
                    {
                        "id": step["id"],
                        "query": step["text"],
                        "success": False,
                        "error": "Skipped: previous step in Q13 failed",
                    }
                )
                continue
            sc = copy.deepcopy(step)
            sc["session_id"] = q13_session
            sc["isolate"] = False
            out = await run_single_api_query(client, sc, headers, SESSION_ID)
            test_stats.append(out)
            prev_ok = bool(out.get("success"))

    with open("temp_suite_results.json", "w", encoding="utf-8") as f:
        json.dump(test_stats, f, ensure_ascii=False, indent=2)

    failures = [r for r in test_stats if not r.get("success")]
    if failures:
        pytest.fail(
            "Suite failures: "
            + "; ".join(f"{r.get('id')}: {r.get('error')}" for r in failures[:5])
            + (f" (+{len(failures) - 5} more)" if len(failures) > 5 else "")
        )


if __name__ == "__main__":
    asyncio.run(test_full_api_suite())