File size: 7,421 Bytes
27158b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
graders.py β€” Reward shaping logic for MediRoute OpenEnv.

Each action is evaluated against the ground-truth task expectations.
Rewards are incremental per-step values; the environment accumulates and
clamps the episode total to [0.0, 1.0].

Reward table
─────────────────────────────────────────────────────────────────
 Correct severity classification (analyze_symptoms)   +0.30
 Correct specialist recommendation                    +0.30
 Correct hospital selection                           +0.20
 Successful appointment booking (non-emergency)       +0.20
 Correct emergency escalation (call_ambulance)        +0.50
 Wrong department / specialist                        -0.20
 Unnecessary loop / duplicate action                  -0.30
 Calling ambulance on non-emergency                   -0.30
 Booking appointment in emergency case                -0.30
─────────────────────────────────────────────────────────────────
"""

from __future__ import annotations

from typing import Any, Dict, List

from models import Action


# ─────────────────────────────────────────────
#  Internal helpers
# ─────────────────────────────────────────────

def _is_duplicate(action: Action, previous_actions: List[str]) -> bool:
    return action.as_key() in previous_actions


# ─────────────────────────────────────────────
#  Public API
# ─────────────────────────────────────────────

def grade_step(
    task: Dict[str, Any],
    action: Action,
    previous_actions: List[str],
) -> float:
    """
    Compute the incremental reward for a single action taken in *task*.

    Args:
        task:             The full task dict as returned by tasks.get_task().
        action:           The Action the agent wants to execute.
        previous_actions: Actions already taken this episode (as 'type:target' strings).

    Returns:
        A float reward value (can be negative; clamping is done in the environment).
    """

    # ── Duplicate penalty ────────────────────────────────────────────────────
    if _is_duplicate(action, previous_actions):
        return -0.30

    action_type = action.action_type
    target = (action.target or "").strip()

    # ── analyze_symptoms ─────────────────────────────────────────────────────
    if action_type == "analyze_symptoms":
        if target.lower() == task["expected_severity"].lower():
            return 0.30
        else:
            return -0.10   # Incorrect severity assessment

    # ── request_more_info ────────────────────────────────────────────────────
    elif action_type == "request_more_info":
        # Neutral in most cases; mild reward only if no prior analysis done
        analyzed = any(a.startswith("analyze_symptoms") for a in previous_actions)
        return 0.05 if not analyzed else -0.05

    # ── recommend_specialist ─────────────────────────────────────────────────
    elif action_type == "recommend_specialist":
        if target == task["expected_specialist"]:
            return 0.30
        else:
            return -0.20   # Wrong department

    # ── select_hospital ──────────────────────────────────────────────────────
    elif action_type == "select_hospital":
        if target == task["expected_hospital"]:
            return 0.20
        elif target in task["nearby_hospitals"]:
            return 0.05    # Nearby but not optimal
        else:
            return -0.10   # Unknown / unreachable hospital

    # ── book_appointment ─────────────────────────────────────────────────────
    elif action_type == "book_appointment":
        if task["requires_ambulance"]:
            # Trying to book appointment in a life-threatening emergency is wrong
            return -0.30
        return 0.20

    # ── call_ambulance ───────────────────────────────────────────────────────
    elif action_type == "call_ambulance":
        if task["requires_ambulance"]:
            return 0.50    # Correct emergency escalation
        else:
            return -0.30   # Unnecessary ambulance dispatch

    # ── provide_temp_guidance ─────────────────────────────────────────────────
    elif action_type == "provide_temp_guidance":
        # Acceptable as a closing action for non-emergencies
        if not task["requires_ambulance"]:
            return 0.10
        else:
            return -0.10   # Not enough for a critical patient

    # ── Unknown action ────────────────────────────────────────────────────────
    return -0.10


def grade_episode(
    task: Dict[str, Any],
    all_actions: List[str],
    final_total_reward: float,
) -> Dict[str, Any]:
    """
    Produce a final episode summary / score report.

    Args:
        task:               Task dict.
        all_actions:        Full list of action keys taken during the episode.
        final_total_reward: Accumulated clamped reward from the environment.

    Returns:
        A dict with score, pass/fail, and diagnostic breakdown.
    """
    score = round(final_total_reward, 4)
    passed = score >= 0.5

    breakdown = {
        "severity_classified": any(
            a.startswith(f"analyze_symptoms:{task['expected_severity']}")
            for a in all_actions
        ),
        "correct_specialist": any(
            a.startswith(f"recommend_specialist:{task['expected_specialist']}")
            for a in all_actions
        ),
        "correct_hospital": any(
            a.startswith(f"select_hospital:{task['expected_hospital']}")
            for a in all_actions
        ),
        "ambulance_called": any(a.startswith("call_ambulance") for a in all_actions),
        "appointment_booked": any(a.startswith("book_appointment") for a in all_actions),
    }

    return {
        "score": score,
        "passed": passed,
        "difficulty": task["difficulty"],
        "total_steps": len(all_actions),
        "breakdown": breakdown,
    }