Spaces:
Running
Running
File size: 4,246 Bytes
34ef1f5 a79aee3 34ef1f5 13b03e3 6c58fca a8426bb 6e7ed91 a79aee3 34ef1f5 a8426bb 34ef1f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | from __future__ import annotations
import importlib
import sys
from pathlib import Path
from unittest.mock import patch
from fastapi.testclient import TestClient
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
server_app = importlib.import_module("server.app")
def main() -> None:
client = TestClient(server_app.app)
root = client.get("/")
assert root.status_code == 200
assert "text/html" in root.headers["content-type"]
assert "ADAPT Judge Demo" in root.text
model_status = client.get("/model/status")
assert model_status.status_code == 200
assert "loaded" in model_status.json()
train_status = client.get("/train/status")
assert train_status.status_code == 200
assert "status" in train_status.json()
assert "completed_steps" in train_status.json()
assert "remaining_steps" in train_status.json()
assert "phase" in train_status.json()
assert "started_at" in train_status.json()
assert "finished_at" in train_status.json()
assert "elapsed_seconds" in train_status.json()
assert "timing_summary" in train_status.json()
assert "latest_uploaded_checkpoint_step" in train_status.json()
assert "latest_uploaded_checkpoint_repo_path" in train_status.json()
assert "run_manifest_path" in train_status.json()
assert "events_path" in train_status.json()
assert "latest_checkpoint_path" in train_status.json()
assert "logs_deleted_from_space" in train_status.json()
assert "overall_accuracy" in train_status.json()
assert "reward_curve" in train_status.json()
run_code = client.post(
"/run-code",
json={
"code": "data = input().strip()\nprint(data[::-1])",
"stdin": "adapt\n",
},
)
assert run_code.status_code == 200
assert run_code.json() == {"stdout": "tpada\n", "stderr": ""}
run_code_error = client.post(
"/run-code",
json={
"code": "raise ValueError('boom')",
"stdin": "",
},
)
assert run_code_error.status_code == 200
assert set(run_code_error.json().keys()) == {"stdout", "stderr"}
assert run_code_error.json()["stdout"] == ""
assert "ValueError: boom" in run_code_error.json()["stderr"]
reset = client.post("/reset", json={"difficulty": "easy", "problem_id": "sum_even_numbers"})
assert reset.status_code == 200
session_id = reset.json()["session_id"]
step = client.post(
"/step",
json={
"session_id": session_id,
"code": "n=int(input())\nnums=list(map(int,input().split()))\nprint(sum(x for x in nums if x % 2 == 0))",
},
)
assert step.status_code == 200
assert step.json()["done"] is True
state = client.get("/state", params={"session_id": session_id})
assert state.status_code == 200
assert state.json()["session_id"] == session_id
no_model = client.post("/run-trained-policy", json={"difficulty": "easy"})
assert no_model.status_code == 409
with patch.object(
server_app.TRAINING_MANAGER,
"generate_code",
return_value={"code": "print(42)", "completion": "print(42)", "problem_id": "custom_problem"},
):
generate = client.post(
"/generate-code",
json={
"problem": "Print 42.",
"input_format": "No input.",
"constraints": "None",
},
)
assert generate.status_code == 200
assert generate.json()["code"] == "print(42)"
with patch.object(server_app.TRAINING_MANAGER, "start_training", return_value={"status": "running", "run_id": "demo"}):
train = client.post("/train", json={})
assert train.status_code == 200
assert train.json()["status"] == "running"
assert server_app.TrainRequest().preset == "overnight"
with patch.object(
server_app.TRAINING_MANAGER,
"start_training",
side_effect=RuntimeError("A training run is already in progress."),
):
conflict = client.post("/train", json={})
assert conflict.status_code == 409
print("Space API smoke tests passed")
if __name__ == "__main__":
main()
|