garvitsachdeva commited on
Commit
4dc3d0a
·
1 Parent(s): 43f2683

Fix triage reward lookup for enum-qualified metadata

Browse files
Files changed (2) hide show
  1. src/rewards.py +42 -8
  2. tests/test_rewards.py +3 -0
src/rewards.py CHANGED
@@ -16,6 +16,35 @@ def _clamp01(value: float) -> float:
16
  return max(0.0, min(1.0, float(value)))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class RewardSignal(BaseModel):
20
  """Signal components for reward breakdown."""
21
 
@@ -102,17 +131,22 @@ class RewardCalculator:
102
  if unit is None or incident is None:
103
  return 0.0
104
 
105
- required_map = state.metadata.get("default_required_units", {})
106
- # Try both formats: plain value and StrEnum repr
107
- required_types = (
108
- required_map.get(incident.incident_type.value, [])
109
- or required_map.get(str(incident.incident_type), [])
110
- )
 
 
 
 
 
111
  if not required_types:
112
  return 0.5
113
 
114
- # required_types are stored as strings in metadata.
115
- if unit.unit_type.value in set(required_types):
116
  return 1.0
117
  return 0.0
118
 
 
16
  return max(0.0, min(1.0, float(value)))
17
 
18
 
19
+ def _normalize_enumish_key(value: object) -> str:
20
+ """Normalize keys that may be stored as Enum-ish strings.
21
+
22
+ We accept forms like:
23
+ - "CARDIAC_ARREST"
24
+ - "IncidentType.CARDIAC_ARREST"
25
+ - "src.models.IncidentType.CARDIAC_ARREST"
26
+ - Enum members (IncidentType.CARDIAC_ARREST)
27
+ """
28
+
29
+ if isinstance(value, str):
30
+ text = value
31
+ else:
32
+ text = getattr(value, "value", None) or str(value)
33
+
34
+ # If the value looks like a qualified enum name, use the trailing segment.
35
+ if "." in text:
36
+ return text.split(".")[-1]
37
+ return text
38
+
39
+
40
+ def _normalize_str_list(values: object) -> list[str]:
41
+ if values is None:
42
+ return []
43
+ if not isinstance(values, (list, tuple, set)):
44
+ return [_normalize_enumish_key(values)]
45
+ return [_normalize_enumish_key(v) for v in values]
46
+
47
+
48
  class RewardSignal(BaseModel):
49
  """Signal components for reward breakdown."""
50
 
 
131
  if unit is None or incident is None:
132
  return 0.0
133
 
134
+ required_map_raw = state.metadata.get("default_required_units", {})
135
+ if not isinstance(required_map_raw, dict):
136
+ return 0.5
137
+
138
+ # Normalize metadata so lookups work across serialization styles.
139
+ required_map: dict[str, list[str]] = {
140
+ _normalize_enumish_key(k): _normalize_str_list(v) for k, v in required_map_raw.items()
141
+ }
142
+
143
+ incident_key = _normalize_enumish_key(incident.incident_type)
144
+ required_types = required_map.get(incident_key, [])
145
  if not required_types:
146
  return 0.5
147
 
148
+ # required_types are stored as strings in metadata (often with enum qualifiers).
149
+ if _normalize_enumish_key(unit.unit_type) in set(required_types):
150
  return 1.0
151
  return 0.0
152
 
tests/test_rewards.py CHANGED
@@ -69,6 +69,9 @@ def test_compute_reward_returns_tuple() -> None:
69
  obs = Observation(result="ok", score=0.8, protocol_ok=True, issues=[])
70
  signal, total = calc.compute_reward(state, action, obs)
71
  assert isinstance(signal, RewardSignal)
 
 
 
72
  assert 0.0 <= total <= 1.0
73
 
74
 
 
69
  obs = Observation(result="ok", score=0.8, protocol_ok=True, issues=[])
70
  signal, total = calc.compute_reward(state, action, obs)
71
  assert isinstance(signal, RewardSignal)
72
+ # Fixture metadata stores enum-ish strings (e.g. "IncidentType.CARDIAC_ARREST").
73
+ # Triage should still award full credit for a correct match.
74
+ assert signal.triage == 1.0
75
  assert 0.0 <= total <= 1.0
76
 
77