schemashift / schemashift_environment.py
SidraMiconi's picture
deploy SchemaShift
a17a9f5
"""SchemaShift EA Arena Environment — reset/step/state with schema drift injection."""
import os, json, copy
from models import EAAction, EAObservation, EpisodeState
from tasks import TASKS
from tools import ALL_TOOLS, CalendarTool, EmailTool, BookingsTool, TravelTool, DocsTool, ExpensesTool, RoomsTool, TeamTool, IncidentsTool
from verifier import verify_episode
class SchemaShiftEnvironment:
def __init__(self):
self._state = None
self._task = None
self._task_index = 0
self._tools = {}
self._drift_applied = False
def _setup_tools(self, seed):
self._tools = {}
tool_map = {
"calendar": CalendarTool, "emails": EmailTool, "email": EmailTool,
"bookings": BookingsTool, "travel": TravelTool, "docs": DocsTool,
"expenses": ExpensesTool, "rooms": RoomsTool, "team": TeamTool,
"incidents": IncidentsTool,
}
for key, data in seed.items():
if key == "policies":
continue
cls = tool_map.get(key)
if cls and isinstance(data, list):
tool = cls()
tool.seed(data)
name = "email" if key == "emails" else key
self._tools[name] = tool
def reset(self):
self._task = TASKS[self._task_index % len(TASKS)]
self._task_index += 1
self._drift_applied = False
self._setup_tools(self._task.get("seed", {}))
self._state = EpisodeState(
task_id=self._task["id"],
task_description=self._task["description"],
max_steps=self._task.get("max_steps", 15),
)
return EAObservation(
success=True,
output=f"TASK: {self._task['title']}\n\n{self._task['description']}",
task_description=self._task["description"],
done=False,
schema_version=1,
)
def _maybe_inject_drift(self):
drift_step = self._task.get("drift_at_step")
if drift_step and self._state.step_count >= drift_step and not self._drift_applied:
drift = self._task.get("drift_event", {})
tool_name = drift.get("tool", "")
if tool_name == "emails":
tool_name = "email"
tool = self._tools.get(tool_name)
if tool:
tool.apply_drift(drift)
self._drift_applied = True
self._state.drift_events.append(drift.get("change", "unknown"))
return drift
return None
def step(self, action):
if self._state is None:
return EAObservation(success=False, error="Call reset() first", reward=-1.0, done=True)
self._state.step_count += 1
tool_name = action.tool if hasattr(action, 'tool') else action.get('tool', '')
act = action.action if hasattr(action, 'action') else action.get('action', '')
params = action.parameters if hasattr(action, 'parameters') else action.get('parameters', {})
drift = self._maybe_inject_drift()
drift_msg = ""
if drift:
dtype = drift.get("type", "")
if dtype == "schema_change":
drift_msg = f"\n⚠️ SCHEMA CHANGE: {drift.get('change', '')}. Check tool documentation."
elif dtype == "policy_change":
drift_msg = f"\n⚠️ POLICY CHANGE: {drift.get('change', '')}. Review updated policies."
elif dtype == "actor_conflict":
drift_msg = f"\n⚠️ NEW MESSAGE from {drift.get('actor', 'unknown')}: \"{drift.get('message', '')}\""
if tool_name == "system" and act == "submit":
return self._submit()
tool = self._tools.get(tool_name)
if not tool:
self._state.invalid_calls += 1
return EAObservation(
success=False, error=f"Unknown tool: {tool_name}{drift_msg}",
step_count=self._state.step_count,
drift_occurred=bool(drift),
)
self._state.tools_used.append(f"{tool_name}.{act}")
result = tool.execute(act, params)
if not result.get("success", False):
if result.get("policy_violated"):
self._state.policy_violations += 1
elif "schema_version" not in result:
self._state.invalid_calls += 1
if self._drift_applied and result.get("success"):
self._state.recovered_from_drift = True
output = json.dumps(result, indent=2) if isinstance(result, dict) else str(result)
output += drift_msg
done = self._state.step_count >= self._state.max_steps
if done:
return self._submit()
return EAObservation(
success=result.get("success", False),
output=output,
error=result.get("error"),
step_count=self._state.step_count,
schema_version=getattr(tool, '_schema_version', 1),
drift_occurred=bool(drift),
)
def _submit(self):
snapshots = {}
for name, tool in self._tools.items():
snapshots[name] = tool.snapshot()
if "email" in self._tools:
email_snap = self._tools["email"].snapshot()
if isinstance(email_snap, dict):
self._state.notifications_sent = [e.get("to", "") for e in email_snap.get("outbox", [])]
reward, violations, verdict = verify_episode(
task=self._task,
snapshots=snapshots,
policy_violations=self._state.policy_violations,
invalid_calls=self._state.invalid_calls,
tool_calls_made=self._state.step_count,
drift_events_handled=len(self._state.drift_events),
recovered_from_drift=self._state.recovered_from_drift,
)
self._state.completed = True
self._state.verdict = verdict
return EAObservation(
success=True,
output=json.dumps(verdict, indent=2),
reward=reward,
done=True,
step_count=self._state.step_count,
)
@property
def state(self):
return self._state