Spaces:
Sleeping
Sleeping
| """SchemaShift EA Arena Environment — reset/step/state with schema drift injection.""" | |
| import os, json, copy | |
| from models import EAAction, EAObservation, EpisodeState | |
| from tasks import TASKS | |
| from tools import ALL_TOOLS, CalendarTool, EmailTool, BookingsTool, TravelTool, DocsTool, ExpensesTool, RoomsTool, TeamTool, IncidentsTool | |
| from verifier import verify_episode | |
| class SchemaShiftEnvironment: | |
| def __init__(self): | |
| self._state = None | |
| self._task = None | |
| self._task_index = 0 | |
| self._tools = {} | |
| self._drift_applied = False | |
| def _setup_tools(self, seed): | |
| self._tools = {} | |
| tool_map = { | |
| "calendar": CalendarTool, "emails": EmailTool, "email": EmailTool, | |
| "bookings": BookingsTool, "travel": TravelTool, "docs": DocsTool, | |
| "expenses": ExpensesTool, "rooms": RoomsTool, "team": TeamTool, | |
| "incidents": IncidentsTool, | |
| } | |
| for key, data in seed.items(): | |
| if key == "policies": | |
| continue | |
| cls = tool_map.get(key) | |
| if cls and isinstance(data, list): | |
| tool = cls() | |
| tool.seed(data) | |
| name = "email" if key == "emails" else key | |
| self._tools[name] = tool | |
| def reset(self): | |
| self._task = TASKS[self._task_index % len(TASKS)] | |
| self._task_index += 1 | |
| self._drift_applied = False | |
| self._setup_tools(self._task.get("seed", {})) | |
| self._state = EpisodeState( | |
| task_id=self._task["id"], | |
| task_description=self._task["description"], | |
| max_steps=self._task.get("max_steps", 15), | |
| ) | |
| return EAObservation( | |
| success=True, | |
| output=f"TASK: {self._task['title']}\n\n{self._task['description']}", | |
| task_description=self._task["description"], | |
| done=False, | |
| schema_version=1, | |
| ) | |
| def _maybe_inject_drift(self): | |
| drift_step = self._task.get("drift_at_step") | |
| if drift_step and self._state.step_count >= drift_step and not self._drift_applied: | |
| drift = self._task.get("drift_event", {}) | |
| tool_name = drift.get("tool", "") | |
| if tool_name == "emails": | |
| tool_name = "email" | |
| tool = self._tools.get(tool_name) | |
| if tool: | |
| tool.apply_drift(drift) | |
| self._drift_applied = True | |
| self._state.drift_events.append(drift.get("change", "unknown")) | |
| return drift | |
| return None | |
| def step(self, action): | |
| if self._state is None: | |
| return EAObservation(success=False, error="Call reset() first", reward=-1.0, done=True) | |
| self._state.step_count += 1 | |
| tool_name = action.tool if hasattr(action, 'tool') else action.get('tool', '') | |
| act = action.action if hasattr(action, 'action') else action.get('action', '') | |
| params = action.parameters if hasattr(action, 'parameters') else action.get('parameters', {}) | |
| drift = self._maybe_inject_drift() | |
| drift_msg = "" | |
| if drift: | |
| dtype = drift.get("type", "") | |
| if dtype == "schema_change": | |
| drift_msg = f"\n⚠️ SCHEMA CHANGE: {drift.get('change', '')}. Check tool documentation." | |
| elif dtype == "policy_change": | |
| drift_msg = f"\n⚠️ POLICY CHANGE: {drift.get('change', '')}. Review updated policies." | |
| elif dtype == "actor_conflict": | |
| drift_msg = f"\n⚠️ NEW MESSAGE from {drift.get('actor', 'unknown')}: \"{drift.get('message', '')}\"" | |
| if tool_name == "system" and act == "submit": | |
| return self._submit() | |
| tool = self._tools.get(tool_name) | |
| if not tool: | |
| self._state.invalid_calls += 1 | |
| return EAObservation( | |
| success=False, error=f"Unknown tool: {tool_name}{drift_msg}", | |
| step_count=self._state.step_count, | |
| drift_occurred=bool(drift), | |
| ) | |
| self._state.tools_used.append(f"{tool_name}.{act}") | |
| result = tool.execute(act, params) | |
| if not result.get("success", False): | |
| if result.get("policy_violated"): | |
| self._state.policy_violations += 1 | |
| elif "schema_version" not in result: | |
| self._state.invalid_calls += 1 | |
| if self._drift_applied and result.get("success"): | |
| self._state.recovered_from_drift = True | |
| output = json.dumps(result, indent=2) if isinstance(result, dict) else str(result) | |
| output += drift_msg | |
| done = self._state.step_count >= self._state.max_steps | |
| if done: | |
| return self._submit() | |
| return EAObservation( | |
| success=result.get("success", False), | |
| output=output, | |
| error=result.get("error"), | |
| step_count=self._state.step_count, | |
| schema_version=getattr(tool, '_schema_version', 1), | |
| drift_occurred=bool(drift), | |
| ) | |
| def _submit(self): | |
| snapshots = {} | |
| for name, tool in self._tools.items(): | |
| snapshots[name] = tool.snapshot() | |
| if "email" in self._tools: | |
| email_snap = self._tools["email"].snapshot() | |
| if isinstance(email_snap, dict): | |
| self._state.notifications_sent = [e.get("to", "") for e in email_snap.get("outbox", [])] | |
| reward, violations, verdict = verify_episode( | |
| task=self._task, | |
| snapshots=snapshots, | |
| policy_violations=self._state.policy_violations, | |
| invalid_calls=self._state.invalid_calls, | |
| tool_calls_made=self._state.step_count, | |
| drift_events_handled=len(self._state.drift_events), | |
| recovered_from_drift=self._state.recovered_from_drift, | |
| ) | |
| self._state.completed = True | |
| self._state.verdict = verdict | |
| return EAObservation( | |
| success=True, | |
| output=json.dumps(verdict, indent=2), | |
| reward=reward, | |
| done=True, | |
| step_count=self._state.step_count, | |
| ) | |
| def state(self): | |
| return self._state | |