File size: 4,246 Bytes
34ef1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a79aee3
 
34ef1f5
 
 
 
 
 
 
 
13b03e3
 
 
6c58fca
 
 
 
a8426bb
 
6e7ed91
 
 
 
a79aee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34ef1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8426bb
 
34ef1f5
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import importlib
import sys
from pathlib import Path
from unittest.mock import patch

from fastapi.testclient import TestClient

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

server_app = importlib.import_module("server.app")


def main() -> None:
    client = TestClient(server_app.app)

    root = client.get("/")
    assert root.status_code == 200
    assert "text/html" in root.headers["content-type"]
    assert "ADAPT Judge Demo" in root.text

    model_status = client.get("/model/status")
    assert model_status.status_code == 200
    assert "loaded" in model_status.json()

    train_status = client.get("/train/status")
    assert train_status.status_code == 200
    assert "status" in train_status.json()
    assert "completed_steps" in train_status.json()
    assert "remaining_steps" in train_status.json()
    assert "phase" in train_status.json()
    assert "started_at" in train_status.json()
    assert "finished_at" in train_status.json()
    assert "elapsed_seconds" in train_status.json()
    assert "timing_summary" in train_status.json()
    assert "latest_uploaded_checkpoint_step" in train_status.json()
    assert "latest_uploaded_checkpoint_repo_path" in train_status.json()
    assert "run_manifest_path" in train_status.json()
    assert "events_path" in train_status.json()
    assert "latest_checkpoint_path" in train_status.json()
    assert "logs_deleted_from_space" in train_status.json()
    assert "overall_accuracy" in train_status.json()
    assert "reward_curve" in train_status.json()

    run_code = client.post(
        "/run-code",
        json={
            "code": "data = input().strip()\nprint(data[::-1])",
            "stdin": "adapt\n",
        },
    )
    assert run_code.status_code == 200
    assert run_code.json() == {"stdout": "tpada\n", "stderr": ""}

    run_code_error = client.post(
        "/run-code",
        json={
            "code": "raise ValueError('boom')",
            "stdin": "",
        },
    )
    assert run_code_error.status_code == 200
    assert set(run_code_error.json().keys()) == {"stdout", "stderr"}
    assert run_code_error.json()["stdout"] == ""
    assert "ValueError: boom" in run_code_error.json()["stderr"]

    reset = client.post("/reset", json={"difficulty": "easy", "problem_id": "sum_even_numbers"})
    assert reset.status_code == 200
    session_id = reset.json()["session_id"]

    step = client.post(
        "/step",
        json={
            "session_id": session_id,
            "code": "n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))",
        },
    )
    assert step.status_code == 200
    assert step.json()["done"] is True

    state = client.get("/state", params={"session_id": session_id})
    assert state.status_code == 200
    assert state.json()["session_id"] == session_id

    no_model = client.post("/run-trained-policy", json={"difficulty": "easy"})
    assert no_model.status_code == 409

    with patch.object(
        server_app.TRAINING_MANAGER,
        "generate_code",
        return_value={"code": "print(42)", "completion": "print(42)", "problem_id": "custom_problem"},
    ):
        generate = client.post(
            "/generate-code",
            json={
                "problem": "Print 42.",
                "input_format": "No input.",
                "constraints": "None",
            },
        )
        assert generate.status_code == 200
        assert generate.json()["code"] == "print(42)"

    with patch.object(server_app.TRAINING_MANAGER, "start_training", return_value={"status": "running", "run_id": "demo"}):
        train = client.post("/train", json={})
        assert train.status_code == 200
        assert train.json()["status"] == "running"

    assert server_app.TrainRequest().preset == "overnight"

    with patch.object(
        server_app.TRAINING_MANAGER,
        "start_training",
        side_effect=RuntimeError("A training run is already in progress."),
    ):
        conflict = client.post("/train", json={})
        assert conflict.status_code == 409

    print("Space API smoke tests passed")


if __name__ == "__main__":
    main()