| from pydantic import BaseModel, Field |
| from typing import List, Optional, Literal, Tuple, Dict, Any |
| import random |
|
|
| |
| |
| |
| TaskType = Literal["email", "meeting", "code_review", "report", "call"] |
| Priority = Literal["critical", "high", "normal", "low"] |
|
|
| PRIORITY_WEIGHT = {"critical": 1.5, "high": 1.2, "normal": 1.0, "low": 0.7} |
| TASK_ENERGY_COST = {"email": 0.08, "meeting": 0.18, "code_review": 0.20, "report": 0.14, "call": 0.11} |
| TASK_PROGRESS_RATE = {"email": 0.35, "meeting": 0.30, "code_review": 0.20, "report": 0.22, "call": 0.28} |
| COGNITIVE_BUCKETS = {"email": "social", "meeting": "social", "code_review": "analytical", "report": "analytical", "call": "social"} |
|
|
| ALL_TASK_TYPES: list[TaskType] = ["email", "meeting", "code_review", "report", "call"] |
| ALL_PRIORITIES: list[Priority] = ["critical", "high", "normal", "low"] |
|
|
| |
| |
| |
| class Task(BaseModel): |
| id: str |
| difficulty: str |
| task_type: TaskType = "report" |
| priority: Priority = "normal" |
| progress: float = 0.0 |
| deadline: Optional[int] = None |
| depends_on: Optional[str] = None |
| is_interrupted: bool = False |
|
|
| class WorkerState(BaseModel): |
| id: str |
| energy: float = 1.0 |
| stress: float = 0.0 |
| current_task_id: Optional[str] = None |
| expertise: str = "analytical" |
|
|
| class VisibleWorker(BaseModel): |
| id: str |
| fatigue_level: str |
| stress_level: str |
| stress_warning: bool |
| expertise: str |
| current_task_id: Optional[str] = None |
|
|
| class VisibleState(BaseModel): |
| """ |
| Partial observability for the Oracle Manager. |
| """ |
| workers: List[VisibleWorker] = [] |
| focus_mode: bool = False |
| upcoming_deadlines: List[str] = [] |
| blocked_tasks: List[str] = [] |
|
|
| class Observation(BaseModel): |
| tasks: List[Task] |
| visible_state: VisibleState |
| time_step: int |
|
|
| class Action(BaseModel): |
| type: Literal["work", "break", "switch", "delay", "focus"] |
| task_id: Optional[str] = None |
| worker_id: Optional[str] = None |
|
|
| class EnvState(BaseModel): |
| workers: List[WorkerState] = [] |
| time_step: int = 0 |
| tasks: List[Task] = [] |
| focus_mode: bool = False |
| interruption_count: int = 0 |
| milestone_rewards: Dict[str, float] = {} |
| next_interrupt_eligible: int = 999 |
| interrupt_budget: int = 0 |
| server_outage_active: bool = False |
|
|
|
|
| |
| |
| |
| |
| |
| def generate_tasks(level: str, seed: Optional[int] = None) -> list[Task]: |
| """ |
| Generate tasks for the given difficulty level. |
| Pass seed=None for a random seed (default for live play), |
| or an explicit int for reproducible evaluation runs. |
| """ |
| rng = random.Random(seed) |
|
|
| def _jitter(base: int, lo: int = -3, hi: int = 3) -> int: |
| return max(1, base + rng.randint(lo, hi)) |
|
|
| def _p(pool: list) -> str: |
| return rng.choice(pool) |
|
|
| if level == "easy": |
| return [ |
| Task(id="e1", difficulty="easy", |
| task_type=_p(["email", "report"]), |
| priority=_p(["normal", "high"]), |
| deadline=None), |
| Task(id="e2", difficulty="easy", |
| task_type=_p(["report", "code_review"]), |
| priority=_p(["normal", "low"]), |
| deadline=None), |
| ] |
|
|
| elif level == "medium": |
| return [ |
| Task(id="m1", difficulty="medium", |
| task_type=_p(["email", "call"]), |
| priority="critical", |
| deadline=_jitter(14)), |
| Task(id="m2", difficulty="medium", |
| task_type=_p(["meeting", "code_review"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(20)), |
| Task(id="m3", difficulty="medium", |
| task_type=_p(["code_review", "report"]), |
| priority=_p(["normal", "high"]), |
| deadline=_jitter(28)), |
| Task(id="m4", difficulty="medium", |
| task_type=_p(["report", "meeting"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(35)), |
| Task(id="m5", difficulty="medium", |
| task_type=_p(["call", "email"]), |
| priority=_p(["low", "normal"]), |
| deadline=_jitter(45)), |
| ] |
|
|
| elif level == "hard": |
| return [ |
| Task(id="h1", difficulty="hard", |
| task_type=_p(["email", "call"]), |
| priority="critical", |
| deadline=_jitter(12)), |
| Task(id="h2", difficulty="hard", |
| task_type=_p(["code_review", "report"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(16)), |
| Task(id="h3", difficulty="hard", |
| task_type=_p(["meeting", "call"]), |
| priority="critical", |
| deadline=_jitter(20), |
| depends_on="h1"), |
| Task(id="h4", difficulty="hard", |
| task_type=_p(["report", "code_review"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(24)), |
| Task(id="h5", difficulty="hard", |
| task_type=_p(["call", "meeting"]), |
| priority=_p(["normal", "high"]), |
| deadline=_jitter(28), |
| depends_on="h2"), |
| Task(id="h6", difficulty="hard", |
| task_type=_p(["email", "report"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(32)), |
| Task(id="h7", difficulty="hard", |
| task_type=_p(["code_review", "meeting"]), |
| priority="critical", |
| deadline=_jitter(38), |
| depends_on="h4"), |
| Task(id="h8", difficulty="hard", |
| task_type=_p(["report", "email"]), |
| priority=_p(["normal", "low"]), |
| deadline=_jitter(46)), |
| ] |
|
|
| elif level == "expert": |
| return [ |
| Task(id="x1", difficulty="expert", |
| task_type=_p(["email", "call"]), |
| priority="critical", |
| deadline=_jitter(8)), |
| Task(id="x2", difficulty="expert", |
| task_type=_p(["code_review", "report"]), |
| priority=_p(["high", "critical"]), |
| deadline=_jitter(12)), |
| Task(id="x3", difficulty="expert", |
| task_type=_p(["meeting", "call"]), |
| priority="critical", |
| deadline=_jitter(14), |
| depends_on="x1"), |
| Task(id="x4", difficulty="expert", |
| task_type=_p(["report", "code_review"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(18), |
| depends_on="x2"), |
| Task(id="x5", difficulty="expert", |
| task_type=_p(["call", "meeting"]), |
| priority=_p(["normal", "high"]), |
| deadline=_jitter(22), |
| depends_on="x3"), |
| Task(id="x6", difficulty="expert", |
| task_type=_p(["code_review", "email"]), |
| priority="critical", |
| deadline=_jitter(24)), |
| Task(id="x7", difficulty="expert", |
| task_type=_p(["email", "report"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(28), |
| depends_on="x4"), |
| Task(id="x8", difficulty="expert", |
| task_type=_p(["report", "call"]), |
| priority=_p(["normal", "high"]), |
| deadline=_jitter(33), |
| depends_on="x6"), |
| Task(id="x9", difficulty="expert", |
| task_type=_p(["meeting", "code_review"]), |
| priority="critical", |
| deadline=_jitter(36), |
| depends_on="x5"), |
| Task(id="x10", difficulty="expert", |
| task_type=_p(["call", "email"]), |
| priority=_p(["high", "normal"]), |
| deadline=_jitter(44)), |
| ] |
|
|
| return [] |
|
|
|
|
| def _inject_interruption(state: EnvState, step: int) -> None: |
| """Inject an urgent email task mid-episode (hard/expert levels).""" |
| iid = f"int{state.interruption_count + 1}" |
| state.tasks.append(Task( |
| id=iid, difficulty=state.tasks[0].difficulty, |
| task_type="email", priority="critical", |
| deadline=step + 8, is_interrupted=True, |
| )) |
| state.interruption_count += 1 |
|
|
|
|
| |
| |
| |
| def grader(trajectory: dict) -> float: |
| if not trajectory or not trajectory.get("tasks"): |
| return 0.01 |
|
|
| raw_tasks = trajectory["tasks"] |
| ts = trajectory.get("time_step", 50) |
| |
| workers = trajectory.get("workers", []) |
| eng = sum(w.get("energy", 0.5) for w in workers) / max(1, len(workers)) if workers else 0.5 |
| |
| task_objs = [Task(**t) if isinstance(t, dict) else t for t in raw_tasks] |
| return deterministic_grader(task_objs, ts, eng) |
|
|
|
|
| def deterministic_grader(tasks: list[Task], time_step: int, final_energy: float) -> float: |
| """ |
| Scores the ACTUAL final task state. Always returns a value in (0.01, 0.99). |
| |
| Formula: |
| weighted_completion x 0.60 |
| deadline_adherence x 0.22 |
| energy_efficiency x 0.10 |
| dependency_bonus x 0.05 |
| interruption_bonus x 0.03 |
| """ |
| if not tasks: |
| return 0.01 |
|
|
| total_weight = sum(PRIORITY_WEIGHT[t.priority] for t in tasks) |
|
|
| |
| wc = sum(t.progress * PRIORITY_WEIGHT[t.priority] for t in tasks) / max(total_weight, 0.01) |
|
|
| |
| completable = [t for t in tasks if t.deadline is not None] |
| met_deadline = sum( |
| 1 for t in completable |
| if t.progress >= 1.0 and time_step <= t.deadline |
| ) |
| da = (met_deadline / len(completable)) if completable else 1.0 |
|
|
| |
| ee = max(0.0, (final_energy - 0.10) * 0.13) |
|
|
| |
| dep_bonus = 0.0 |
| for t in tasks: |
| if t.depends_on and t.progress >= 1.0: |
| parent = next((p for p in tasks if p.id == t.depends_on), None) |
| if parent and parent.progress >= 1.0: |
| dep_bonus += 0.015 |
| dep_bonus = min(0.05, dep_bonus) |
|
|
| |
| interrupted = [t for t in tasks if t.is_interrupted] |
| int_bonus = 0.0 |
| if interrupted: |
| handled = sum(1 for t in interrupted if t.progress >= 1.0) |
| int_bonus = min(0.03, (handled / len(interrupted)) * 0.03) |
|
|
| raw = wc * 0.60 + da * 0.22 + ee + dep_bonus + int_bonus |
| return round(max(0.01, min(0.99, raw)), 4) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| _INTERRUPT_CONFIG = { |
| |
| "hard": (0.18, 10, 8, 2), |
| "expert": (0.22, 6, 7, 3), |
| } |
|
|
| DRIFT_EVENTS = [ |
| { |
| "name": "server_outage", |
| "trigger_step": 10, |
| "effect": "code_review energy cost doubles", |
| "announcement": "URGENT: Production server down, all code reviews now critical" |
| }, |
| { |
| "name": "urgent_interrupt", |
| "trigger_step": 20, |
| "effect": "Investor call added mid-episode", |
| "announcement": "Urgent interrupt — investor call added mid-episode" |
| }, |
| { |
| "name": "deadline_crunch", |
| "trigger_step": 35, |
| "effect": "All deadlines reduced by 5 steps", |
| "announcement": "Client moved deadline up. All deliverables due earlier." |
| } |
| ] |
|
|
| class CLMEnvironment: |
| def __init__(self, tasks: list[Task], max_steps: int = 50, |
| seed: Optional[int] = None): |
| self.max_steps = max_steps |
| self.initial_tasks = tasks |
| self.difficulty = tasks[0].difficulty if tasks else "easy" |
| self._rng = random.Random(seed) |
| cfg = _INTERRUPT_CONFIG.get(self.difficulty, (0.0, 999, 999, 0)) |
| self._interrupt_prob, eligible_from, self._cooldown, budget = cfg |
| self.state = EnvState( |
| tasks=[t.model_copy() for t in tasks], |
| workers=self._init_workers(), |
| next_interrupt_eligible=eligible_from, |
| interrupt_budget=budget, |
| ) |
|
|
| def _init_workers(self) -> List[WorkerState]: |
| return [ |
| WorkerState(id="w1", expertise="analytical"), |
| WorkerState(id="w2", expertise="social"), |
| WorkerState(id="w3", expertise="analytical") |
| ] |
|
|
| def reset(self) -> Observation: |
| cfg = _INTERRUPT_CONFIG.get(self.difficulty, (0.0, 999, 999, 0)) |
| _, eligible_from, _, budget = cfg |
| self.state = EnvState( |
| tasks=[t.model_copy() for t in self.initial_tasks], |
| workers=self._init_workers(), |
| next_interrupt_eligible=eligible_from, |
| interrupt_budget=budget, |
| ) |
| return self._get_observation() |
|
|
| def _blocked_ids(self) -> set[str]: |
| done_ids = {t.id for t in self.state.tasks if t.progress >= 1.0} |
| return {t.id for t in self.state.tasks if t.depends_on and t.depends_on not in done_ids} |
|
|
| def apply_schema_drift(self, step: int) -> Optional[dict]: |
| for event in DRIFT_EVENTS: |
| if step == event["trigger_step"]: |
| if event["name"] == "deadline_crunch": |
| for t in self.state.tasks: |
| if t.deadline: |
| t.deadline = max(step + 1, t.deadline - 5) |
| elif event["name"] == "urgent_interrupt": |
| self.state.tasks.append(Task( |
| id=f"drift_{step}", difficulty=self.difficulty, |
| task_type="call", priority="critical", |
| deadline=step + 10, is_interrupted=True, |
| )) |
| elif event["name"] == "server_outage": |
| self.state.server_outage_active = True |
| return { |
| "title": event["name"], |
| "message": event["announcement"], |
| "step": step |
| } |
| return None |
|
|
| def _upcoming_ids(self, window: int = 5) -> list[str]: |
| return [ |
| t.id for t in self.state.tasks |
| if t.deadline and 0 < (t.deadline - self.state.time_step) <= window and t.progress < 1.0 |
| ] |
|
|
| def _get_observation(self) -> Observation: |
| vis_workers = [] |
| for w in self.state.workers: |
| e = w.energy |
| s = w.stress |
| fatigue_label = "high" if e < 0.30 else ("medium" if e < 0.60 else "low") |
| stress_label = "critical" if s > 0.75 else ("elevated" if s > 0.45 else "calm") |
| vis_workers.append(VisibleWorker( |
| id=w.id, fatigue_level=fatigue_label, stress_level=stress_label, |
| stress_warning=s > 0.65, expertise=w.expertise, current_task_id=w.current_task_id |
| )) |
|
|
| vs = VisibleState( |
| workers=vis_workers, |
| focus_mode=self.state.focus_mode, |
| upcoming_deadlines=self._upcoming_ids(), |
| blocked_tasks=list(self._blocked_ids()), |
| ) |
| return Observation(tasks=self.state.tasks, visible_state=vs, time_step=self.state.time_step) |
|
|
| def step(self, action: Action) -> Tuple[Observation, float, bool, dict]: |
| reward = 0.0 |
| blocked = self._blocked_ids() |
| |
| |
| worker = next((w for w in self.state.workers if w.id == action.worker_id), self.state.workers[0]) |
|
|
| if (self.state.interrupt_budget > 0 |
| and self.state.time_step >= self.state.next_interrupt_eligible |
| and self._rng.random() < self._interrupt_prob): |
| _inject_interruption(self.state, self.state.time_step) |
| self.state.interrupt_budget -= 1 |
| self.state.next_interrupt_eligible = self.state.time_step + self._cooldown |
| reward -= 0.05 |
|
|
| if action.type in ("work", "focus"): |
| is_focus = (action.type == "focus") |
|
|
| if action.task_id: |
| if action.task_id in blocked: |
| reward -= 0.15 |
| else: |
| if worker.current_task_id and worker.current_task_id != action.task_id: |
| |
| old_t = next((t for t in self.state.tasks if t.id == worker.current_task_id), None) |
| new_t = next((t for t in self.state.tasks if t.id == action.task_id), None) |
| if old_t and new_t: |
| |
| if COGNITIVE_BUCKETS.get(old_t.task_type) == COGNITIVE_BUCKETS.get(new_t.task_type): |
| reward -= 0.15 |
| worker.stress = min(1.0, worker.stress + 0.05) |
| else: |
| reward -= 0.05 |
| worker.current_task_id = action.task_id |
| self.state.focus_mode = is_focus |
|
|
| task = next((t for t in self.state.tasks if t.id == worker.current_task_id), None) |
|
|
| if task and task.progress < 1.0 and task.id not in blocked: |
| ecost = TASK_ENERGY_COST.get(task.task_type, 0.14) * (2.0 if is_focus else 1.0) |
| if self.state.server_outage_active and task.task_type == "code_review": |
| ecost *= 2.0 |
| base_rate = TASK_PROGRESS_RATE.get(task.task_type, 0.22) |
| efficiency = max(0.15, worker.energy) * (1.0 - worker.stress * 0.45) |
| progress = base_rate * (2.0 if is_focus else 1.0) * efficiency |
| pw = PRIORITY_WEIGHT[task.priority] |
|
|
| worker.energy = max(0.0, worker.energy - ecost) |
| old_p = task.progress |
| task.progress = min(1.0, task.progress + progress) |
|
|
| reward += 0.10 * (task.progress - old_p) * pw |
|
|
| for ms, bonus in [(0.25, 0.04), (0.50, 0.07), (0.75, 0.09), (1.00, 0.18)]: |
| key = f"{task.id}@{ms}" |
| if task.progress >= ms and key not in self.state.milestone_rewards: |
| self.state.milestone_rewards[key] = bonus |
| reward += bonus * pw |
| else: |
| worker.energy = max(0.0, worker.energy - 0.04) |
|
|
| elif action.type == "break": |
| self.state.focus_mode = False |
| worker.energy = min(1.0, worker.energy + 0.22) |
| worker.stress = max(0.0, worker.stress - 0.18) |
| reward += 0.03 |
|
|
| elif action.type == "switch": |
| self.state.focus_mode = False |
| if action.task_id and action.task_id not in blocked: |
| worker.current_task_id = action.task_id |
| reward -= 0.07 |
|
|
| elif action.type == "delay": |
| |
| worker.stress = min(1.0, worker.stress + 0.05) |
| reward -= 0.05 |
|
|
| self.state.time_step += 1 |
|
|
| |
| for t in (tt for tt in self.state.tasks if tt.progress < 1.0): |
| if t.deadline: |
| ttd = t.deadline - self.state.time_step |
| pw = PRIORITY_WEIGHT[t.priority] |
| if 0 <= ttd <= 3: |
| for w in self.state.workers: |
| w.stress = min(1.0, w.stress + 0.06 * pw) |
| elif ttd < 0: |
| for w in self.state.workers: |
| w.stress = min(1.0, w.stress + 0.12 * pw) |
|
|
| |
| all_done = all(t.progress >= 1.0 for t in self.state.tasks) |
| |
| burnout = any(w.energy < 0.07 for w in self.state.workers) |
| timeout = self.state.time_step >= self.max_steps |
| done = all_done or burnout or timeout |
|
|
| if any(w.stress > 0.80 for w in self.state.workers): |
| reward -= 0.07 |
|
|
| if done: |
| if burnout: |
| reward -= 1.0 |
| elif all_done: |
| missed = any(t.deadline and self.state.time_step > t.deadline for t in self.state.tasks) |
| reward += 0.5 if missed else 1.0 |
|
|
| reward = max(-1.0, min(1.0, float(reward))) |
| info = self.state.model_dump() |
| |
| drift = self.apply_schema_drift(self.state.time_step) |
| if drift: |
| info["schema_drift"] = drift |
|
|
| if done: |
| eng = sum(w.energy for w in self.state.workers) / max(1, len(self.state.workers)) |
| info["final_score"] = deterministic_grader( |
| self.state.tasks, self.state.time_step, eng |
| ) |
| return self._get_observation(), reward, done, info |
|
|
| def state_dict(self) -> dict: |
| return self.state.model_dump() |
|
|