Spaces:
Sleeping
Sleeping
| """SchemaShiftEnvironment β episode scheduler. reset/step loop with drift ticks + grader. | |
| Round 1 bug prevention: step() raises RuntimeError if called before reset(). Never lazy-init. | |
| """ | |
| from __future__ import annotations | |
| import uuid | |
| from copy import deepcopy | |
| from typing import Any | |
| from drift import DriftInjector | |
| from graders import build_grader, compute_step_shaping | |
| from models import ( | |
| Action, | |
| DriftReportParams, | |
| EpisodeState, | |
| HistoryStep, | |
| Observation, | |
| RewardBreakdown, | |
| ToolResponse, | |
| ) | |
| from scenarios import SCENARIOS | |
| def _instantiate_tool(name: str, seed_data: dict) -> Any: | |
| """Lazy import so missing stretch tools don't break core scenarios.""" | |
| if name == "mail": | |
| from tools.mail import MailAPI | |
| return MailAPI(seed_data) | |
| if name == "calendar": | |
| from tools.calendar import CalendarAPI | |
| return CalendarAPI(seed_data) | |
| if name == "crm": | |
| from tools.crm import CRMAPI | |
| return CRMAPI(seed_data) | |
| if name == "chat": | |
| from tools.chat import ChatAPI # type: ignore[attr-defined] | |
| return ChatAPI(seed_data) | |
| if name == "docs": | |
| from tools.docs import DocsAPI # type: ignore[attr-defined] | |
| return DocsAPI(seed_data) | |
| raise ValueError(f"Unknown tool: {name}") | |
| class SchemaShiftEnvironment: | |
| """The SchemaShift RL environment. One instance = one episode at a time.""" | |
| def __init__(self) -> None: | |
| self._state: EpisodeState | None = None | |
| self._tools: dict[str, Any] = {} | |
| self._grader = build_grader() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Public API | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: str, seed: int = 0) -> Observation: | |
| if task_id not in SCENARIOS: | |
| raise ValueError( | |
| f"Unknown task_id: {task_id}. Available: {list(SCENARIOS.keys())}" | |
| ) | |
| scenario = SCENARIOS[task_id] | |
| required_tools = scenario["required_tools"] | |
| self._tools = {} | |
| for tool_name in required_tools: | |
| tool_seed = scenario["seed_data"].get(tool_name, {}) | |
| self._tools[tool_name] = _instantiate_tool(tool_name, tool_seed) | |
| self._state = EpisodeState( | |
| episode_id=str(uuid.uuid4()), | |
| task_id=task_id, | |
| difficulty=scenario["difficulty"], | |
| step=0, | |
| max_steps=scenario["max_steps"], | |
| token_budget=scenario["token_budget"], | |
| token_budget_remaining=scenario["token_budget"], | |
| drift_plan=deepcopy(scenario["drift_plan"]), | |
| ground_truth_final_state=dict(scenario["ground_truth_final_state"]), | |
| agent_state={}, | |
| history=[], | |
| done=False, | |
| cumulative_reward=0.0, | |
| ) | |
| return self._observation("Episode started.") | |
| def step( | |
| self, action: Action, tokens_used: int = 0 | |
| ) -> tuple[Observation, RewardBreakdown]: | |
| if self._state is None: | |
| raise RuntimeError( | |
| "Call reset() before step(). " | |
| "SchemaShiftEnvironment requires an active episode." | |
| ) | |
| s = self._state | |
| if s.done: | |
| raise RuntimeError( | |
| "Episode already done. Call reset() to start a new episode." | |
| ) | |
| s.step += 1 | |
| s.token_budget_remaining = max(0, s.token_budget_remaining - tokens_used) | |
| # 1. Apply any scheduled drifts for this step | |
| fired_drifts = DriftInjector.tick(s, self._tools) | |
| # 2. Dispatch action (does NOT mark drift detected β that happens after shaping) | |
| response, feedback = self._dispatch_action(action) | |
| # 3. Compute step shaping BEFORE marking drift detected | |
| # (shaping checks `not d.detected_by_agent` β must run pre-mark) | |
| step_shape = compute_step_shaping(s, action, response) | |
| # 4. Now apply drift-detection mark (so grader sees the detection this step) | |
| if action.type == "report_drift" and action.report is not None: | |
| self._mark_drift_detected(action.report) | |
| # 5. Update agent_state from action+response | |
| self._update_agent_state(action, response) | |
| # 6. Log the history step (reward will be filled below) | |
| history_step = HistoryStep( | |
| step=s.step, action=action, response=response, reward_breakdown=None, | |
| ) | |
| s.history.append(history_step) | |
| # 7. Check terminal conditions | |
| if s.step >= s.max_steps or s.token_budget_remaining <= 0: | |
| s.done = True | |
| if action.type == "complete_task": | |
| s.done = True | |
| # 8. Run grader (sees marked drifts + updated agent_state + done flag) | |
| reward = self._grader(s) | |
| reward.step_shaping = step_shape | |
| reward.shaped_total += step_shape | |
| # 9. Log reward into history | |
| s.history[-1].reward_breakdown = reward.model_dump() | |
| s.cumulative_reward += reward.shaped_total | |
| # 10. Decorate feedback with drift info and return | |
| if fired_drifts: | |
| feedback += ( | |
| f" [DRIFT FIRED: {len(fired_drifts)} drift event(s) on step {s.step}.]" | |
| ) | |
| return self._observation(feedback), reward | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Action dispatch (pure β no state mutation for drift detection) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _dispatch_action(self, action: Action) -> tuple[ToolResponse | None, str]: | |
| s = self._state | |
| assert s is not None | |
| if action.type == "call_tool": | |
| if action.tool_call is None: | |
| return None, "Invalid call_tool: missing tool_call params." | |
| if action.tool_call.tool not in self._tools: | |
| return ( | |
| ToolResponse( | |
| ok=False, status=404, | |
| error=f"Tool '{action.tool_call.tool}' not available in this scenario.", | |
| ), | |
| "Tool not available.", | |
| ) | |
| response = self._tools[action.tool_call.tool].call( | |
| action.tool_call.endpoint, action.tool_call.params | |
| ) | |
| return ( | |
| response, | |
| f"Called {action.tool_call.tool}.{action.tool_call.endpoint}: status={response.status}", | |
| ) | |
| if action.type == "inspect_schema": | |
| if action.inspect is None or action.inspect.tool not in self._tools: | |
| return ( | |
| ToolResponse( | |
| ok=False, status=404, | |
| error="Tool unavailable for inspection.", | |
| ), | |
| "Inspect target missing.", | |
| ) | |
| schema = self._tools[action.inspect.tool].get_schema() | |
| return ( | |
| ToolResponse(ok=True, status=200, body={"schema": schema}), | |
| f"Inspected {action.inspect.tool} schema.", | |
| ) | |
| if action.type == "retry_with_variant": | |
| if action.retry is None or action.retry.tool not in self._tools: | |
| return ( | |
| ToolResponse( | |
| ok=False, status=404, | |
| error="Retry target unavailable.", | |
| ), | |
| "Retry target missing.", | |
| ) | |
| response = self._tools[action.retry.tool].call( | |
| action.retry.endpoint, action.retry.params | |
| ) | |
| return ( | |
| response, | |
| f"Retried {action.retry.tool}.{action.retry.endpoint}: status={response.status}", | |
| ) | |
| if action.type == "report_drift": | |
| if action.report is None: | |
| return None, "Invalid report_drift: missing report params." | |
| for d in s.drift_plan: | |
| if (d.tool == action.report.tool | |
| and d.kind == action.report.drift_kind | |
| and d.fires_at_step <= s.step | |
| and not d.detected_by_agent): | |
| return ( | |
| None, | |
| f"Drift correctly reported: {d.kind} on {d.tool} at step {d.fires_at_step}.", | |
| ) | |
| return None, "Drift report did not match any undetected fired drift." | |
| if action.type == "complete_task": | |
| summary = action.complete.summary if action.complete else "" | |
| s.agent_state["_completion_summary"] = summary | |
| self._check_completion_summary(summary) | |
| return None, f"Episode marked complete. Summary: {summary[:80]}" | |
| return None, f"Unknown action type: {action.type}" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Post-shaping state mutations | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _mark_drift_detected(self, report: DriftReportParams) -> None: | |
| s = self._state | |
| assert s is not None | |
| for d in s.drift_plan: | |
| if (d.tool == report.tool | |
| and d.kind == report.drift_kind | |
| and d.fires_at_step <= s.step | |
| and not d.detected_by_agent): | |
| d.detected_by_agent = True | |
| return | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # State tracking β populates agent_state so grader can read it | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _update_agent_state( | |
| self, action: Action, response: ToolResponse | None | |
| ) -> None: | |
| if response is None or not response.ok: | |
| return | |
| s = self._state | |
| assert s is not None | |
| st = s.agent_state | |
| tool: str | None = None | |
| endpoint: str | None = None | |
| params: dict = {} | |
| if action.type == "call_tool" and action.tool_call is not None: | |
| tool = action.tool_call.tool | |
| endpoint = action.tool_call.endpoint | |
| params = action.tool_call.params | |
| elif action.type == "retry_with_variant" and action.retry is not None: | |
| tool = action.retry.tool | |
| endpoint = action.retry.endpoint | |
| params = action.retry.params | |
| else: | |
| return | |
| # MAIL ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if tool == "mail": | |
| if endpoint in ("send_message", "messages.send"): | |
| st["mail.sent_count"] = st.get("mail.sent_count", 0) + 1 | |
| sent_to = params.get("to", "") | |
| st["mail.last_sent_to"] = sent_to | |
| subject = str(params.get("subject", "")).lower() | |
| if "welcome" in subject: | |
| st["mail.last_subject_contains_welcome"] = True | |
| if "all-hands" in subject or "all hands" in subject: | |
| st["mail.last_subject_contains_allhands"] = True | |
| if "priority support" in subject: | |
| st["mail.last_subject_contains_priority_support"] = True | |
| if "weekly" in subject: | |
| st["mail.last_subject_contains_weekly"] = True | |
| if "calendar updated" in subject: | |
| st["mail.last_subject_contains_calendar_updated"] = True | |
| recipients: list[str] = st.get("mail.all_recipients", []) | |
| if sent_to and sent_to not in recipients: | |
| recipients.append(sent_to) | |
| st["mail.all_recipients"] = recipients | |
| e2_required = {"alex@company.com", "jordan@company.com", "sam@company.com"} | |
| if e2_required.issubset(set(recipients)): | |
| st["mail.sent_to_all_three_recipients"] = True | |
| # CALENDAR ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if tool == "calendar": | |
| if endpoint == "create_event": | |
| st["calendar.events_count"] = st.get("calendar.events_count", 0) + 1 | |
| body = response.body or {} | |
| raw = body.get("attendees") or body.get("participants") or [] | |
| emails: list[str] = [] | |
| for a in raw: | |
| if isinstance(a, str): | |
| emails.append(a) | |
| elif isinstance(a, dict): | |
| emails.append(a.get("email", "")) | |
| st["calendar.last_event_attendees"] = emails | |
| # Recognised attendee pairs (E1 + M1 share this key by design). | |
| priya_alex = ( | |
| "priya@company.com" in emails and "alex@company.com" in emails | |
| ) | |
| bob_alex = ( | |
| "bob@customer.com" in emails and "alex@company.com" in emails | |
| ) | |
| if priya_alex or bob_alex: | |
| st["calendar.last_event_has_both_attendees"] = True | |
| if "sarah@company.com" in emails and "mike@company.com" in emails: | |
| st["calendar.last_event_has_both_sales_leads"] = True | |
| # M3: Friday Wrap-up event counter | |
| title = str(body.get("title") or params.get("title") or "").lower() | |
| if "friday wrap-up" in title: | |
| st["calendar.events_count_new_friday_wrapup"] = ( | |
| st.get("calendar.events_count_new_friday_wrapup", 0) + 1 | |
| ) | |
| elif endpoint == "update_event": | |
| # M3: track per-event status transitions (cancellations) | |
| event_id = params.get("event_id", "") | |
| status = params.get("status") | |
| if event_id and status: | |
| st[f"calendar.{event_id}_status"] = status | |
| # CRM βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if tool == "crm": | |
| if endpoint in ("update_contact", "contacts.patch"): | |
| cid = params.get("contact_id", "") | |
| status = params.get("status") | |
| if cid and status: | |
| st[f"crm.contact_{cid}_status"] = status | |
| def _check_completion_summary(self, summary: str) -> None: | |
| s = self._state | |
| assert s is not None | |
| st = s.agent_state | |
| gt = s.ground_truth_final_state | |
| if "complete_summary_mentions_company" in gt: | |
| for c in ("Globex", "Acme", "Initech"): | |
| if c.lower() in summary.lower(): | |
| st["complete_summary_mentions_company"] = True | |
| break | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Observation construction | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _observation(self, feedback: str) -> Observation: | |
| s = self._state | |
| assert s is not None | |
| scenario = SCENARIOS[s.task_id] | |
| return Observation( | |
| episode_id=s.episode_id, | |
| task_id=s.task_id, | |
| difficulty=s.difficulty, | |
| step=s.step, | |
| max_steps=s.max_steps, | |
| token_budget_remaining=s.token_budget_remaining, | |
| task_description=scenario["task_description"], | |
| success_criteria=list(scenario["success_criteria"]), | |
| tool_schemas={name: t.get_schema() for name, t in self._tools.items()}, | |
| known_state=dict(s.agent_state), | |
| history=list(s.history[-5:]), | |
| last_response=s.history[-1].response if s.history else None, | |
| drift_events_visible=[ | |
| { | |
| "tool": d.tool, | |
| "kind": d.kind, | |
| "endpoint": d.endpoint, | |
| "fires_at_step": d.fires_at_step, | |
| } | |
| for d in s.drift_plan | |
| if d.detected_by_agent | |
| ], | |
| done=s.done, | |
| feedback=feedback, | |
| ) | |