File size: 3,425 Bytes
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
 
 
 
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
 
 
 
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
eb1ebe6
43f41de
 
eb1ebe6
 
 
 
43f41de
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Integration test: start server, connect client, run explore→generate.

Usage:
    uv run python tests/test_client_server.py          # auto-starts server
    uv run python tests/test_client_server.py --url http://localhost:8000
"""

import argparse
import subprocess
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from client import ExplainerEnv
from models import ExplainerAction


def wait_for_server(url: str, timeout: int = 15):
    import urllib.request

    deadline = time.time() + timeout
    while time.time() < deadline:
        try:
            urllib.request.urlopen(f"{url}/health", timeout=2)
            return True
        except Exception:
            time.sleep(0.5)
    return False


def run_tests(base_url: str):
    client = ExplainerEnv(base_url=base_url)
    with client.sync() as sc:
        # --- reset ---
        result = sc.reset()
        obs = result.observation
        assert obs.topic, "reset should return a topic"
        assert obs.phase == "explore"
        assert obs.explore_steps_left == 3
        print(f"  reset: topic={obs.topic!r}, phase={obs.phase}")

        # --- explore ---
        action = ExplainerAction(
            action_type="explore",
            tool="search_wikipedia",
            query=obs.topic,
            intent="overview",
        )
        result = sc.step(action)
        assert not result.done
        assert result.observation.explore_steps_left == 2
        print(f"  explore: reward={result.reward:.3f}, steps_left={result.observation.explore_steps_left}")

        # --- generate ---
        action = ExplainerAction(
            action_type="generate",
            format="marimo",
            code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n    mo.md('hi')\n    return\n",
        )
        result = sc.step(action)
        if not result.done:
            result = sc.step(ExplainerAction(
                action_type="repair",
                format="marimo",
                code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n    mo.md('hi')\n    return\n",
            ))
        assert result.done
        assert isinstance(result.reward, (int, float))
        print(f"  generate: reward={result.reward:.3f}, done={result.done}")

        # --- second episode ---
        result2 = sc.reset()
        assert result2.observation.topic
        print(f"  reset2: topic={result2.observation.topic!r}")

    print("PASS: test_client_server (4/4)")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--url", default=None)
    args = parser.parse_args()

    if args.url:
        run_tests(args.url)
    else:
        port = "8010"
        proc = subprocess.Popen(
            ["uv", "run", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", port],
            cwd=str(Path(__file__).resolve().parents[1]),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        try:
            url = f"http://localhost:{port}"
            if not wait_for_server(url):
                stderr = proc.stderr.read().decode() if proc.stderr else ""
                print(f"FAIL: server did not start\n{stderr}", file=sys.stderr)
                sys.exit(1)
            run_tests(url)
        finally:
            proc.terminate()
            proc.wait(timeout=5)


if __name__ == "__main__":
    main()