keerthanas1011's picture
API Contract Debugger OpenEnv Environment
5cf6185
"""
Test suite for the API Contract Debugger environment.
Coverage:
- Violation detection (all violation types)
- Grader scoring
- Per-step reward shaping
- Environment reset / step / state
- All three tasks end-to-end
- Edge cases: malformed actions, double-fix, already-clean spec
- HTTP API routes (via TestClient)
"""
from __future__ import annotations
import copy
import sys
import os
import pytest
# Make sure the project root is on the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from server.fixtures import TASK_EASY, TASK_HARD, TASK_MEDIUM, TASKS
from server.graders import detect_violations, grade_episode, step_reward
from server.models import ActionKind, DebugAction
from server.environment import APIContractDebuggerEnv
# ===========================================================================
# Helpers
# ===========================================================================
def make_env(task: str = "easy") -> APIContractDebuggerEnv:
env = APIContractDebuggerEnv(task_name=task)
env.reset()
return env
def action(**kwargs) -> DebugAction:
defaults = dict(
kind=ActionKind.NO_OP,
endpoint_index=0,
location="response_body",
field_name=None,
new_value=None,
)
defaults.update(kwargs)
return DebugAction(**defaults)
# ===========================================================================
# 1. Fixture sanity
# ===========================================================================
class TestFixtures:
def test_all_tasks_present(self):
assert set(TASKS.keys()) == {"easy", "medium", "hard"}
def test_easy_has_violations(self):
v = detect_violations(TASK_EASY["broken_endpoints"], TASK_EASY["golden_endpoints"])
assert len(v) == 1
def test_medium_has_three_violations(self):
v = detect_violations(TASK_MEDIUM["broken_endpoints"], TASK_MEDIUM["golden_endpoints"])
assert len(v) == 3
def test_hard_has_six_violations(self):
v = detect_violations(TASK_HARD["broken_endpoints"], TASK_HARD["golden_endpoints"])
assert len(v) == 6
def test_golden_specs_are_clean(self):
for task in TASKS.values():
v = detect_violations(task["golden_endpoints"], task["golden_endpoints"])
assert v == [], f"Golden spec for '{task['name']}' has violations: {v}"
def test_broken_and_golden_same_length(self):
for task in TASKS.values():
assert len(task["broken_endpoints"]) == len(task["golden_endpoints"])
# ===========================================================================
# 2. Violation detection
# ===========================================================================
class TestViolationDetection:
def test_missing_field_detected(self):
current = [{"method": "GET", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {}}]
golden = [{"method": "GET", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {
"id": {"type": "integer", "required": True, "description": ""}
}}]
v = detect_violations(current, golden)
assert len(v) == 1
assert v[0]["violation_type"] == "missing_field"
assert v[0]["field_name"] == "id"
def test_extra_field_detected(self):
current = [{"method": "GET", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {
"secret": {"type": "string", "required": False, "description": ""}
}}]
golden = [{"method": "GET", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {}}]
v = detect_violations(current, golden)
assert len(v) == 1
assert v[0]["violation_type"] == "extra_field"
def test_wrong_type_detected(self):
current = [{"method": "GET", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {
"count": {"type": "string", "required": True, "description": ""}
}}]
golden = [{"method": "GET", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {
"count": {"type": "integer", "required": True, "description": ""}
}}]
v = detect_violations(current, golden)
assert len(v) == 1
assert v[0]["violation_type"] == "wrong_type"
def test_wrong_status_detected(self):
current = [{"method": "DELETE", "path": "/x", "status_code": 200,
"request_body": {}, "response_body": {}}]
golden = [{"method": "DELETE", "path": "/x", "status_code": 204,
"request_body": {}, "response_body": {}}]
v = detect_violations(current, golden)
assert len(v) == 1
assert v[0]["violation_type"] == "wrong_status"
def test_no_violations_on_matching_spec(self):
golden = TASK_EASY["golden_endpoints"]
v = detect_violations(golden, golden)
assert v == []
def test_violation_severity_range(self):
v = detect_violations(TASK_HARD["broken_endpoints"], TASK_HARD["golden_endpoints"])
for viol in v:
assert 0.0 < viol["severity"] <= 1.0
# ===========================================================================
# 3. Grader scoring
# ===========================================================================
class TestGrader:
def test_perfect_score_when_all_fixed(self):
golden = TASK_EASY["golden_endpoints"]
initial = detect_violations(TASK_EASY["broken_endpoints"], golden)
score = grade_episode(golden, golden, initial)
assert score == pytest.approx(1.0)
def test_zero_score_when_nothing_fixed(self):
broken = TASK_EASY["broken_endpoints"]
golden = TASK_EASY["golden_endpoints"]
initial = detect_violations(broken, golden)
score = grade_episode(broken, golden, initial)
assert score == pytest.approx(0.0)
def test_partial_score_medium(self):
broken = copy.deepcopy(TASK_MEDIUM["broken_endpoints"])
golden = TASK_MEDIUM["golden_endpoints"]
initial = detect_violations(broken, golden)
# Fix only violation 1: product_id type
broken[0]["response_body"]["product_id"]["type"] = "integer"
score = grade_episode(broken, golden, initial)
assert 0.0 < score < 1.0
def test_score_clamped_to_zero_when_extra_violations_introduced(self):
broken = copy.deepcopy(TASK_EASY["broken_endpoints"])
golden = TASK_EASY["golden_endpoints"]
initial = detect_violations(broken, golden)
# Introduce more violations
broken[0]["response_body"]["user_id"]["type"] = "string"
broken[0]["response_body"]["username"]["type"] = "boolean"
score = grade_episode(broken, golden, initial)
assert score == 0.0
def test_score_in_range(self):
for task in TASKS.values():
broken = task["broken_endpoints"]
golden = task["golden_endpoints"]
initial = detect_violations(broken, golden)
score = grade_episode(broken, golden, initial)
assert 0.0 <= score <= 1.0, f"Out-of-range score for task '{task['name']}'"
def test_already_clean_spec_scores_one(self):
golden = TASK_EASY["golden_endpoints"]
initial: list = [] # no violations at start
score = grade_episode(golden, golden, initial)
assert score == pytest.approx(1.0)
# ===========================================================================
# 4. Step reward
# ===========================================================================
class TestStepReward:
def _make_violation(self, vtype="missing_field", severity=1.0):
return {
"endpoint_index": 0, "location": "response_body",
"field_name": "foo", "violation_type": vtype,
"description": "test", "severity": severity,
}
def test_positive_reward_for_fix(self):
v = self._make_violation()
r = step_reward(prev_violations=[v], new_violations=[], initial_violations=[v], action_error=False)
assert r > 0
def test_negative_reward_for_introduction(self):
v = self._make_violation()
r = step_reward(prev_violations=[], new_violations=[v], initial_violations=[], action_error=False)
assert r < 0
def test_penalty_for_action_error(self):
r = step_reward(prev_violations=[], new_violations=[], initial_violations=[], action_error=True)
assert r == pytest.approx(-0.05)
def test_zero_reward_for_no_op(self):
r = step_reward(prev_violations=[], new_violations=[], initial_violations=[], action_error=False)
assert r == pytest.approx(0.0)
# ===========================================================================
# 5. Environment — reset
# ===========================================================================
class TestEnvReset:
def test_reset_returns_observation(self):
env = APIContractDebuggerEnv(task_name="easy")
obs = env.reset()
assert obs.task_name == "easy"
assert len(obs.violations) == 1
assert obs.done is False
assert obs.step_count == 0
def test_reset_clears_state(self):
env = make_env("easy")
# Take a step, then reset
env.step(action(
kind=ActionKind.ADD_FIELD,
location="response_body",
field_name="created_at",
new_value={"type": "string", "required": True, "description": "timestamp"},
))
obs = env.reset()
assert obs.step_count == 0
assert len(obs.violations) == 1 # back to broken state
def test_reset_switches_task(self):
env = APIContractDebuggerEnv(task_name="easy")
obs = env.reset(task_name="medium")
assert obs.task_name == "medium"
assert len(obs.violations) == 3
def test_reset_preserves_golden(self):
env = make_env("hard")
obs = env.reset()
assert obs.total_violations_at_start == 6
def test_episode_id_set_on_reset(self):
env = APIContractDebuggerEnv(task_name="easy")
env.reset(episode_id="test-123")
assert env.state.episode_id == "test-123"
# ===========================================================================
# 6. Environment — step mechanics
# ===========================================================================
class TestEnvStep:
def test_add_missing_field_fixes_easy(self):
env = make_env("easy")
obs = env.step(action(
kind=ActionKind.ADD_FIELD,
location="response_body",
field_name="created_at",
new_value={"type": "string", "required": True, "description": "ISO timestamp"},
))
assert len(obs.violations) == 0
assert obs.done is True
assert obs.reward > 0
def test_wrong_type_action_introduces_violation(self):
env = make_env("easy")
obs = env.step(action(
kind=ActionKind.ADD_FIELD,
location="response_body",
field_name="created_at",
new_value={"type": "integer", "required": True, "description": "wrong type"},
))
# Still has a violation (wrong type now)
assert len(obs.violations) == 1
assert obs.violations[0]["violation_type"] == "wrong_type"
def test_out_of_range_endpoint_index(self):
env = make_env("easy")
obs = env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=99,
field_name="x", new_value={"type": "string"}))
assert obs.last_action_error is not None
assert "out of range" in obs.last_action_error
def test_change_type_fixes_medium_violation(self):
env = make_env("medium")
# Fix violation 1: product_id type string→integer in response
obs = env.step(action(
kind=ActionKind.CHANGE_TYPE,
endpoint_index=0,
location="response_body",
field_name="product_id",
new_value="integer",
))
assert obs.violations_fixed_this_step == 1
assert len(obs.violations) == 2 # 2 remaining
def test_change_status_fixes_medium_violation(self):
env = make_env("medium")
obs = env.step(action(
kind=ActionKind.CHANGE_STATUS,
endpoint_index=2,
location="status_code",
new_value=204,
))
assert obs.violations_fixed_this_step == 1
def test_remove_field_fixes_hard_extra_field(self):
env = make_env("hard")
obs = env.step(action(
kind=ActionKind.REMOVE_FIELD,
endpoint_index=1,
location="response_body",
field_name="password_hash",
))
assert obs.violations_fixed_this_step == 1
def test_no_op_does_not_change_violations(self):
env = make_env("easy")
before = len(env.state.violations)
obs = env.step(action(kind=ActionKind.NO_OP))
assert len(obs.violations) == before
def test_step_after_done_returns_done(self):
env = make_env("easy")
# Solve it
env.step(action(
kind=ActionKind.ADD_FIELD,
location="response_body",
field_name="created_at",
new_value={"type": "string", "required": True, "description": "ts"},
))
# Step again — should get done=True with error message
obs = env.step(action(kind=ActionKind.NO_OP))
assert obs.done is True
assert obs.last_action_error is not None
def test_max_steps_terminates_episode(self):
env = APIContractDebuggerEnv(task_name="easy")
env.reset()
obs = None
for _ in range(env._task_cfg["max_steps"]):
obs = env.step(action(kind=ActionKind.NO_OP))
assert obs.done is True
def test_step_count_increments(self):
env = make_env("easy")
env.step(action(kind=ActionKind.NO_OP))
env.step(action(kind=ActionKind.NO_OP))
assert env.state.step_count == 2
# ===========================================================================
# 7. Environment — state
# ===========================================================================
class TestEnvState:
def test_state_reflects_current_endpoints(self):
env = make_env("easy")
state = env.state
assert len(state.current_endpoints) == 1
assert state.task_name == "easy"
def test_state_tracks_step_count(self):
env = make_env("easy")
env.step(action(kind=ActionKind.NO_OP))
assert env.state.step_count == 1
def test_original_endpoints_unchanged_after_steps(self):
env = make_env("easy")
original_before = copy.deepcopy(env.state.original_endpoints)
env.step(action(
kind=ActionKind.ADD_FIELD,
location="response_body",
field_name="created_at",
new_value={"type": "string", "required": True, "description": "ts"},
))
assert env.state.original_endpoints == original_before
# ===========================================================================
# 8. Full episode walkthroughs
# ===========================================================================
class TestFullEpisodes:
def test_easy_perfect_solve(self):
env = make_env("easy")
env.step(action(
kind=ActionKind.ADD_FIELD,
location="response_body",
field_name="created_at",
new_value={"type": "string", "required": True, "description": "ISO timestamp"},
))
assert env.score() == pytest.approx(1.0)
def test_medium_perfect_solve(self):
env = make_env("medium")
# Fix 1: product_id type
env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=0,
location="response_body", field_name="product_id", new_value="integer"))
# Fix 2: quantity type
env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=1,
location="request_body", field_name="quantity", new_value="integer"))
# Fix 3: DELETE status code
env.step(action(kind=ActionKind.CHANGE_STATUS, endpoint_index=2,
location="status_code", new_value=204))
assert env.score() == pytest.approx(1.0)
def test_hard_perfect_solve(self):
env = make_env("hard")
# Fix 1: add refresh_token to /auth/login response
env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=0,
location="response_body", field_name="refresh_token",
new_value={"type": "string", "required": True, "description": "Refresh token"}))
# Fix 2: expires_in type string→integer in /auth/login response
env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=0,
location="response_body", field_name="expires_in", new_value="integer"))
# Fix 3: add created_at to /users/{id}/profile response
env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=1,
location="response_body", field_name="created_at",
new_value={"type": "string", "required": True, "description": "ISO timestamp"}))
# Fix 4: remove password_hash from /users/{id}/profile response
env.step(action(kind=ActionKind.REMOVE_FIELD, endpoint_index=1,
location="response_body", field_name="password_hash"))
# Fix 5: PATCH status 500→200
env.step(action(kind=ActionKind.CHANGE_STATUS, endpoint_index=2,
location="status_code", new_value=200))
# Fix 6: add updated_at to PATCH response
env.step(action(kind=ActionKind.ADD_FIELD, endpoint_index=2,
location="response_body", field_name="updated_at",
new_value={"type": "string", "required": True, "description": "ISO timestamp"}))
assert env.score() == pytest.approx(1.0)
def test_score_after_partial_solve(self):
env = make_env("medium")
# Fix only 1 of 3
env.step(action(kind=ActionKind.CHANGE_TYPE, endpoint_index=0,
location="response_body", field_name="product_id", new_value="integer"))
score = env.score()
assert 0.0 < score < 1.0
def test_unknown_task_raises(self):
with pytest.raises(ValueError, match="Unknown task"):
APIContractDebuggerEnv(task_name="impossible")
# ===========================================================================
# 9. HTTP API routes (FastAPI TestClient)
# ===========================================================================
class TestHTTPRoutes:
@pytest.fixture(autouse=True)
def client(self):
from fastapi.testclient import TestClient
from server.app import app
self.client = TestClient(app)
def test_health_endpoint(self):
r = self.client.get("/health")
assert r.status_code == 200
def test_reset_returns_200(self):
r = self.client.post("/reset", json={})
assert r.status_code == 200
data = r.json()
assert "violations" in data
assert "endpoints" in data
def test_reset_switches_task(self):
r = self.client.post("/reset", json={"task_name": "medium"})
assert r.status_code == 200
assert r.json()["task_name"] == "medium"
def test_reset_unknown_task_422(self):
r = self.client.post("/reset", json={"task_name": "impossible"})
assert r.status_code == 422
def test_step_add_field(self):
self.client.post("/reset", json={"task_name": "easy"})
r = self.client.post("/step", json={
"action": {
"kind": "add_field",
"endpoint_index": 0,
"location": "response_body",
"field_name": "created_at",
"new_value": {"type": "string", "required": True, "description": "ts"},
}
})
assert r.status_code == 200
data = r.json()
assert data["done"] is True
assert data["reward"] > 0
def test_step_invalid_action_422(self):
self.client.post("/reset", json={})
r = self.client.post("/step", json={"action": {"kind": "nonexistent_kind"}})
assert r.status_code == 422
def test_state_endpoint(self):
self.client.post("/reset", json={"task_name": "easy"})
r = self.client.get("/state")
assert r.status_code == 200
assert "current_endpoints" in r.json()
def test_score_endpoint(self):
self.client.post("/reset", json={"task_name": "easy"})
r = self.client.get("/score")
assert r.status_code == 200
data = r.json()
assert "score" in data
assert 0.0 <= data["score"] <= 1.0
def test_tasks_endpoint(self):
r = self.client.get("/tasks")
assert r.status_code == 200
data = r.json()
assert len(data["tasks"]) == 3
def test_schema_endpoint(self):
r = self.client.get("/schema")
assert r.status_code == 200
schema = r.json()
assert "action" in schema
assert "observation" in schema
def test_full_easy_solve_via_http(self):
self.client.post("/reset", json={"task_name": "easy"})
r = self.client.post("/step", json={
"action": {
"kind": "add_field",
"endpoint_index": 0,
"location": "response_body",
"field_name": "created_at",
"new_value": {"type": "string", "required": True, "description": "ts"},
}
})
assert r.json()["done"] is True
score_r = self.client.get("/score")
assert score_r.json()["score"] == pytest.approx(1.0)