Spaces:
Sleeping
Sleeping
File size: 5,898 Bytes
6762657 af8810b 6762657 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """Environment factory for TRL GRPOTrainer integration.
Wraps CommitmentOS as a callable that accepts model completions and
returns rewards, making it compatible with TRL's ``environment_factory``
pattern for multi-turn RL training.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from server.domain import ScenarioDef
from server.environment import CommitmentEnvironment
from server.tasks import get_all_scenarios
from models import CommitmentAction
TOOL_DESCRIPTIONS = """Available tools (respond with JSON):
- {"action_type": "view_calendar", "date": "2026-04-25"}
- {"action_type": "check_availability", "person": "Name"}
- {"action_type": "search_restaurants", "cuisine": "...", "max_price": 50, "dietary": "..."}
- {"action_type": "schedule_meeting", "title": "...", "date": "...", "time": "HH:MM", "participants": [...]}
- {"action_type": "reschedule_event", "event_id": "evt_X", "new_time": "HH:MM"}
- {"action_type": "cancel_event", "event_id": "evt_X"}
- {"action_type": "send_email", "to": "Name", "subject": "...", "body": "..."}
- {"action_type": "book_restaurant", "restaurant_name": "..."}
- {"action_type": "submit_plan"}"""
def build_system_prompt() -> str:
return (
"You are an expert executive assistant AI managing calendars, emails, and "
"dining reservations. For each turn, respond with EXACTLY ONE JSON tool call.\n\n"
f"{TOOL_DESCRIPTIONS}\n\n"
"Rules:\n"
"1. Respond with ONLY JSON, no markdown or explanation\n"
"2. Handle higher-priority items first\n"
"3. When cancelling/rescheduling commitments, ALWAYS email affected parties\n"
"4. Call submit_plan when all issues are resolved\n"
"5. Never silently drop a commitment"
)
def build_initial_prompt(scenario: ScenarioDef) -> str:
"""Build the user message for the first turn of an episode."""
from server.world import WorldState
world = WorldState(scenario)
calendar = json.dumps(world.get_calendar_snapshot(), indent=2)
inbox = json.dumps(world.get_inbox_snapshot(), indent=2)
return (
f"SCENARIO: {scenario.briefing}\n\n"
f"CALENDAR:\n{calendar}\n\n"
f"INBOX:\n{inbox}\n\n"
"What is your first action? Respond with a JSON tool call."
)
def parse_action_from_text(text: str) -> Dict[str, Any]:
"""Extract a JSON action from model output, with fallback to submit."""
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
try:
data = json.loads(text)
if isinstance(data, dict) and "action_type" in data:
return data
except (json.JSONDecodeError, ValueError):
pass
for line in text.split("\n"):
line = line.strip()
if line.startswith("{"):
try:
data = json.loads(line)
if isinstance(data, dict) and "action_type" in data:
return data
except (json.JSONDecodeError, ValueError):
continue
return {"action_type": "submit_plan"}
class CommitmentOSEnvFactory:
"""Wraps CommitmentOS for use with TRL's GRPOTrainer.
Usage with TRL::
from training.env_factory import CommitmentOSEnvFactory
factory = CommitmentOSEnvFactory(max_turns=8)
trainer = GRPOTrainer(
...
environment_factory=factory,
)
"""
def __init__(
self,
max_turns: int = 8,
scenario_ids: Optional[List[str]] = None,
) -> None:
self.max_turns = max_turns
self.scenario_ids = scenario_ids or list(get_all_scenarios().keys())
self.system_prompt = build_system_prompt()
def __call__(self, completions: List[str], **kwargs: Any) -> List[float]:
"""Evaluate a batch of model completions.
Each completion is treated as a full multi-turn transcript where
each line is one JSON action. Returns a list of final rewards.
"""
rewards: List[float] = []
for completion in completions:
reward = self._evaluate_single(completion)
rewards.append(reward)
return rewards
def _evaluate_single(self, completion: str) -> float:
import random
env = CommitmentEnvironment()
scenario_id = random.choice(self.scenario_ids)
env.reset(task_id=scenario_id)
actions = completion.strip().split("\n")
last_reward = 0.01
for i, action_text in enumerate(actions[: self.max_turns]):
action_dict = parse_action_from_text(action_text)
try:
action = CommitmentAction(**action_dict)
obs = env.step(action)
last_reward = obs.reward
if obs.done:
break
except Exception:
# Invalid action payloads should be penalized, not silently ignored.
last_reward = 0.01
break
if not env._done:
obs = env.step(CommitmentAction(action_type="submit_plan"))
last_reward = obs.reward
return float(last_reward)
def get_prompt(self, scenario_id: Optional[str] = None) -> List[Dict[str, str]]:
"""Build chat messages for a scenario."""
import random
from server.tasks import get_scenario
sid = scenario_id or random.choice(self.scenario_ids)
scenario = get_scenario(sid)
if scenario is None:
raise ValueError(f"Unknown scenario: {sid}")
return [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": build_initial_prompt(scenario)},
]
|