soumi guria
changes after round 1
60fc766
from pydantic import BaseModel, Field
from typing import List, Optional, Literal, Tuple, Dict, Any
import random
# ==========================================
# TASK TYPES
# ==========================================
TaskType = Literal["email", "meeting", "code_review", "report", "call"]
Priority = Literal["critical", "high", "normal", "low"]
PRIORITY_WEIGHT = {"critical": 1.5, "high": 1.2, "normal": 1.0, "low": 0.7}
TASK_ENERGY_COST = {"email": 0.08, "meeting": 0.18, "code_review": 0.20, "report": 0.14, "call": 0.11}
TASK_PROGRESS_RATE = {"email": 0.35, "meeting": 0.30, "code_review": 0.20, "report": 0.22, "call": 0.28}
COGNITIVE_BUCKETS = {"email": "social", "meeting": "social", "code_review": "analytical", "report": "analytical", "call": "social"}
ALL_TASK_TYPES: list[TaskType] = ["email", "meeting", "code_review", "report", "call"]
ALL_PRIORITIES: list[Priority] = ["critical", "high", "normal", "low"]
# ==========================================
# OPENENV SCHEMAS
# ==========================================
class Task(BaseModel):
id: str
difficulty: str
task_type: TaskType = "report"
priority: Priority = "normal"
progress: float = 0.0
deadline: Optional[int] = None
depends_on: Optional[str] = None
is_interrupted: bool = False
class WorkerState(BaseModel):
id: str
energy: float = 1.0
stress: float = 0.0
current_task_id: Optional[str] = None
expertise: str = "analytical"
class VisibleWorker(BaseModel):
id: str
fatigue_level: str
stress_level: str
stress_warning: bool
expertise: str
current_task_id: Optional[str] = None
class VisibleState(BaseModel):
"""
Partial observability for the Oracle Manager.
"""
workers: List[VisibleWorker] = []
focus_mode: bool = False
upcoming_deadlines: List[str] = []
blocked_tasks: List[str] = []
class Observation(BaseModel):
tasks: List[Task]
visible_state: VisibleState
time_step: int
class Action(BaseModel):
type: Literal["work", "break", "switch", "delay", "focus"]
task_id: Optional[str] = None
worker_id: Optional[str] = None
class EnvState(BaseModel):
workers: List[WorkerState] = []
time_step: int = 0
tasks: List[Task] = []
focus_mode: bool = False
interruption_count: int = 0
milestone_rewards: Dict[str, float] = {}
next_interrupt_eligible: int = 999
interrupt_budget: int = 0
server_outage_active: bool = False
# ==========================================
# FIX 2 — PROCEDURAL TASK GENERATION
# Seed-based so episodes are reproducible on request but vary by default.
# Deadlines jitter +-3 steps; task types and secondary priorities randomised.
# ==========================================
def generate_tasks(level: str, seed: Optional[int] = None) -> list[Task]:
"""
Generate tasks for the given difficulty level.
Pass seed=None for a random seed (default for live play),
or an explicit int for reproducible evaluation runs.
"""
rng = random.Random(seed)
def _jitter(base: int, lo: int = -3, hi: int = 3) -> int:
return max(1, base + rng.randint(lo, hi))
def _p(pool: list) -> str:
return rng.choice(pool)
if level == "easy":
return [
Task(id="e1", difficulty="easy",
task_type=_p(["email", "report"]),
priority=_p(["normal", "high"]),
deadline=None),
Task(id="e2", difficulty="easy",
task_type=_p(["report", "code_review"]),
priority=_p(["normal", "low"]),
deadline=None),
]
elif level == "medium":
return [
Task(id="m1", difficulty="medium",
task_type=_p(["email", "call"]),
priority="critical",
deadline=_jitter(14)),
Task(id="m2", difficulty="medium",
task_type=_p(["meeting", "code_review"]),
priority=_p(["high", "normal"]),
deadline=_jitter(20)),
Task(id="m3", difficulty="medium",
task_type=_p(["code_review", "report"]),
priority=_p(["normal", "high"]),
deadline=_jitter(28)),
Task(id="m4", difficulty="medium",
task_type=_p(["report", "meeting"]),
priority=_p(["high", "normal"]),
deadline=_jitter(35)),
Task(id="m5", difficulty="medium",
task_type=_p(["call", "email"]),
priority=_p(["low", "normal"]),
deadline=_jitter(45)),
]
elif level == "hard":
return [
Task(id="h1", difficulty="hard",
task_type=_p(["email", "call"]),
priority="critical",
deadline=_jitter(12)),
Task(id="h2", difficulty="hard",
task_type=_p(["code_review", "report"]),
priority=_p(["high", "normal"]),
deadline=_jitter(16)),
Task(id="h3", difficulty="hard",
task_type=_p(["meeting", "call"]),
priority="critical",
deadline=_jitter(20),
depends_on="h1"),
Task(id="h4", difficulty="hard",
task_type=_p(["report", "code_review"]),
priority=_p(["high", "normal"]),
deadline=_jitter(24)),
Task(id="h5", difficulty="hard",
task_type=_p(["call", "meeting"]),
priority=_p(["normal", "high"]),
deadline=_jitter(28),
depends_on="h2"),
Task(id="h6", difficulty="hard",
task_type=_p(["email", "report"]),
priority=_p(["high", "normal"]),
deadline=_jitter(32)),
Task(id="h7", difficulty="hard",
task_type=_p(["code_review", "meeting"]),
priority="critical",
deadline=_jitter(38),
depends_on="h4"),
Task(id="h8", difficulty="hard",
task_type=_p(["report", "email"]),
priority=_p(["normal", "low"]),
deadline=_jitter(46)),
]
elif level == "expert":
return [
Task(id="x1", difficulty="expert",
task_type=_p(["email", "call"]),
priority="critical",
deadline=_jitter(8)),
Task(id="x2", difficulty="expert",
task_type=_p(["code_review", "report"]),
priority=_p(["high", "critical"]),
deadline=_jitter(12)),
Task(id="x3", difficulty="expert",
task_type=_p(["meeting", "call"]),
priority="critical",
deadline=_jitter(14),
depends_on="x1"),
Task(id="x4", difficulty="expert",
task_type=_p(["report", "code_review"]),
priority=_p(["high", "normal"]),
deadline=_jitter(18),
depends_on="x2"),
Task(id="x5", difficulty="expert",
task_type=_p(["call", "meeting"]),
priority=_p(["normal", "high"]),
deadline=_jitter(22),
depends_on="x3"),
Task(id="x6", difficulty="expert",
task_type=_p(["code_review", "email"]),
priority="critical",
deadline=_jitter(24)),
Task(id="x7", difficulty="expert",
task_type=_p(["email", "report"]),
priority=_p(["high", "normal"]),
deadline=_jitter(28),
depends_on="x4"),
Task(id="x8", difficulty="expert",
task_type=_p(["report", "call"]),
priority=_p(["normal", "high"]),
deadline=_jitter(33),
depends_on="x6"),
Task(id="x9", difficulty="expert",
task_type=_p(["meeting", "code_review"]),
priority="critical",
deadline=_jitter(36),
depends_on="x5"),
Task(id="x10", difficulty="expert",
task_type=_p(["call", "email"]),
priority=_p(["high", "normal"]),
deadline=_jitter(44)),
]
return []
def _inject_interruption(state: EnvState, step: int) -> None:
"""Inject an urgent email task mid-episode (hard/expert levels)."""
iid = f"int{state.interruption_count + 1}"
state.tasks.append(Task(
id=iid, difficulty=state.tasks[0].difficulty,
task_type="email", priority="critical",
deadline=step + 8, is_interrupted=True,
))
state.interruption_count += 1
# ==========================================
# GRADER
# ==========================================
def grader(trajectory: dict) -> float:
if not trajectory or not trajectory.get("tasks"):
return 0.01
raw_tasks = trajectory["tasks"]
ts = trajectory.get("time_step", 50)
# Average energy across workers for grading purposes
workers = trajectory.get("workers", [])
eng = sum(w.get("energy", 0.5) for w in workers) / max(1, len(workers)) if workers else 0.5
task_objs = [Task(**t) if isinstance(t, dict) else t for t in raw_tasks]
return deterministic_grader(task_objs, ts, eng)
def deterministic_grader(tasks: list[Task], time_step: int, final_energy: float) -> float:
"""
Scores the ACTUAL final task state. Always returns a value in (0.01, 0.99).
Formula:
weighted_completion x 0.60
deadline_adherence x 0.22
energy_efficiency x 0.10
dependency_bonus x 0.05
interruption_bonus x 0.03
"""
if not tasks:
return 0.01
total_weight = sum(PRIORITY_WEIGHT[t.priority] for t in tasks)
# Weighted completion (partial progress counts)
wc = sum(t.progress * PRIORITY_WEIGHT[t.priority] for t in tasks) / max(total_weight, 0.01)
# Deadline adherence
completable = [t for t in tasks if t.deadline is not None]
met_deadline = sum(
1 for t in completable
if t.progress >= 1.0 and time_step <= t.deadline
)
da = (met_deadline / len(completable)) if completable else 1.0
# Energy efficiency
ee = max(0.0, (final_energy - 0.10) * 0.13)
# Dependency ordering bonus
dep_bonus = 0.0
for t in tasks:
if t.depends_on and t.progress >= 1.0:
parent = next((p for p in tasks if p.id == t.depends_on), None)
if parent and parent.progress >= 1.0:
dep_bonus += 0.015
dep_bonus = min(0.05, dep_bonus)
# Interruption handling bonus
interrupted = [t for t in tasks if t.is_interrupted]
int_bonus = 0.0
if interrupted:
handled = sum(1 for t in interrupted if t.progress >= 1.0)
int_bonus = min(0.03, (handled / len(interrupted)) * 0.03)
raw = wc * 0.60 + da * 0.22 + ee + dep_bonus + int_bonus
return round(max(0.01, min(0.99, raw)), 4)
# ==========================================
# FIX 3 — STOCHASTIC INTERRUPTION CONFIG
# Interruptions fire with a per-step probability once an eligibility
# window opens, with a cooldown to prevent back-to-back fires.
# budget = max number of interrupts for the difficulty level.
# ==========================================
_INTERRUPT_CONFIG = {
# prob_per_step eligible_from cooldown_steps budget
"hard": (0.18, 10, 8, 2),
"expert": (0.22, 6, 7, 3),
}
DRIFT_EVENTS = [
{
"name": "server_outage",
"trigger_step": 10,
"effect": "code_review energy cost doubles",
"announcement": "URGENT: Production server down, all code reviews now critical"
},
{
"name": "urgent_interrupt",
"trigger_step": 20,
"effect": "Investor call added mid-episode",
"announcement": "Urgent interrupt — investor call added mid-episode"
},
{
"name": "deadline_crunch",
"trigger_step": 35,
"effect": "All deadlines reduced by 5 steps",
"announcement": "Client moved deadline up. All deliverables due earlier."
}
]
class CLMEnvironment:
def __init__(self, tasks: list[Task], max_steps: int = 50,
seed: Optional[int] = None):
self.max_steps = max_steps
self.initial_tasks = tasks
self.difficulty = tasks[0].difficulty if tasks else "easy"
self._rng = random.Random(seed)
cfg = _INTERRUPT_CONFIG.get(self.difficulty, (0.0, 999, 999, 0))
self._interrupt_prob, eligible_from, self._cooldown, budget = cfg
self.state = EnvState(
tasks=[t.model_copy() for t in tasks],
workers=self._init_workers(),
next_interrupt_eligible=eligible_from,
interrupt_budget=budget,
)
def _init_workers(self) -> List[WorkerState]:
return [
WorkerState(id="w1", expertise="analytical"),
WorkerState(id="w2", expertise="social"),
WorkerState(id="w3", expertise="analytical")
]
def reset(self) -> Observation:
cfg = _INTERRUPT_CONFIG.get(self.difficulty, (0.0, 999, 999, 0))
_, eligible_from, _, budget = cfg
self.state = EnvState(
tasks=[t.model_copy() for t in self.initial_tasks],
workers=self._init_workers(),
next_interrupt_eligible=eligible_from,
interrupt_budget=budget,
)
return self._get_observation()
def _blocked_ids(self) -> set[str]:
done_ids = {t.id for t in self.state.tasks if t.progress >= 1.0}
return {t.id for t in self.state.tasks if t.depends_on and t.depends_on not in done_ids}
def apply_schema_drift(self, step: int) -> Optional[dict]:
for event in DRIFT_EVENTS:
if step == event["trigger_step"]:
if event["name"] == "deadline_crunch":
for t in self.state.tasks:
if t.deadline:
t.deadline = max(step + 1, t.deadline - 5)
elif event["name"] == "urgent_interrupt":
self.state.tasks.append(Task(
id=f"drift_{step}", difficulty=self.difficulty,
task_type="call", priority="critical",
deadline=step + 10, is_interrupted=True,
))
elif event["name"] == "server_outage":
self.state.server_outage_active = True
return {
"title": event["name"],
"message": event["announcement"],
"step": step
}
return None
def _upcoming_ids(self, window: int = 5) -> list[str]:
return [
t.id for t in self.state.tasks
if t.deadline and 0 < (t.deadline - self.state.time_step) <= window and t.progress < 1.0
]
def _get_observation(self) -> Observation:
vis_workers = []
for w in self.state.workers:
e = w.energy
s = w.stress
fatigue_label = "high" if e < 0.30 else ("medium" if e < 0.60 else "low")
stress_label = "critical" if s > 0.75 else ("elevated" if s > 0.45 else "calm")
vis_workers.append(VisibleWorker(
id=w.id, fatigue_level=fatigue_label, stress_level=stress_label,
stress_warning=s > 0.65, expertise=w.expertise, current_task_id=w.current_task_id
))
vs = VisibleState(
workers=vis_workers,
focus_mode=self.state.focus_mode,
upcoming_deadlines=self._upcoming_ids(),
blocked_tasks=list(self._blocked_ids()),
)
return Observation(tasks=self.state.tasks, visible_state=vs, time_step=self.state.time_step)
def step(self, action: Action) -> Tuple[Observation, float, bool, dict]:
reward = 0.0
blocked = self._blocked_ids()
# Oracle manager assigns action to specific worker
worker = next((w for w in self.state.workers if w.id == action.worker_id), self.state.workers[0])
if (self.state.interrupt_budget > 0
and self.state.time_step >= self.state.next_interrupt_eligible
and self._rng.random() < self._interrupt_prob):
_inject_interruption(self.state, self.state.time_step)
self.state.interrupt_budget -= 1
self.state.next_interrupt_eligible = self.state.time_step + self._cooldown
reward -= 0.05
if action.type in ("work", "focus"):
is_focus = (action.type == "focus")
if action.task_id:
if action.task_id in blocked:
reward -= 0.15
else:
if worker.current_task_id and worker.current_task_id != action.task_id:
# Context switching penalty logic
old_t = next((t for t in self.state.tasks if t.id == worker.current_task_id), None)
new_t = next((t for t in self.state.tasks if t.id == action.task_id), None)
if old_t and new_t:
# If similar task type, HIGH penalty. If dissimilar, LOW penalty.
if COGNITIVE_BUCKETS.get(old_t.task_type) == COGNITIVE_BUCKETS.get(new_t.task_type):
reward -= 0.15 # Penalty for monotony
worker.stress = min(1.0, worker.stress + 0.05)
else:
reward -= 0.05 # Refreshing context switch
worker.current_task_id = action.task_id
self.state.focus_mode = is_focus
task = next((t for t in self.state.tasks if t.id == worker.current_task_id), None)
if task and task.progress < 1.0 and task.id not in blocked:
ecost = TASK_ENERGY_COST.get(task.task_type, 0.14) * (2.0 if is_focus else 1.0)
if self.state.server_outage_active and task.task_type == "code_review":
ecost *= 2.0
base_rate = TASK_PROGRESS_RATE.get(task.task_type, 0.22)
efficiency = max(0.15, worker.energy) * (1.0 - worker.stress * 0.45)
progress = base_rate * (2.0 if is_focus else 1.0) * efficiency
pw = PRIORITY_WEIGHT[task.priority]
worker.energy = max(0.0, worker.energy - ecost)
old_p = task.progress
task.progress = min(1.0, task.progress + progress)
reward += 0.10 * (task.progress - old_p) * pw
for ms, bonus in [(0.25, 0.04), (0.50, 0.07), (0.75, 0.09), (1.00, 0.18)]:
key = f"{task.id}@{ms}"
if task.progress >= ms and key not in self.state.milestone_rewards:
self.state.milestone_rewards[key] = bonus
reward += bonus * pw
else:
worker.energy = max(0.0, worker.energy - 0.04)
elif action.type == "break":
self.state.focus_mode = False
worker.energy = min(1.0, worker.energy + 0.22)
worker.stress = max(0.0, worker.stress - 0.18)
reward += 0.03
elif action.type == "switch":
self.state.focus_mode = False
if action.task_id and action.task_id not in blocked:
worker.current_task_id = action.task_id
reward -= 0.07
elif action.type == "delay":
# Pushing to tomorrow: Moderate penalty (not extreme)
worker.stress = min(1.0, worker.stress + 0.05)
reward -= 0.05
self.state.time_step += 1
# Stress dynamics for all workers
for t in (tt for tt in self.state.tasks if tt.progress < 1.0):
if t.deadline:
ttd = t.deadline - self.state.time_step
pw = PRIORITY_WEIGHT[t.priority]
if 0 <= ttd <= 3:
for w in self.state.workers:
w.stress = min(1.0, w.stress + 0.06 * pw)
elif ttd < 0:
for w in self.state.workers:
w.stress = min(1.0, w.stress + 0.12 * pw)
# Episode termination
all_done = all(t.progress >= 1.0 for t in self.state.tasks)
# Burnout condition: ANY worker hits 0 energy
burnout = any(w.energy < 0.07 for w in self.state.workers)
timeout = self.state.time_step >= self.max_steps
done = all_done or burnout or timeout
if any(w.stress > 0.80 for w in self.state.workers):
reward -= 0.07
if done:
if burnout:
reward -= 1.0
elif all_done:
missed = any(t.deadline and self.state.time_step > t.deadline for t in self.state.tasks)
reward += 0.5 if missed else 1.0
reward = max(-1.0, min(1.0, float(reward)))
info = self.state.model_dump()
drift = self.apply_schema_drift(self.state.time_step)
if drift:
info["schema_drift"] = drift
if done:
eng = sum(w.energy for w in self.state.workers) / max(1, len(self.state.workers))
info["final_score"] = deterministic_grader(
self.state.tasks, self.state.time_step, eng
)
return self._get_observation(), reward, done, info
def state_dict(self) -> dict:
return self.state.model_dump()