Spaces:
Sleeping
Sleeping
| """Environment factory for TRL GRPOTrainer integration. | |
| Wraps CommitmentOS as a callable that accepts model completions and | |
| returns rewards, making it compatible with TRL's ``environment_factory`` | |
| pattern for multi-turn RL training. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from server.domain import ScenarioDef | |
| from server.environment import CommitmentEnvironment | |
| from server.tasks import get_all_scenarios | |
| from models import CommitmentAction | |
| TOOL_DESCRIPTIONS = """Available tools (respond with JSON): | |
| - {"action_type": "view_calendar", "date": "2026-04-25"} | |
| - {"action_type": "check_availability", "person": "Name"} | |
| - {"action_type": "search_restaurants", "cuisine": "...", "max_price": 50, "dietary": "..."} | |
| - {"action_type": "schedule_meeting", "title": "...", "date": "...", "time": "HH:MM", "participants": [...]} | |
| - {"action_type": "reschedule_event", "event_id": "evt_X", "new_time": "HH:MM"} | |
| - {"action_type": "cancel_event", "event_id": "evt_X"} | |
| - {"action_type": "send_email", "to": "Name", "subject": "...", "body": "..."} | |
| - {"action_type": "book_restaurant", "restaurant_name": "..."} | |
| - {"action_type": "submit_plan"}""" | |
| def build_system_prompt() -> str: | |
| return ( | |
| "You are an expert executive assistant AI managing calendars, emails, and " | |
| "dining reservations. For each turn, respond with EXACTLY ONE JSON tool call.\n\n" | |
| f"{TOOL_DESCRIPTIONS}\n\n" | |
| "Rules:\n" | |
| "1. Respond with ONLY JSON, no markdown or explanation\n" | |
| "2. Handle higher-priority items first\n" | |
| "3. When cancelling/rescheduling commitments, ALWAYS email affected parties\n" | |
| "4. Call submit_plan when all issues are resolved\n" | |
| "5. Never silently drop a commitment" | |
| ) | |
| def build_initial_prompt(scenario: ScenarioDef) -> str: | |
| """Build the user message for the first turn of an episode.""" | |
| from server.world import WorldState | |
| world = WorldState(scenario) | |
| calendar = json.dumps(world.get_calendar_snapshot(), indent=2) | |
| inbox = json.dumps(world.get_inbox_snapshot(), indent=2) | |
| return ( | |
| f"SCENARIO: {scenario.briefing}\n\n" | |
| f"CALENDAR:\n{calendar}\n\n" | |
| f"INBOX:\n{inbox}\n\n" | |
| "What is your first action? Respond with a JSON tool call." | |
| ) | |
| def parse_action_from_text(text: str) -> Dict[str, Any]: | |
| """Extract a JSON action from model output, with fallback to submit.""" | |
| text = text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| text = "\n".join(lines[1:-1]) if len(lines) > 2 else text | |
| try: | |
| data = json.loads(text) | |
| if isinstance(data, dict) and "action_type" in data: | |
| return data | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| for line in text.split("\n"): | |
| line = line.strip() | |
| if line.startswith("{"): | |
| try: | |
| data = json.loads(line) | |
| if isinstance(data, dict) and "action_type" in data: | |
| return data | |
| except (json.JSONDecodeError, ValueError): | |
| continue | |
| return {"action_type": "submit_plan"} | |
| class CommitmentOSEnvFactory: | |
| """Wraps CommitmentOS for use with TRL's GRPOTrainer. | |
| Usage with TRL:: | |
| from training.env_factory import CommitmentOSEnvFactory | |
| factory = CommitmentOSEnvFactory(max_turns=8) | |
| trainer = GRPOTrainer( | |
| ... | |
| environment_factory=factory, | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| max_turns: int = 8, | |
| scenario_ids: Optional[List[str]] = None, | |
| ) -> None: | |
| self.max_turns = max_turns | |
| self.scenario_ids = scenario_ids or list(get_all_scenarios().keys()) | |
| self.system_prompt = build_system_prompt() | |
| def __call__(self, completions: List[str], **kwargs: Any) -> List[float]: | |
| """Evaluate a batch of model completions. | |
| Each completion is treated as a full multi-turn transcript where | |
| each line is one JSON action. Returns a list of final rewards. | |
| """ | |
| rewards: List[float] = [] | |
| for completion in completions: | |
| reward = self._evaluate_single(completion) | |
| rewards.append(reward) | |
| return rewards | |
| def _evaluate_single(self, completion: str) -> float: | |
| import random | |
| env = CommitmentEnvironment() | |
| scenario_id = random.choice(self.scenario_ids) | |
| env.reset(task_id=scenario_id) | |
| actions = completion.strip().split("\n") | |
| last_reward = 0.01 | |
| for i, action_text in enumerate(actions[: self.max_turns]): | |
| action_dict = parse_action_from_text(action_text) | |
| try: | |
| action = CommitmentAction(**action_dict) | |
| obs = env.step(action) | |
| last_reward = obs.reward | |
| if obs.done: | |
| break | |
| except Exception: | |
| # Invalid action payloads should be penalized, not silently ignored. | |
| last_reward = 0.01 | |
| break | |
| if not env._done: | |
| obs = env.step(CommitmentAction(action_type="submit_plan")) | |
| last_reward = obs.reward | |
| return float(last_reward) | |
| def get_prompt(self, scenario_id: Optional[str] = None) -> List[Dict[str, str]]: | |
| """Build chat messages for a scenario.""" | |
| import random | |
| from server.tasks import get_scenario | |
| sid = scenario_id or random.choice(self.scenario_ids) | |
| scenario = get_scenario(sid) | |
| if scenario is None: | |
| raise ValueError(f"Unknown scenario: {sid}") | |
| return [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": build_initial_prompt(scenario)}, | |
| ] | |