ambiguity-env / test_server.py
Yaser77's picture
feat: ambiguity resolution environment v1.0 - OpenEnv Hackathon
c06cf60
"""
test_server.py
Validates all server endpoints using FastAPI TestClient (no real server needed).
"""
import sys, importlib.util, io, types, os
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
def _load(path, name):
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
sys.modules[name] = mod
spec.loader.exec_module(mod)
return mod
_load("models/models.py", "models.models")
_load("skills/ambiguity_detection.py", "skills.ambiguity_detection")
_load("skills/conversation_memory.py", "skills.conversation_memory")
_load("skills/reward_system.py", "skills.reward_system")
_load("env/env.py", "env.env")
_load("tasks/tasks.py", "tasks.tasks")
_load("grader/grader.py", "grader.grader")
# stub dotenv
fake_dl = types.ModuleType("dotenv")
fake_dl.load_dotenv = lambda: None
sys.modules["dotenv"] = fake_dl
os.environ.setdefault("HF_TOKEN", "test")
_load("server.py", "server")
from fastapi.testclient import TestClient
from server import app
client = TestClient(app)
passed = []
failed = []
def check(name, cond, got=None):
if cond:
passed.append(name)
print(f" [PASS] {name}")
else:
failed.append(name)
print(f" [FAIL] {name} (got: {got})")
SEP = "=" * 55
print(SEP)
print("SERVER ENDPOINT TESTS")
print(SEP)
# GET /
print("\nGET /")
r = client.get("/")
check("GET / β†’ 200", r.status_code == 200, r.status_code)
check("GET / β†’ status=running", r.json().get("status") == "running")
check("GET / β†’ tasks list", isinstance(r.json().get("tasks"), list))
# GET /health
print("\nGET /health")
r = client.get("/health")
check("GET /health β†’ 200", r.status_code == 200, r.status_code)
check("GET /health β†’ ok", r.json().get("status") == "ok")
# GET /tasks
print("\nGET /tasks")
r = client.get("/tasks")
check("GET /tasks β†’ 200", r.status_code == 200, r.status_code)
check("GET /tasks β†’ 4 items", len(r.json().get("tasks", [])) == 4)
# POST /reset (default)
print("\nPOST /reset (no body)")
r = client.post("/reset")
check("POST /reset β†’ 200", r.status_code == 200, r.status_code)
body = r.json()
check("/reset β†’ status=ok", body.get("status") == "ok")
check("/reset β†’ observation", "observation" in body)
check("/reset β†’ instruction", "instruction" in body["observation"])
# POST /reset (specific task)
print("\nPOST /reset (easy_explicit)")
r = client.post("/reset", json={"task_name": "easy_explicit"})
check("/reset easy β†’ 200", r.status_code == 200, r.status_code)
check("/reset easy β†’ task", r.json().get("task") == "easy_explicit")
# POST /reset (bad task)
print("\nPOST /reset (nonexistent_task)")
r = client.post("/reset", json={"task_name": "nonexistent_task"})
check("/reset bad_task β†’ 404", r.status_code == 404, r.status_code)
# POST /step (ask)
print("\nPOST /step (ask)")
client.post("/reset", json={"task_name": "hard_ambiguous"})
r = client.post("/step", json={"type": "ask", "question": "When should this happen?"})
check("/step ask β†’ 200", r.status_code == 200, r.status_code)
body = r.json()
check("/step ask β†’ observation", "observation" in body)
check("/step ask β†’ reward", "reward" in body)
check("/step ask β†’ done=false", body.get("done") == False)
check("/step ask β†’ reward=0.3", body.get("reward") == 0.3, body.get("reward"))
# POST /step (execute)
print("\nPOST /step (execute)")
client.post("/step", json={"type": "ask", "question": "Who are the participants?"})
r = client.post("/step", json={
"type": "execute",
"proposed_time": "10 AM",
"proposed_participants": ["Team A"]
})
check("/step execute β†’ 200", r.status_code == 200, r.status_code)
body = r.json()
check("/step execute β†’ done=true", body.get("done") == True)
check("/step execute β†’ reward=1.0", body.get("reward") == 1.0, body.get("reward"))
# POST /step (bad action)
print("\nPOST /step (bad action)")
r = client.post("/step", json={"type": "fly"})
check("/step bad_action β†’ 422", r.status_code == 422, r.status_code)
# GET /state
print("\nGET /state")
client.post("/reset", json={"task_name": "hard_ambiguous"})
r = client.get("/state")
check("/state β†’ 200", r.status_code == 200, r.status_code)
s = r.json().get("state", {})
check("/state β†’ instruction", "instruction" in s)
check("/state β†’ question_count", "question_count" in s)
check("/state β†’ done", "done" in s)
# POST /run_task
print("\nPOST /run_task (hard_ambiguous)")
r = client.post("/run_task", json={"task_name": "hard_ambiguous"})
check("/run_task β†’ 200", r.status_code == 200, r.status_code)
body = r.json()
check("/run_task β†’ total_reward", "total_reward" in body)
check("/run_task β†’ success=true", body.get("success") == True, body.get("success"))
check("/run_task β†’ log present", len(body.get("log", [])) > 0)
print(f" total_reward={body.get('total_reward')} steps={body.get('steps')}")
print(f"\n{SEP}")
print("FINAL REPORT")
print(SEP)
total = len(passed) + len(failed)
print(f"\n Passed : {len(passed)}/{total}")
print(f" Failed : {len(failed)}/{total}")
if failed:
print("\n FAILURES:")
for f in failed:
print(f" - {f}")
sys.exit(1)
else:
print("\n ALL SERVER TESTS PASSED")
print(" Ready to build Docker image and deploy to HF Spaces.")