File size: 5,453 Bytes
9fe417b
 
 
 
 
2dbf205
9fe417b
 
 
 
ec7e9a5
9fe417b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dbf205
9fe417b
 
 
 
 
 
2dbf205
ec7e9a5
9fe417b
 
2dbf205
9fe417b
 
 
 
 
 
 
 
 
 
 
2dbf205
ec7e9a5
9fe417b
 
 
ec7e9a5
 
 
 
 
 
 
9fe417b
 
2dbf205
9fe417b
 
 
 
 
 
2dbf205
 
 
9fe417b
 
 
 
 
 
 
 
 
 
 
 
2dbf205
 
 
 
 
 
 
 
 
9fe417b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dbf205
 
9fe417b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dbf205
 
ec7e9a5
2dbf205
 
 
 
 
 
 
 
 
 
 
ec7e9a5
2dbf205
 
9fe417b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""OpenEnv server wrapper for the focus scheduling simulator."""

from __future__ import annotations

from uuid import uuid4
import os

from openenv.core.env_server.interfaces import Environment, EnvironmentMetadata
from openenv.core.env_server.types import State

from benchmark_tasks import TASK_SPECS, apply_task
from focus_resource_env import FocusResourceEnv

try:
    from ..models import EngineerManagerAction, EngineerManagerObservation
except ImportError:
    from models import EngineerManagerAction, EngineerManagerObservation


class EngineerManagerEnvironment(
    Environment[EngineerManagerAction, EngineerManagerObservation, State]
):
    """Expose the scheduling simulator through the OpenEnv HTTP contract."""

    SUPPORTS_CONCURRENT_SESSIONS = True

    def __init__(
        self,
        start_hour: str = "09:00",
        end_hour: str = "17:00",
        distraction_risk: float = 0.15,
        seed: int | None = 7,
        task_name: str | None = None,
    ) -> None:
        super().__init__()
        self._start_hour = start_hour
        self._end_hour = end_hour
        self._distraction_risk = distraction_risk
        self._seed = seed
        self._task_name = task_name or os.getenv("TASK_NAME")
        self._task_id = 0
        self._step_count = 0
        self._episode_id = str(uuid4())
        self._trajectory: list[dict[str, object]] = []
        self._env = FocusResourceEnv(
            start_hour=start_hour,
            end_hour=end_hour,
            distraction_risk=distraction_risk,
            seed=seed,
        )

    def reset(
        self,
        seed: int | None = None,
        episode_id: str | None = None,
        task_name: str | None = None,
        task_id: int | None = None,
        **_: object,
    ) -> EngineerManagerObservation:
        self._seed = self._seed if seed is None else seed
        task_names = ["quiet-morning", "meeting-surgery", "delivery-triage"]
        if task_id is not None and 0 <= int(task_id) < len(task_names):
            self._task_id = int(task_id)
            self._task_name = task_names[self._task_id]
        else:
            self._task_name = task_name or self._task_name or os.getenv("TASK_NAME")
            self._task_id = task_names.index(self._task_name) if self._task_name in task_names else 0
        self._episode_id = episode_id or str(uuid4())
        self._step_count = 0
        self._trajectory = []
        self._env = FocusResourceEnv(
            start_hour=self._start_hour,
            end_hour=self._end_hour,
            distraction_risk=self._distraction_risk,
            seed=self._seed,
        )
        self._env.reset()
        apply_task(self._env, self._task_name)
        return self._to_observation(self._env._observation(), reward=0.0, done=False)

    def step(
        self,
        action: EngineerManagerAction,
        timeout_s: float | None = None,
        **_: object,
    ) -> EngineerManagerObservation:
        del timeout_s
        observation, reward, done, info = self._env.step(
            (action.target_slot, action.operation)
        )
        self._step_count += 1
        self._trajectory.append(
            {
                "action": {"target_slot": int(action.target_slot), "operation": int(action.operation)},
                "observation": observation,
                "reward": float(reward),
                "done": bool(done),
                "info": info,
            }
        )
        return self._to_observation(observation, reward=reward, done=done, info=info)

    @property
    def state(self) -> State:
        return State(
            episode_id=self._episode_id,
            step_count=self._step_count,
            current_slot=self._env.current_slot,
            done=self._env.current_slot >= self._env.timeline_length,
        )

    def get_metadata(self) -> EnvironmentMetadata:
        return EnvironmentMetadata(
            name="Engineer Manager",
            description=(
                "Manage a workday by scheduling deep work, rescheduling meetings, "
                "and controlling communication noise. "
                f"Available tasks: {', '.join(sorted(TASK_SPECS))}."
            ),
            version="0.1.0",
        )

    def _to_observation(
        self,
        observation: dict[str, object],
        *,
        reward: float | None,
        done: bool,
        info: dict[str, object] | None = None,
    ) -> EngineerManagerObservation:
        payload = dict(observation)
        payload["reward"] = reward
        payload["done"] = done
        metadata = dict(info or {})
        metadata["task_name"] = self._task_name
        metadata["task_id"] = self._task_id
        metadata["episode_metrics"] = {
            "interruptions": int(self._env.interruptions),
            "invalid_actions": int(self._env.invalid_actions),
            "remaining_tasks": len(self._env.task_buffer),
            "scheduled_work_slots": sum(1 for slot in self._env.timeline if int(slot) == 1),
            "successful_reschedules": sum(
                1
                for step in self._trajectory
                if step["info"].get("action_info", {}).get("status") == "meeting_rescheduled"
            ),
            "total_score": float(self._env._total_score()),
            "grader_score": min(max(float(reward or 0.0), 0.0), 1.0),
        }
        payload["metadata"] = metadata
        return EngineerManagerObservation.model_validate(payload)