File size: 4,074 Bytes
d416acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Grader functions for OpenEnv task validation
# Each function is referenced in openenv.yaml as tests.test_graders:grade_<task_id>
# Grader functions must return a float score between 0.0 and 1.0

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from environment.api_triage_env import APITriageEnv
from environment.incident_generator import get_incident_by_type


def _run_agent_on_incident(incident_type, max_steps=10):
    """Simulate an agent solving a specific incident type. Returns score 0.0-1.0."""
    env = APITriageEnv(max_steps=max_steps)

    # Force the specific incident (bypass curriculum randomness)
    env.incident = get_incident_by_type(incident_type)
    env.fix_applied = False
    env.done = False
    env.step_counter = 0
    env.total_reward = 0.0

    # Get the correct fix action for this incident
    correct_action = env.incident["fix_action"]

    # Optimal action sequence: inspect -> fix -> resolve
    actions = ["inspect_logs", correct_action, "resolve"]

    for action in actions:
        state, reward, done, info = env.step(action)
        if done:
            return 0.95 if info.get("resolution") == "success" else 0.05

    return 0.1


# ============================================
# Per-task grader functions (referenced in openenv.yaml)
# ============================================

def grade_auth_error():
    """Grader for auth_error task: 401 Unauthorized - expired API key"""
    score = _run_agent_on_incident("auth_error")
    assert 0.0 <= score <= 1.0, f"Score {score} out of range"
    return score


def grade_missing_fields():
    """Grader for missing_fields task: 400 Bad Request - missing email field"""
    score = _run_agent_on_incident("missing_fields")
    assert 0.0 <= score <= 1.0, f"Score {score} out of range"
    return score


def grade_rate_limit():
    """Grader for rate_limit task: 429 Too Many Requests"""
    score = _run_agent_on_incident("rate_limit")
    assert 0.0 <= score <= 1.0, f"Score {score} out of range"
    return score


def grade_timeout():
    """Grader for timeout task: 408 Request Timeout"""
    score = _run_agent_on_incident("timeout")
    assert 0.0 <= score <= 1.0, f"Score {score} out of range"
    return score


def grade_wrong_endpoint():
    """Grader for wrong_endpoint task: 404 Not Found"""
    score = _run_agent_on_incident("wrong_endpoint")
    assert 0.0 <= score <= 1.0, f"Score {score} out of range"
    return score


def grade_server_error():
    """Grader for server_error task: 500 Internal Server Error"""
    score = _run_agent_on_incident("server_error")
    assert 0.0 <= score <= 1.0, f"Score {score} out of range"
    return score


# ============================================
# Pytest-compatible test wrappers
# ============================================

def test_grade_auth_error():
    score = grade_auth_error()
    assert score > 0.5, f"auth_error grader returned low score: {score}"


def test_grade_missing_fields():
    score = grade_missing_fields()
    assert score > 0.5, f"missing_fields grader returned low score: {score}"


def test_grade_rate_limit():
    score = grade_rate_limit()
    assert score > 0.5, f"rate_limit grader returned low score: {score}"


def test_grade_timeout():
    score = grade_timeout()
    assert score > 0.5, f"timeout grader returned low score: {score}"


def test_grade_wrong_endpoint():
    score = grade_wrong_endpoint()
    assert score > 0.5, f"wrong_endpoint grader returned low score: {score}"


def test_grade_server_error():
    score = grade_server_error()
    assert score > 0.5, f"server_error grader returned low score: {score}"


if __name__ == "__main__":
    print(f"auth_error score:      {grade_auth_error()}")
    print(f"missing_fields score:  {grade_missing_fields()}")
    print(f"rate_limit score:      {grade_rate_limit()}")
    print(f"timeout score:         {grade_timeout()}")
    print(f"wrong_endpoint score:  {grade_wrong_endpoint()}")
    print(f"server_error score:    {grade_server_error()}")
    print("All graders passed!")