API_DEBUG_SOLVER / tests /test_graders.py
Kavya988's picture
Upload 29 files
d416acc verified
raw
history blame
4.07 kB
# Grader functions for OpenEnv task validation
# Each function is referenced in openenv.yaml as tests.test_graders:grade_<task_id>
# Grader functions must return a float score between 0.0 and 1.0
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from environment.api_triage_env import APITriageEnv
from environment.incident_generator import get_incident_by_type
def _run_agent_on_incident(incident_type, max_steps=10):
"""Simulate an agent solving a specific incident type. Returns score 0.0-1.0."""
env = APITriageEnv(max_steps=max_steps)
# Force the specific incident (bypass curriculum randomness)
env.incident = get_incident_by_type(incident_type)
env.fix_applied = False
env.done = False
env.step_counter = 0
env.total_reward = 0.0
# Get the correct fix action for this incident
correct_action = env.incident["fix_action"]
# Optimal action sequence: inspect -> fix -> resolve
actions = ["inspect_logs", correct_action, "resolve"]
for action in actions:
state, reward, done, info = env.step(action)
if done:
return 0.95 if info.get("resolution") == "success" else 0.05
return 0.1
# ============================================
# Per-task grader functions (referenced in openenv.yaml)
# ============================================
def grade_auth_error():
"""Grader for auth_error task: 401 Unauthorized - expired API key"""
score = _run_agent_on_incident("auth_error")
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
return score
def grade_missing_fields():
"""Grader for missing_fields task: 400 Bad Request - missing email field"""
score = _run_agent_on_incident("missing_fields")
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
return score
def grade_rate_limit():
"""Grader for rate_limit task: 429 Too Many Requests"""
score = _run_agent_on_incident("rate_limit")
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
return score
def grade_timeout():
"""Grader for timeout task: 408 Request Timeout"""
score = _run_agent_on_incident("timeout")
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
return score
def grade_wrong_endpoint():
"""Grader for wrong_endpoint task: 404 Not Found"""
score = _run_agent_on_incident("wrong_endpoint")
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
return score
def grade_server_error():
"""Grader for server_error task: 500 Internal Server Error"""
score = _run_agent_on_incident("server_error")
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
return score
# ============================================
# Pytest-compatible test wrappers
# ============================================
def test_grade_auth_error():
score = grade_auth_error()
assert score > 0.5, f"auth_error grader returned low score: {score}"
def test_grade_missing_fields():
score = grade_missing_fields()
assert score > 0.5, f"missing_fields grader returned low score: {score}"
def test_grade_rate_limit():
score = grade_rate_limit()
assert score > 0.5, f"rate_limit grader returned low score: {score}"
def test_grade_timeout():
score = grade_timeout()
assert score > 0.5, f"timeout grader returned low score: {score}"
def test_grade_wrong_endpoint():
score = grade_wrong_endpoint()
assert score > 0.5, f"wrong_endpoint grader returned low score: {score}"
def test_grade_server_error():
score = grade_server_error()
assert score > 0.5, f"server_error grader returned low score: {score}"
if __name__ == "__main__":
print(f"auth_error score: {grade_auth_error()}")
print(f"missing_fields score: {grade_missing_fields()}")
print(f"rate_limit score: {grade_rate_limit()}")
print(f"timeout score: {grade_timeout()}")
print(f"wrong_endpoint score: {grade_wrong_endpoint()}")
print(f"server_error score: {grade_server_error()}")
print("All graders passed!")