File size: 4,475 Bytes
4904e85
 
 
 
 
1bc6b3d
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6172160
1bc6b3d
 
aea845f
 
 
 
 
 
 
4904e85
 
 
 
 
6172160
 
 
 
 
 
 
1d762f3
aea845f
 
4904e85
 
 
 
 
 
 
 
1bc6b3d
 
 
 
 
 
 
 
 
 
 
6172160
1bc6b3d
 
 
 
 
 
 
4904e85
0b2675d
 
 
 
 
 
 
1d762f3
 
 
 
0b2675d
 
 
 
 
 
 
 
 
 
 
1bc6b3d
 
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""OpenEnv-compatible environment wrapper."""

import uuid

from src.city_schema import CitySchemaLoader
from src.grading import grade_episode
from src.models import Action, Observation, State
from src.state_machine import DispatchStateMachine


class OpenEnvEnvironment:
    def __init__(self, task_id: str, seed: int | None = None) -> None:
        self.task_id = task_id
        self.seed_value = seed
        schema = CitySchemaLoader.load("metro_city")
        self._machine = DispatchStateMachine(schema=schema, seed=seed)
        self._state: State | None = None
        self._last_observation: Observation | None = None

    async def reset(self) -> Observation:
        episode_id = str(uuid.uuid4())
        self._state = self._machine.reset(task_id=self.task_id, episode_id=episode_id)
        self._state.metadata["cumulative_reward"] = 0.0
        self._state.metadata["episode_rewards"] = []
        self._state.metadata["episode_score"] = 0.0
        active_p1 = sum(
            1
            for i in self._state.incidents.values()
            if i.severity.value == "PRIORITY_1" and i.status.value not in {"RESOLVED", "ESCALATED"}
        )
        avail = sum(1 for u in self._state.units.values() if u.status.value == "AVAILABLE")

        self._last_observation = Observation(
            result="dispatch center online",
            score=0.0,
            protocol_ok=True,
            issues=[],
            reward_breakdown={
                "response_time": 0.0,
                "triage": 0.0,
                "survival": 0.0,
                "coverage": 0.0,
                "protocol": 1.0,
            },
            phraseology_score=1.0,
            active_p1_count=active_p1,
            units_available=avail,
        )
        return self._last_observation

    async def step(self, action: Action) -> tuple[Observation, float, bool]:
        if self._state is None:
            raise RuntimeError("Environment not initialized. Call reset() first.")
        state, obs = self._machine.step(self._state, action)
        self._state = state

        # `DispatchStateMachine.step()` sets `obs.score` to the per-step reward.
        # OpenEnv consumers often interpret `observation.score` as an episode score,
        # so we keep the per-step reward in `reward` and publish the episode score
        # into `observation.score`.
        step_reward = float(obs.score)

        rewards: list[float] = list(self._state.metadata.get("episode_rewards", []))
        rewards.append(step_reward)
        self._state.metadata["episode_rewards"] = rewards

        cumulative = float(self._state.metadata.get("cumulative_reward", 0.0))
        self._state.metadata["cumulative_reward"] = cumulative + step_reward

        # Episode score is derived from the same grading logic as benchmark runs.
        episode_score = grade_episode(task_id=self.task_id, state=self._state, rewards=rewards)
        episode_score = max(0.0, min(1.0, float(episode_score)))
        self._state.metadata["episode_score"] = episode_score

        done = self._machine.is_terminal(state)

        active_p1_count = sum(
            1
            for i in state.incidents.values()
            if i.severity.value == "PRIORITY_1" and i.status.value not in {"RESOLVED", "ESCALATED"}
        )
        units_available = sum(1 for u in state.units.values() if u.status.value == "AVAILABLE")
        
        phraseology = 0.0
        if obs.reward_breakdown:
            phraseology = obs.reward_breakdown.get("protocol", 0.0)

        obs = obs.model_copy(
            update={
                "score": episode_score,
                "phraseology_score": phraseology,
                "active_p1_count": active_p1_count,
                "units_available": units_available,
                "step_count": state.step_count,
                "episode_done": done,
            }
        )
        self._last_observation = obs
        return obs, step_reward, done

    def state(self) -> State:
        if self._state is None:
            raise RuntimeError("Environment not initialized. Call reset() first.")
        return self._state

    def last_observation(self) -> Observation | None:
        return self._last_observation

    def legal_actions(self) -> list[Action]:
        if self._state is None:
            return []
        return self._machine.get_legal_actions(self._state)

    def close(self) -> None:
        self._state = None
        self._last_observation = None