Spaces:
Sleeping
Sleeping
| # Grader functions for OpenEnv task validation | |
| # Each function is referenced in openenv.yaml as tests.test_graders:grade_<task_id> | |
| # Grader functions must return a float score between 0.0 and 1.0 | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from environment.api_triage_env import APITriageEnv | |
| from environment.incident_generator import get_incident_by_type | |
| def _run_agent_on_incident(incident_type, max_steps=10): | |
| """Simulate an agent solving a specific incident type. Returns score 0.0-1.0.""" | |
| env = APITriageEnv(max_steps=max_steps) | |
| # Force the specific incident (bypass curriculum randomness) | |
| env.incident = get_incident_by_type(incident_type) | |
| env.fix_applied = False | |
| env.done = False | |
| env.step_counter = 0 | |
| env.total_reward = 0.0 | |
| # Get the correct fix action for this incident | |
| correct_action = env.incident["fix_action"] | |
| # Optimal action sequence: inspect -> fix -> resolve | |
| actions = ["inspect_logs", correct_action, "resolve"] | |
| for action in actions: | |
| state, reward, done, info = env.step(action) | |
| if done: | |
| return 0.95 if info.get("resolution") == "success" else 0.05 | |
| return 0.1 | |
| # ============================================ | |
| # Per-task grader functions (referenced in openenv.yaml) | |
| # ============================================ | |
| def grade_auth_error(): | |
| """Grader for auth_error task: 401 Unauthorized - expired API key""" | |
| score = _run_agent_on_incident("auth_error") | |
| assert 0.0 <= score <= 1.0, f"Score {score} out of range" | |
| return score | |
| def grade_missing_fields(): | |
| """Grader for missing_fields task: 400 Bad Request - missing email field""" | |
| score = _run_agent_on_incident("missing_fields") | |
| assert 0.0 <= score <= 1.0, f"Score {score} out of range" | |
| return score | |
| def grade_rate_limit(): | |
| """Grader for rate_limit task: 429 Too Many Requests""" | |
| score = _run_agent_on_incident("rate_limit") | |
| assert 0.0 <= score <= 1.0, f"Score {score} out of range" | |
| return score | |
| def grade_timeout(): | |
| """Grader for timeout task: 408 Request Timeout""" | |
| score = _run_agent_on_incident("timeout") | |
| assert 0.0 <= score <= 1.0, f"Score {score} out of range" | |
| return score | |
| def grade_wrong_endpoint(): | |
| """Grader for wrong_endpoint task: 404 Not Found""" | |
| score = _run_agent_on_incident("wrong_endpoint") | |
| assert 0.0 <= score <= 1.0, f"Score {score} out of range" | |
| return score | |
| def grade_server_error(): | |
| """Grader for server_error task: 500 Internal Server Error""" | |
| score = _run_agent_on_incident("server_error") | |
| assert 0.0 <= score <= 1.0, f"Score {score} out of range" | |
| return score | |
| # ============================================ | |
| # Pytest-compatible test wrappers | |
| # ============================================ | |
| def test_grade_auth_error(): | |
| score = grade_auth_error() | |
| assert score > 0.5, f"auth_error grader returned low score: {score}" | |
| def test_grade_missing_fields(): | |
| score = grade_missing_fields() | |
| assert score > 0.5, f"missing_fields grader returned low score: {score}" | |
| def test_grade_rate_limit(): | |
| score = grade_rate_limit() | |
| assert score > 0.5, f"rate_limit grader returned low score: {score}" | |
| def test_grade_timeout(): | |
| score = grade_timeout() | |
| assert score > 0.5, f"timeout grader returned low score: {score}" | |
| def test_grade_wrong_endpoint(): | |
| score = grade_wrong_endpoint() | |
| assert score > 0.5, f"wrong_endpoint grader returned low score: {score}" | |
| def test_grade_server_error(): | |
| score = grade_server_error() | |
| assert score > 0.5, f"server_error grader returned low score: {score}" | |
| if __name__ == "__main__": | |
| print(f"auth_error score: {grade_auth_error()}") | |
| print(f"missing_fields score: {grade_missing_fields()}") | |
| print(f"rate_limit score: {grade_rate_limit()}") | |
| print(f"timeout score: {grade_timeout()}") | |
| print(f"wrong_endpoint score: {grade_wrong_endpoint()}") | |
| print(f"server_error score: {grade_server_error()}") | |
| print("All graders passed!") | |