| """ |
| Test suite for the API Contract Debugger environment. |
| |
| Coverage: |
| - Violation detection (all violation types) |
| - Grader scoring |
| - Per-step reward shaping |
| - Environment reset / step / state |
| - All three tasks end-to-end |
| - Edge cases: malformed actions, double-fix, already-clean spec |
| - HTTP API routes (via TestClient) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import sys |
| import os |
|
|
| import pytest |
|
|
| |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
|
|
| from server.fixtures import TASK_EASY, TASK_HARD, TASK_MEDIUM, TASKS |
| from server.graders import detect_violations, grade_episode, step_reward |
| from server.models import ActionKind, DebugAction |
| from server.environment import APIContractDebuggerEnv |
|
|
|
|
| |
| |
| |
|
|
| def make_env(task: str = "easy") -> APIContractDebuggerEnv: |
| env = APIContractDebuggerEnv(task_name=task) |
| env.reset() |
| return env |
|
|
|
|
| def action(**kwargs) -> DebugAction: |
| defaults = dict( |
| kind=ActionKind.NO_OP, |
| endpoint_index=0, |
| location="response_body", |
| field_name=None, |
| new_value=None, |
| ) |
| defaults.update(kwargs) |
| return DebugAction(**defaults) |
|
|
|
|
| |
| |
| |
|
|
| class TestFixtures: |
| def test_all_tasks_present(self): |
| assert set(TASKS.keys()) == {"easy", "medium", "hard"} |
|
|
| def test_easy_has_violations(self): |
| v = detect_violations(TASK_EASY["broken_endpoints"], TASK_EASY["golden_endpoints"]) |
| assert len(v) == 1 |
|
|
| def test_medium_has_three_violations(self): |
| v = detect_violations(TASK_MEDIUM["broken_endpoints"], TASK_MEDIUM["golden_endpoints"]) |
| assert len(v) == 3 |
|
|
| def test_hard_has_six_violations(self): |
| v = detect_violations(TASK_HARD["broken_endpoints"], TASK_HARD["golden_endpoints"]) |
| assert len(v) == 6 |
|
|
| def test_golden_specs_are_clean(self): |
| for task in TASKS.values(): |
| v = detect_violations(task["golden_endpoints"], task["golden_endpoints"]) |
| assert v == [], f"Golden spec for '{task['name']}' has violations: {v}" |
|
|
| def test_broken_and_golden_same_length(self): |
| for task in TASKS.values(): |
| assert len(task["broken_endpoints"]) == len(task["golden_endpoints"]) |
|
|
|
|
| |
| |
| |
|
|
| class TestViolationDetection: |
| def test_missing_field_detected(self): |
| current = [{"method": "GET", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": {}}] |
| golden = [{"method": "GET", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": { |
| "id": {"type": "integer", "required": True, "description": ""} |
| }}] |
| v = detect_violations(current, golden) |
| assert len(v) == 1 |
| assert v[0]["violation_type"] == "missing_field" |
| assert v[0]["field_name"] == "id" |
|
|
| def test_extra_field_detected(self): |
| current = [{"method": "GET", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": { |
| "secret": {"type": "string", "required": False, "description": ""} |
| }}] |
| golden = [{"method": "GET", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": {}}] |
| v = detect_violations(current, golden) |
| assert len(v) == 1 |
| assert v[0]["violation_type"] == "extra_field" |
|
|
| def test_wrong_type_detected(self): |
| current = [{"method": "GET", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": { |
| "count": {"type": "string", "required": True, "description": ""} |
| }}] |
| golden = [{"method": "GET", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": { |
| "count": {"type": "integer", "required": True, "description": ""} |
| }}] |
| v = detect_violations(current, golden) |
| assert len(v) == 1 |
| assert v[0]["violation_type"] == "wrong_type" |
|
|
| def test_wrong_status_detected(self): |
| current = [{"method": "DELETE", "path": "/x", "status_code": 200, |
| "request_body": {}, "response_body": {}}] |
| golden = [{"method": "DELETE", "path": "/x", "status_code": 204, |
| "request_body": {}, "response_body": {}}] |
| v = detect_violations(current, golden) |
| assert len(v) == 1 |
| assert v[0]["violation_type"] == "wrong_status" |
|
|
| def test_no_violations_on_matching_spec(self): |
| golden = TASK_EASY["golden_endpoints"] |
| v = detect_violations(golden, golden) |
| assert v == [] |
|
|
| def test_violation_severity_range(self): |
| v = detect_violations(TASK_HARD["broken_endpoints"], TASK_HARD["golden_endpoints"]) |
| for viol in v: |
| assert 0.0 < viol["severity"] <= 1.0 |
|
|
|
|
| |
| |
| |
|
|
| class TestGrader: |
| def test_perfect_score_when_all_fixed(self): |
| golden = TASK_EASY["golden_endpoints"] |
| initial = detect_violations(TASK_EASY["broken_endpoints"], golden) |
| score = grade_episode(golden, golden, initial) |
| assert score == pytest.approx(1.0) |
|
|
| def test_zero_score_when_nothing_fixed(self): |
| broken = TASK_EASY["broken_endpoints"] |
| golden = TASK_EASY["golden_endpoints"] |
| initial = detect_violations(broken, golden) |
| score = grade_episode(broken, golden, initial) |
| assert score == pytest.approx(0.0) |
|
|
| def test_partial_score_medium(self): |
| broken = copy.deepcopy(TASK_MEDIUM["broken_endpoints"]) |
| golden = TASK_MEDIUM["golden_endpoints"] |
| initial = detect_violations(broken, golden) |
|
|
| |
| broken[0]["response_body"]["product_id"]["type"] = "integer" |
|
|
| score = grade_episode(broken, golden, initial) |
| assert 0.0 < score < 1.0 |
|
|
| def test_score_clamped_to_zero_when_extra_violations_introduced(self): |
| broken = copy.deepcopy(TASK_EASY["broken_endpoints"]) |
| golden = TASK_EASY["golden_endpoints"] |
| initial = detect_violations(broken, golden) |
|
|
| |
| broken[0]["response_body"]["user_id"]["type"] = "string" |
| broken[0]["response_body"]["username"]["type"] = "boolean" |
|
|
| score = grade_episode(broken, golden, initial) |
| assert score == 0.0 |
|
|
| def test_score_in_range(self): |
| for task in TASKS.values(): |
| broken = task["broken_endpoints"] |
| golden = task["golden_endpoints"] |
| initial = detect_violations(broken, golden) |
| score = grade_episode(broken, golden, initial) |
| assert 0.0 <= score <= 1.0, f"Out-of-range score for task '{task['name']}'" |
|
|
| def test_already_clean_spec_scores_one(self): |
| golden = TASK_EASY["golden_endpoints"] |
| initial: list = [] |
| score = grade_episode(golden, golden, initial) |
| assert score == pytest.approx(1.0) |
|
|
|
|
| |
| |
| |
|
|
| class TestStepReward: |
| def _make_violation(self, vtype="missing_field", severity=1.0): |
| return { |
| "endpoint_index": 0, "location": "response_body", |
| "field_name": "foo", "violation_type": vtype, |
| "description": "test", "severity": severity, |
| } |
|
|
| def test_positive_reward_for_fix(self): |
| v = self._make_violation() |
| r = step_reward(prev_violations=[v], new_violations=[], initial_violations=[v], action_error=False) |
| assert r > 0 |
|
|
| def test_negative_reward_for_introduction(self): |
| v = self._make_violation() |
| r = step_reward(prev_violations=[], new_violations=[v], initial_violations=[], action_error=False) |
| assert r < 0 |
|
|
| def test_penalty_for_action_error(self): |
| r = step_reward(prev_violations=[], new_violations=[], initial_violations=[], action_error=True) |
| assert r == pytest.approx(-0.05) |
|
|
| def test_zero_reward_for_no_op(self): |
| r = step_reward(prev_violations=[], new_violations=[], initial_violations=[], action_error=False) |
| assert r == pytest.approx(0.0) |
|
|
|
|
| |
| |
| |
|
|
| class TestEnvReset: |
| def test_reset_returns_observation(self): |
| env = APIContractDebuggerEnv(task_name="easy") |
| obs = env.reset() |
| assert obs.task_name == "easy" |
| assert len(obs.violations) == 1 |
| assert obs.done is False |
| assert obs.step_count == 0 |
|
|
| def test_reset_clears_state(self): |
| env = make_env("easy") |
| |
| env.step(action( |
| kind=ActionKind.ADD_FIELD, |
| location="response_body", |
| field_name="created_at", |
| new_value={"type": "string", "required": True, "description": "timestamp"}, |
| )) |
| obs = env.reset() |
| assert obs.step_count == 0 |
| assert len(obs.violations) == 1 |
|
|
| def test_reset_switches_task(self): |
| env = APIContractDebuggerEnv(task_name="easy") |
| obs = env.reset(task_name="medium") |
| assert obs.task_name == "medium" |
| assert len(obs.violations) == 3 |
|
|
| def test_reset_preserves_golden(self): |
| env = make_env("hard") |
| obs = env.reset() |
| assert obs.total_violations_at_start == 6 |
|
|
| def test_episode_id_set_on_reset(self): |
| env = APIContractDebuggerEnv(task_name="easy") |
| env.reset(episode_id="test-123") |
| assert env.state.episode_id == "test-123" |
|
|
|
|
| |
| |
| |
|
|
| class TestEnvStep: |
| def test_add_missing_field_fixes_easy(self): |
| env = make_env("easy") |
| obs = env.step(action( |
| kind=ActionKind.ADD_FIELD, |
| location="response_body", |
| field_name="created_at", |
| new_value={"type": "string", "required": True, "description": "ISO timestamp"}, |
| )) |
| assert len(obs.violations) == 0 |
| assert obs.done is True |
| assert obs.reward > 0 |
|
|
| def test_wrong_type_action_introduces_violation(self): |
| env = make_env("easy") |
| obs = env.step(action( |
| kind=ActionKind.ADD_FIELD, |
| location="response_body", |
| field_name="created_at", |
| new_value={"type": "integer", "required": True, "description": "wrong type"}, |
| )) |
| |
| assert len(obs.violations) == 1 |
| assert obs.violations[0]["violation_type"] == "wrong_type" |
|
|
| def test_out_of_range_endpoint_index(self): |
| env = make_env("easy") |
| obs = env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=99, |
| field_name="x", new_value={"type": "string"})) |
| assert obs.last_action_error is not None |
| assert "out of range" in obs.last_action_error |
|
|
| def test_change_type_fixes_medium_violation(self): |
| env = make_env("medium") |
| |
| obs = env.step(action( |
| kind=ActionKind.CHANGE_TYPE, |
| endpoint_index=0, |
| location="response_body", |
| field_name="product_id", |
| new_value="integer", |
| )) |
| assert obs.violations_fixed_this_step == 1 |
| assert len(obs.violations) == 2 |
|
|
| def test_change_status_fixes_medium_violation(self): |
| env = make_env("medium") |
| obs = env.step(action( |
| kind=ActionKind.CHANGE_STATUS, |
| endpoint_index=2, |
| location="status_code", |
| new_value=204, |
| )) |
| assert obs.violations_fixed_this_step == 1 |
|
|
| def test_remove_field_fixes_hard_extra_field(self): |
| env = make_env("hard") |
| obs = env.step(action( |
| kind=ActionKind.REMOVE_FIELD, |
| endpoint_index=1, |
| location="response_body", |
| field_name="password_hash", |
| )) |
| assert obs.violations_fixed_this_step == 1 |
|
|
| def test_no_op_does_not_change_violations(self): |
| env = make_env("easy") |
| before = len(env.state.violations) |
| obs = env.step(action(kind=ActionKind.NO_OP)) |
| assert len(obs.violations) == before |
|
|
| def test_step_after_done_returns_done(self): |
| env = make_env("easy") |
| |
| env.step(action( |
| kind=ActionKind.ADD_FIELD, |
| location="response_body", |
| field_name="created_at", |
| new_value={"type": "string", "required": True, "description": "ts"}, |
| )) |
| |
| obs = env.step(action(kind=ActionKind.NO_OP)) |
| assert obs.done is True |
| assert obs.last_action_error is not None |
|
|
| def test_max_steps_terminates_episode(self): |
| env = APIContractDebuggerEnv(task_name="easy") |
| env.reset() |
| obs = None |
| for _ in range(env._task_cfg["max_steps"]): |
| obs = env.step(action(kind=ActionKind.NO_OP)) |
| assert obs.done is True |
|
|
| def test_step_count_increments(self): |
| env = make_env("easy") |
| env.step(action(kind=ActionKind.NO_OP)) |
| env.step(action(kind=ActionKind.NO_OP)) |
| assert env.state.step_count == 2 |
|
|
|
|
| |
| |
| |
|
|
| class TestEnvState: |
| def test_state_reflects_current_endpoints(self): |
| env = make_env("easy") |
| state = env.state |
| assert len(state.current_endpoints) == 1 |
| assert state.task_name == "easy" |
|
|
| def test_state_tracks_step_count(self): |
| env = make_env("easy") |
| env.step(action(kind=ActionKind.NO_OP)) |
| assert env.state.step_count == 1 |
|
|
| def test_original_endpoints_unchanged_after_steps(self): |
| env = make_env("easy") |
| original_before = copy.deepcopy(env.state.original_endpoints) |
| env.step(action( |
| kind=ActionKind.ADD_FIELD, |
| location="response_body", |
| field_name="created_at", |
| new_value={"type": "string", "required": True, "description": "ts"}, |
| )) |
| assert env.state.original_endpoints == original_before |
|
|
|
|
| |
| |
| |
|
|
| class TestFullEpisodes: |
| def test_easy_perfect_solve(self): |
| env = make_env("easy") |
| env.step(action( |
| kind=ActionKind.ADD_FIELD, |
| location="response_body", |
| field_name="created_at", |
| new_value={"type": "string", "required": True, "description": "ISO timestamp"}, |
| )) |
| assert env.score() == pytest.approx(1.0) |
|
|
| def test_medium_perfect_solve(self): |
| env = make_env("medium") |
| |
| env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=0, |
| location="response_body", field_name="product_id", new_value="integer")) |
| |
| env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=1, |
| location="request_body", field_name="quantity", new_value="integer")) |
| |
| env.step(action(kind=ActionKind.CHANGE_STATUS, endpoint_index=2, |
| location="status_code", new_value=204)) |
| assert env.score() == pytest.approx(1.0) |
|
|
| def test_hard_perfect_solve(self): |
| env = make_env("hard") |
| |
| env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=0, |
| location="response_body", field_name="refresh_token", |
| new_value={"type": "string", "required": True, "description": "Refresh token"})) |
| |
| env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=0, |
| location="response_body", field_name="expires_in", new_value="integer")) |
| |
| env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=1, |
| location="response_body", field_name="created_at", |
| new_value={"type": "string", "required": True, "description": "ISO timestamp"})) |
| |
| env.step(action(kind=ActionKind.REMOVE_FIELD, endpoint_index=1, |
| location="response_body", field_name="password_hash")) |
| |
| env.step(action(kind=ActionKind.CHANGE_STATUS, endpoint_index=2, |
| location="status_code", new_value=200)) |
| |
| env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=2, |
| location="response_body", field_name="updated_at", |
| new_value={"type": "string", "required": True, "description": "ISO timestamp"})) |
|
|
| assert env.score() == pytest.approx(1.0) |
|
|
| def test_score_after_partial_solve(self): |
| env = make_env("medium") |
| |
| env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=0, |
| location="response_body", field_name="product_id", new_value="integer")) |
| score = env.score() |
| assert 0.0 < score < 1.0 |
|
|
| def test_unknown_task_raises(self): |
| with pytest.raises(ValueError, match="Unknown task"): |
| APIContractDebuggerEnv(task_name="impossible") |
|
|
|
|
| |
| |
| |
|
|
| class TestHTTPRoutes: |
| @pytest.fixture(autouse=True) |
| def client(self): |
| from fastapi.testclient import TestClient |
| from server.app import app |
| self.client = TestClient(app) |
|
|
| def test_health_endpoint(self): |
| r = self.client.get("/health") |
| assert r.status_code == 200 |
|
|
| def test_reset_returns_200(self): |
| r = self.client.post("/reset", json={}) |
| assert r.status_code == 200 |
| data = r.json() |
| assert "violations" in data |
| assert "endpoints" in data |
|
|
| def test_reset_switches_task(self): |
| r = self.client.post("/reset", json={"task_name": "medium"}) |
| assert r.status_code == 200 |
| assert r.json()["task_name"] == "medium" |
|
|
| def test_reset_unknown_task_422(self): |
| r = self.client.post("/reset", json={"task_name": "impossible"}) |
| assert r.status_code == 422 |
|
|
| def test_step_add_field(self): |
| self.client.post("/reset", json={"task_name": "easy"}) |
| r = self.client.post("/step", json={ |
| "action": { |
| "kind": "add_field", |
| "endpoint_index": 0, |
| "location": "response_body", |
| "field_name": "created_at", |
| "new_value": {"type": "string", "required": True, "description": "ts"}, |
| } |
| }) |
| assert r.status_code == 200 |
| data = r.json() |
| assert data["done"] is True |
| assert data["reward"] > 0 |
|
|
| def test_step_invalid_action_422(self): |
| self.client.post("/reset", json={}) |
| r = self.client.post("/step", json={"action": {"kind": "nonexistent_kind"}}) |
| assert r.status_code == 422 |
|
|
| def test_state_endpoint(self): |
| self.client.post("/reset", json={"task_name": "easy"}) |
| r = self.client.get("/state") |
| assert r.status_code == 200 |
| assert "current_endpoints" in r.json() |
|
|
| def test_score_endpoint(self): |
| self.client.post("/reset", json={"task_name": "easy"}) |
| r = self.client.get("/score") |
| assert r.status_code == 200 |
| data = r.json() |
| assert "score" in data |
| assert 0.0 <= data["score"] <= 1.0 |
|
|
| def test_tasks_endpoint(self): |
| r = self.client.get("/tasks") |
| assert r.status_code == 200 |
| data = r.json() |
| assert len(data["tasks"]) == 3 |
|
|
| def test_schema_endpoint(self): |
| r = self.client.get("/schema") |
| assert r.status_code == 200 |
| schema = r.json() |
| assert "action" in schema |
| assert "observation" in schema |
|
|
| def test_full_easy_solve_via_http(self): |
| self.client.post("/reset", json={"task_name": "easy"}) |
| r = self.client.post("/step", json={ |
| "action": { |
| "kind": "add_field", |
| "endpoint_index": 0, |
| "location": "response_body", |
| "field_name": "created_at", |
| "new_value": {"type": "string", "required": True, "description": "ts"}, |
| } |
| }) |
| assert r.json()["done"] is True |
| score_r = self.client.get("/score") |
| assert score_r.json()["score"] == pytest.approx(1.0) |
|
|