| """ |
| DriftPA β server/environment.py |
| ================================ |
| Core RL environment implementing OpenEnv 0.2.1 interface. |
| |
| Implements three novel mechanics: |
| 1. Schema Drift β API field names change mid-episode (steps 3-6) |
| 2. Time Pressure β tasks expire if not resolved within N steps |
| 3. Irreversible Actions β reply/book/cancel cannot be undone |
| |
| OpenEnv interface contract: |
| - reset(seed, episode_id, **kwargs) -> DriftPAObservation |
| - step(action, timeout_s, **kwargs) -> DriftPAObservation |
| - state (property) -> DriftPAState |
| """ |
|
|
| import uuid |
| import random |
| from typing import Optional |
|
|
| from openenv.core import Environment |
|
|
| from models import DriftPAAction, DriftPAObservation, DriftPAState |
| from scenarios.generator import generate_scenario |
|
|
| |
| |
| |
| R_URGENT_RESOLVED = +3.0 |
| R_NORMAL_RESOLVED = +2.0 |
| R_CORRECT_DRIFT = +2.0 |
| R_FEASIBLE_SCHEDULE = +2.0 |
| R_POLICY_QUERIED = +1.0 |
| R_CORRECT_FORMAT = +1.0 |
| R_EPISODE_BONUS = +5.0 |
|
|
| R_IRREVERSIBLE_CONFLICT = -4.0 |
| R_DOUBLE_BOOKING = -3.0 |
| R_POLICY_VIOLATION = -3.0 |
| R_TASK_EXPIRED = -3.0 |
| R_STALE_SCHEMA = -2.0 |
| R_MISSED_URGENT = -2.0 |
| R_REDUNDANT_CALL = -1.0 |
| R_CONTRADICTORY_REPLY = -1.0 |
| R_CATASTROPHIC = -5.0 |
|
|
| MAX_STEPS = 15 |
|
|
| |
| VALID_TOOLS = { |
| "read_message", "read_calendar", "reply_message", |
| "move_event", "book_restaurant", "cancel_booking", |
| "confirm_booking", "book_ride", "query_policy", |
| "list_tools", "finish", |
| } |
|
|
| |
| IRREVERSIBLE_TOOLS = { |
| "reply_message", "book_restaurant", |
| "cancel_booking", "confirm_booking", "book_ride", |
| } |
|
|
| |
| API_TO_TOOL = { |
| "restaurant": "book_restaurant", |
| "ride": "book_ride", |
| "calendar": "move_event", |
| "email": "reply_message", |
| "booking": "confirm_booking", |
| } |
|
|
|
|
| class DriftPAEnvironment(Environment[DriftPAAction, DriftPAObservation, DriftPAState]): |
| """Personal assistant RL environment with schema drift and time pressure.""" |
|
|
| SUPPORTS_CONCURRENT_SESSIONS = False |
|
|
| def __init__(self): |
| super().__init__() |
| |
| self._episode_id: str = "" |
| self._step_count: int = 0 |
| self._scenario: dict = {} |
| self._schema_version: int = 0 |
| self._drift_log: list = [] |
| self._policy_log: list = [] |
| self._irreversible_taken: list = [] |
| self._total_reward: float = 0.0 |
| self._tasks_resolved: int = 0 |
| self._tasks_expired: int = 0 |
| self._cascade_failures: int = 0 |
|
|
| |
| self._inbox: list = [] |
| self._calendar: list = [] |
| self._bookings: dict = {} |
| self._resolved_tasks: set = set() |
| self._policy_queried_this_step: bool = False |
| self._list_tools_called: bool = False |
| self._last_action_result: str = "" |
| self._done: bool = False |
|
|
| |
| self._active_schema: dict = {} |
| self._active_policy: dict = {} |
|
|
| |
| self._schema_drift_steps: dict = {} |
| self._policy_drift_step: int = -1 |
| self._policy_drift_spec: dict = {} |
|
|
| |
| self._read_messages: set = set() |
| self._calendar_read: bool = False |
| self._reply_targets: dict = {} |
| self._ride_booked: bool = False |
| self._restaurant_booked: bool = False |
| self._conflicts: list = [] |
| |
| self._policy_reward_claimed: bool = False |
|
|
| |
| |
| |
|
|
| def reset( |
| self, |
| seed: Optional[int] = None, |
| episode_id: Optional[str] = None, |
| **kwargs, |
| ) -> DriftPAObservation: |
| """Start a new episode. Loads a random (or seeded) scenario.""" |
|
|
| if seed is not None: |
| random.seed(seed) |
|
|
| self._episode_id = episode_id or str(uuid.uuid4()) |
| self._step_count = 0 |
| self._schema_version = 0 |
| self._drift_log = [] |
| self._policy_log = [] |
| self._irreversible_taken = [] |
| self._total_reward = 0.0 |
| self._tasks_resolved = 0 |
| self._tasks_expired = 0 |
| self._cascade_failures = 0 |
| self._done = False |
| self._last_action_result = "Episode started. Check inbox and calendar." |
| self._read_messages = set() |
| self._calendar_read = False |
| self._reply_targets = {} |
| self._ride_booked = False |
| self._restaurant_booked = False |
| self._conflicts = [] |
| self._policy_queried_this_step = False |
| self._list_tools_called = False |
| self._bookings = {} |
| self._resolved_tasks = set() |
| self._policy_reward_claimed = False |
|
|
| |
| self._scenario = generate_scenario(seed) |
| self._inbox = [dict(m) for m in self._scenario["inbox"]] |
| self._calendar = [dict(e) for e in self._scenario["calendar"]] |
| self._active_policy = dict(self._scenario["initial_policy"]) |
| self._schema_drift_steps = self._scenario["schema_drift_steps"] |
| self._policy_drift_step = self._scenario["policy_drift_step"] |
| self._policy_drift_spec = self._scenario["policy_drift_spec"] |
|
|
| |
| self._active_schema = self._build_tool_schemas() |
|
|
| return self._make_observation() |
|
|
| |
| |
| |
|
|
| def step( |
| self, |
| action: DriftPAAction, |
| timeout_s: Optional[float] = None, |
| **kwargs, |
| ) -> DriftPAObservation: |
| """Execute one action and advance the environment by one step.""" |
|
|
| if self._done: |
| return self._make_observation( |
| result="Episode is already done. Call reset() to start a new one." |
| ) |
|
|
| self._step_count += 1 |
| self._policy_queried_this_step = False |
| step_reward = 0.0 |
|
|
| |
| tool = action.tool_name |
| if tool not in VALID_TOOLS: |
| step_reward += R_REDUNDANT_CALL |
| result = f"ERROR: Unknown tool '{tool}'. Call list_tools() to see valid tools." |
| self._total_reward += step_reward |
| return self._make_observation(result=result, reward=step_reward) |
|
|
| |
| if self._step_count in self._schema_drift_steps: |
| self._apply_schema_drift(self._step_count) |
|
|
| |
| if self._step_count == self._policy_drift_step: |
| self._apply_policy_drift() |
|
|
| |
| expired = self._check_expiry() |
| for msg_id in expired: |
| step_reward += R_TASK_EXPIRED |
| self._tasks_expired += 1 |
| self._drift_log |
| if expired: |
| result_prefix = f"[EXPIRED: {expired}] " |
| else: |
| result_prefix = "" |
|
|
| |
| action_reward, action_result = self._dispatch(tool, action.payload) |
| step_reward += action_reward |
| result = result_prefix + action_result |
|
|
| |
| if tool == "finish" or self._step_count >= MAX_STEPS: |
| episode_bonus, bonus_msg = self._compute_episode_bonus() |
| step_reward += episode_bonus |
| result += bonus_msg |
| self._done = True |
|
|
| self._total_reward += step_reward |
| return self._make_observation(result=result, reward=step_reward) |
|
|
| |
| |
| |
|
|
| @property |
| def state(self) -> DriftPAState: |
| return DriftPAState( |
| episode_id=self._episode_id, |
| step_count=self._step_count, |
| schema_version=self._schema_version, |
| drift_log=list(self._drift_log), |
| policy_log=list(self._policy_log), |
| irreversible_actions_taken=list(self._irreversible_taken), |
| total_reward=self._total_reward, |
| tasks_resolved=self._tasks_resolved, |
| tasks_expired=self._tasks_expired, |
| cascade_failures=self._cascade_failures, |
| ) |
|
|
| |
| |
| |
|
|
| def _dispatch(self, tool: str, payload: dict) -> tuple[float, str]: |
| """Route to the correct handler. Returns (reward_delta, result_string).""" |
| handlers = { |
| "read_message": self._h_read_message, |
| "read_calendar": self._h_read_calendar, |
| "reply_message": self._h_reply_message, |
| "move_event": self._h_move_event, |
| "book_restaurant": self._h_book_restaurant, |
| "cancel_booking": self._h_cancel_booking, |
| "confirm_booking": self._h_confirm_booking, |
| "book_ride": self._h_book_ride, |
| "query_policy": self._h_query_policy, |
| "list_tools": self._h_list_tools, |
| "finish": self._h_finish, |
| } |
| return handlers[tool](payload) |
|
|
| def _h_read_message(self, payload: dict) -> tuple[float, str]: |
| msg_id = payload.get("message_id", "") |
| msg = next((m for m in self._inbox if m["id"] == msg_id), None) |
| if msg is None: |
| return R_REDUNDANT_CALL, f"ERROR: Message '{msg_id}' not found." |
| if msg_id in self._read_messages: |
| return R_REDUNDANT_CALL, f"Already read message '{msg_id}'." |
| self._read_messages.add(msg_id) |
| return 0.0, ( |
| f"FROM: {msg['sender']}\n" |
| f"SUBJECT: {msg['subject']}\n" |
| f"PRIORITY: {msg['priority']}\n" |
| f"EXPIRES AT STEP: {msg.get('expires_at_step', 'none')}\n" |
| f"BODY: {msg['body']}" |
| ) |
|
|
| def _h_read_calendar(self, payload: dict) -> tuple[float, str]: |
| if self._calendar_read and not self._calendar: |
| return R_REDUNDANT_CALL, "Calendar already read and empty." |
| self._calendar_read = True |
| if not self._calendar: |
| return 0.0, "Calendar is empty." |
| lines = ["CALENDAR:"] |
| for e in self._calendar: |
| |
| title = e.get("title") or e.get("event_name", "Unknown") |
| begins_at = e.get("begins_at") or e.get("start_time", "?") |
| lines.append( |
| f" [{e['id']}] {title} @ {begins_at} β {e.get('end_time','?')} " |
| f"({'confirmed' if e.get('confirmed') else 'tentative'})" |
| ) |
| return 0.0, "\n".join(lines) |
|
|
| def _h_reply_message(self, payload: dict) -> tuple[float, str]: |
| """IRREVERSIBLE β sends a reply to a message.""" |
| |
| msg_id = payload.get("message_id", "") |
| text = payload.get("text", "") |
|
|
| reward, schema_note = self._check_schema_payload("email", payload) |
|
|
| msg = next((m for m in self._inbox if m["id"] == msg_id), None) |
| if msg is None: |
| return R_CATASTROPHIC, f"CATASTROPHIC: Replied to unknown message '{msg_id}'." |
|
|
| if msg_id in self._reply_targets: |
| return R_CONTRADICTORY_REPLY, ( |
| f"WARNING: Sent contradictory reply to '{msg_id}'. " |
| f"Previous reply already sent. {schema_note}" |
| ) |
|
|
| self._reply_targets[msg_id] = text |
| self._irreversible_taken.append({ |
| "step": self._step_count, |
| "tool_name": "reply_message", |
| "payload": payload, |
| }) |
|
|
| |
| task_reward = self._resolve_task(msg_id, msg.get("priority", "normal")) |
| return reward + task_reward + R_CORRECT_FORMAT, ( |
| f"SENT reply to {msg['sender']}: '{text}'. " |
| f"IRREVERSIBLE. {schema_note}" |
| ) |
|
|
| def _h_move_event(self, payload: dict) -> tuple[float, str]: |
| event_id = payload.get("event_id", "") |
| |
| new_time = payload.get("new_time") or payload.get("begins_at", "") |
|
|
| event = next((e for e in self._calendar if e["id"] == event_id), None) |
| if event is None: |
| return R_REDUNDANT_CALL, f"ERROR: Event '{event_id}' not found." |
|
|
| |
| conflict = self._has_time_conflict(event_id, new_time) |
| if conflict: |
| self._cascade_failures += 1 |
| return R_DOUBLE_BOOKING, ( |
| f"ERROR: Moving '{event_id}' to {new_time} conflicts with " |
| f"'{conflict}'. Double-booking penalty applied." |
| ) |
|
|
| old_time = event.get("begins_at") or event.get("start_time", "?") |
| |
| if "begins_at" in event: |
| event["begins_at"] = new_time |
| else: |
| event["start_time"] = new_time |
|
|
| return 0.0, ( |
| f"Moved '{event_id}' from {old_time} to {new_time}. " |
| f"Schedule updated." |
| ) |
|
|
| def _h_book_restaurant(self, payload: dict) -> tuple[float, str]: |
| """IRREVERSIBLE β books a restaurant table.""" |
| reward, schema_note = self._check_schema_payload("restaurant", payload) |
|
|
| if self._restaurant_booked: |
| self._cascade_failures += 1 |
| return R_DOUBLE_BOOKING, ( |
| f"ERROR: Restaurant already booked. Double-booking penalty. {schema_note}" |
| ) |
|
|
| |
| guests = payload.get("guests") or payload.get("party_size", 0) |
| try: |
| guests = int(guests) |
| except (TypeError, ValueError): |
| guests = 0 |
|
|
| if guests > self._active_policy.get("guest_limit", 6): |
| return R_POLICY_VIOLATION, ( |
| f"POLICY VIOLATION: guest_limit is " |
| f"{self._active_policy['guest_limit']}, requested {guests}. {schema_note}" |
| ) |
|
|
| |
| |
| |
| cost_per_person = self._scenario.get("restaurant_cost_per_person", 0) |
| total_cost = cost_per_person * max(guests, 1) |
| max_reimb = self._active_policy.get("max_reimbursement", 999) |
| if cost_per_person > 0 and total_cost > max_reimb * max(guests, 1): |
| |
| return R_POLICY_VIOLATION, ( |
| f"POLICY VIOLATION: dinner costs ${cost_per_person}/person but " |
| f"max_reimbursement is ${max_reimb}/person. " |
| f"Call query_policy() before booking. {schema_note}" |
| ) |
|
|
| booking_id = f"rest_{self._step_count}" |
| self._bookings[booking_id] = { |
| "type": "restaurant", "payload": payload, "confirmed": False |
| } |
| self._restaurant_booked = True |
| self._irreversible_taken.append({ |
| "step": self._step_count, |
| "tool_name": "book_restaurant", |
| "payload": payload, |
| }) |
| return reward + R_NORMAL_RESOLVED, ( |
| f"Restaurant booked (id={booking_id}). " |
| f"Awaiting confirmation. IRREVERSIBLE. {schema_note}" |
| ) |
|
|
| def _h_cancel_booking(self, payload: dict) -> tuple[float, str]: |
| """IRREVERSIBLE β cancels an existing booking.""" |
| booking_id = payload.get("booking_id", "") |
| booking = self._bookings.get(booking_id) |
| if booking is None: |
| return R_CATASTROPHIC, ( |
| f"CATASTROPHIC: Tried to cancel unknown booking '{booking_id}'." |
| ) |
|
|
| |
| |
| |
| |
| |
| window_str = self._active_policy.get("cancellation_window", "2hr") |
| try: |
| hours = float(window_str.replace("hr", "").strip()) |
| except (ValueError, AttributeError): |
| hours = 2.0 |
| steps_needed = int(hours * 2) |
| cutoff_step = MAX_STEPS - steps_needed |
| if self._step_count > cutoff_step: |
| return R_POLICY_VIOLATION, ( |
| f"POLICY VIOLATION: Too late to cancel. cancellation_window is " |
| f"{window_str} β must cancel by step {cutoff_step}, " |
| f"but already at step {self._step_count}." |
| ) |
|
|
| booking["cancelled"] = True |
| self._irreversible_taken.append({ |
| "step": self._step_count, |
| "tool_name": "cancel_booking", |
| "payload": payload, |
| }) |
| return 0.0, f"Booking '{booking_id}' cancelled. IRREVERSIBLE." |
|
|
| def _h_confirm_booking(self, payload: dict) -> tuple[float, str]: |
| """IRREVERSIBLE β confirms a pending booking.""" |
| booking_id = payload.get("booking_id", "") |
| booking = self._bookings.get(booking_id) |
| if booking is None: |
| return R_CATASTROPHIC, ( |
| f"CATASTROPHIC: Tried to confirm unknown booking '{booking_id}'." |
| ) |
| if booking.get("confirmed"): |
| return R_REDUNDANT_CALL, f"Booking '{booking_id}' already confirmed." |
|
|
| booking["confirmed"] = True |
| self._irreversible_taken.append({ |
| "step": self._step_count, |
| "tool_name": "confirm_booking", |
| "payload": payload, |
| }) |
| return 0.0, f"Booking '{booking_id}' confirmed. IRREVERSIBLE." |
|
|
| def _h_book_ride(self, payload: dict) -> tuple[float, str]: |
| """IRREVERSIBLE β books a ride.""" |
| reward, schema_note = self._check_schema_payload("ride", payload) |
|
|
| if self._ride_booked: |
| self._cascade_failures += 1 |
| return R_DOUBLE_BOOKING, ( |
| f"ERROR: Ride already booked. Double-booking penalty. {schema_note}" |
| ) |
|
|
| booking_id = f"ride_{self._step_count}" |
| self._bookings[booking_id] = { |
| "type": "ride", "payload": payload, "confirmed": True |
| } |
| self._ride_booked = True |
| self._irreversible_taken.append({ |
| "step": self._step_count, |
| "tool_name": "book_ride", |
| "payload": payload, |
| }) |
| return reward + R_NORMAL_RESOLVED, ( |
| f"Ride booked (id={booking_id}). IRREVERSIBLE. {schema_note}" |
| ) |
|
|
| def _h_query_policy(self, payload: dict) -> tuple[float, str]: |
| self._policy_queried_this_step = True |
| lines = ["CURRENT POLICIES:"] |
| for k, v in self._active_policy.items(): |
| lines.append(f" {k}: {v}") |
|
|
| |
| if not self._policy_reward_claimed and self._policy_log: |
| self._policy_reward_claimed = True |
| reward = R_POLICY_QUERIED |
| lines.append(" [+1 reward: queried policy after drift]") |
| elif not self._policy_log: |
| |
| if not self._policy_reward_claimed: |
| self._policy_reward_claimed = True |
| reward = R_POLICY_QUERIED |
| else: |
| reward = 0.0 |
| else: |
| reward = 0.0 |
|
|
| return reward, "\n".join(lines) |
|
|
| def _h_list_tools(self, payload: dict) -> tuple[float, str]: |
| self._list_tools_called = True |
| lines = ["AVAILABLE TOOLS (current schema):"] |
| for tool_name, schema in self._active_schema.items(): |
| params = ", ".join( |
| f"{p}:{t}" for p, t in schema.get("params", {}).items() |
| ) |
| irrev = " [IRREVERSIBLE]" if tool_name in IRREVERSIBLE_TOOLS else "" |
| lines.append(f" {tool_name}({params}){irrev}") |
| return 0.0, "\n".join(lines) |
|
|
| def _h_finish(self, payload: dict) -> tuple[float, str]: |
| |
| return 0.0, "Agent called finish(). Computing final score..." |
|
|
| |
| |
| |
|
|
| def _apply_schema_drift(self, step: int): |
| """Apply all schema drifts scheduled for this step.""" |
| for spec in self._schema_drift_steps.get(step, []): |
| api = spec["api"] |
| old_key = spec["old"] |
| new_key = spec["new"] |
|
|
| |
| tool_key = API_TO_TOOL.get(api, api) |
| if tool_key in self._active_schema: |
| params = self._active_schema[tool_key].get("params", {}) |
| if old_key in params: |
| params[new_key] = params.pop(old_key) |
|
|
| |
| if api == "calendar": |
| for event in self._calendar: |
| if old_key in event: |
| event[new_key] = event.pop(old_key) |
|
|
| self._schema_version += 1 |
| self._drift_log.append({ |
| "step": step, "api": api, |
| "field_old": old_key, "field_new": new_key, |
| }) |
|
|
| def _apply_policy_drift(self): |
| """Apply the scheduled policy drift.""" |
| for rule, new_val in self._policy_drift_spec.items(): |
| old_val = self._active_policy.get(rule) |
| self._active_policy[rule] = new_val |
| self._policy_log.append({ |
| "step": self._step_count, "rule": rule, |
| "old_value": old_val, "new_value": new_val, |
| }) |
| |
| self._policy_reward_claimed = False |
|
|
| def _check_schema_payload(self, api: str, payload: dict) -> tuple[float, str]: |
| """Returns (reward, note) β +2 for correct post-drift fields, -2 for stale.""" |
| if self._schema_version == 0: |
| return 0.0, "" |
|
|
| tool_key = API_TO_TOOL.get(api, api) |
| current_params = self._active_schema.get(tool_key, {}).get("params", {}) |
| if not current_params: |
| return 0.0, "" |
|
|
| |
| drifted_apis = {d["api"] for d in self._drift_log} |
| if api not in drifted_apis: |
| return 0.0, "" |
|
|
| |
| old_to_new = { |
| d["field_old"]: d["field_new"] |
| for d in self._drift_log if d["api"] == api |
| } |
| stale_keys = [k for k in payload if k in old_to_new] |
| correct_keys = [k for k in payload if k in current_params] |
|
|
| if stale_keys: |
| return R_STALE_SCHEMA, ( |
| f"DRIFT ERROR: Used stale field(s) {stale_keys} for '{api}'. " |
| f"Correct fields: {list(current_params.keys())}. Stale schema penalty." |
| ) |
| if correct_keys: |
| return R_CORRECT_DRIFT, f"[Correct post-drift schema used for '{api}']" |
| return 0.0, "" |
|
|
| |
| |
| |
|
|
| def _check_expiry(self) -> list[str]: |
| """Mark tasks expired if their deadline has passed. Returns list of expired ids.""" |
| expired = [] |
| for msg in self._inbox: |
| exp = msg.get("expires_at_step") |
| if exp is None: |
| continue |
| if self._step_count > exp and msg["id"] not in self._resolved_tasks: |
| expired.append(msg["id"]) |
| self._resolved_tasks.add(msg["id"]) |
| return expired |
|
|
| def _resolve_task(self, msg_id: str, priority: str) -> float: |
| """Award task resolution reward. Returns reward delta.""" |
| if msg_id in self._resolved_tasks: |
| return R_REDUNDANT_CALL |
|
|
| msg = next((m for m in self._inbox if m["id"] == msg_id), None) |
| exp = msg.get("expires_at_step") if msg else None |
|
|
| self._resolved_tasks.add(msg_id) |
| self._tasks_resolved += 1 |
|
|
| if priority == "critical" or priority == "urgent": |
| if exp is not None and self._step_count <= exp: |
| return R_URGENT_RESOLVED |
| return R_MISSED_URGENT |
| return R_NORMAL_RESOLVED |
|
|
| |
| |
| |
|
|
| def _has_time_conflict(self, moving_id: str, new_time: str) -> Optional[str]: |
| """Check if placing event at new_time conflicts with another event. |
| Returns conflicting event id or None.""" |
| for event in self._calendar: |
| if event["id"] == moving_id: |
| continue |
| existing_time = event.get("begins_at") or event.get("start_time", "") |
| if existing_time == new_time: |
| return event["id"] |
| return None |
|
|
| |
| |
| |
|
|
| def _compute_episode_bonus(self) -> tuple[float, str]: |
| """+5 if ALL tasks resolved, +2 for clean schedule, -4 per calendar conflict if irreversible actions taken.""" |
| total_tasks = len([m for m in self._inbox if m.get("expires_at_step")]) |
| bonus = 0.0 |
| notes = [] |
|
|
| |
| times = [ |
| e.get("begins_at") or e.get("start_time") |
| for e in self._calendar |
| ] |
| non_null_times = [t for t in times if t] |
| conflict_count = len(non_null_times) - len(set(non_null_times)) |
|
|
| |
| |
| |
| if conflict_count > 0 and self._irreversible_taken: |
| cascade_penalty = R_IRREVERSIBLE_CONFLICT * conflict_count |
| bonus += cascade_penalty |
| self._cascade_failures += conflict_count |
| notes.append( |
| f"{cascade_penalty:.0f} CASCADE: {conflict_count} calendar conflict(s) " |
| f"after irreversible actions" |
| ) |
|
|
| |
| if conflict_count == 0: |
| bonus += R_FEASIBLE_SCHEDULE |
| notes.append(f"+{R_FEASIBLE_SCHEDULE} feasible schedule") |
|
|
| |
| if self._tasks_resolved >= total_tasks and self._tasks_expired == 0: |
| bonus += R_EPISODE_BONUS |
| notes.append(f"+{R_EPISODE_BONUS} EPISODE BONUS: all tasks resolved!") |
|
|
| note_str = (" EPISODE END: " + ", ".join(notes)) if notes else " EPISODE END." |
| return bonus, note_str |
|
|
| |
| |
| |
|
|
| def _build_tool_schemas(self) -> dict: |
| """Build the tool schema dict shown to the agent via list_tools().""" |
| return { |
| "read_message": {"params": {"message_id": "str"}}, |
| "read_calendar": {"params": {}}, |
| "reply_message": {"params": {"message_id": "str", "text": "str"}}, |
| "move_event": {"params": {"event_id": "str", "new_time": "str"}}, |
| "book_restaurant": {"params": { |
| "party_size": "int", |
| "date": "str", |
| "time": "str", |
| "restaurant": "str", |
| }}, |
| "cancel_booking": {"params": {"booking_id": "str"}}, |
| "confirm_booking": {"params": {"booking_id": "str"}}, |
| "book_ride": {"params": { |
| "pickup_time": "str", |
| "destination": "str", |
| }}, |
| "query_policy": {"params": {}}, |
| "list_tools": {"params": {}}, |
| "finish": {"params": {}}, |
| } |
|
|
| |
| |
| |
|
|
| def _make_observation( |
| self, |
| result: str = "", |
| reward: float = 0.0, |
| ) -> DriftPAObservation: |
| if result: |
| self._last_action_result = result |
|
|
| |
| urgent = [ |
| m["id"] for m in self._inbox |
| if m.get("expires_at_step") is not None |
| and m["id"] not in self._resolved_tasks |
| and m["expires_at_step"] <= self._step_count + 2 |
| ] |
|
|
| return DriftPAObservation( |
| inbox=list(self._inbox), |
| calendar=list(self._calendar), |
| available_tools=dict(self._active_schema), |
| policy=dict(self._active_policy), |
| time_step=self._step_count, |
| urgent_expiring=urgent, |
| last_action_result=self._last_action_result, |
| done=self._done, |
| reward=reward, |
| ) |
|
|