commitment-os / training /env_factory.py
jayantaggarwal-sketch
Sync latest code and non-binary artifacts
af8810b
"""Environment factory for TRL GRPOTrainer integration.
Wraps CommitmentOS as a callable that accepts model completions and
returns rewards, making it compatible with TRL's ``environment_factory``
pattern for multi-turn RL training.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from server.domain import ScenarioDef
from server.environment import CommitmentEnvironment
from server.tasks import get_all_scenarios
from models import CommitmentAction
TOOL_DESCRIPTIONS = """Available tools (respond with JSON):
- {"action_type": "view_calendar", "date": "2026-04-25"}
- {"action_type": "check_availability", "person": "Name"}
- {"action_type": "search_restaurants", "cuisine": "...", "max_price": 50, "dietary": "..."}
- {"action_type": "schedule_meeting", "title": "...", "date": "...", "time": "HH:MM", "participants": [...]}
- {"action_type": "reschedule_event", "event_id": "evt_X", "new_time": "HH:MM"}
- {"action_type": "cancel_event", "event_id": "evt_X"}
- {"action_type": "send_email", "to": "Name", "subject": "...", "body": "..."}
- {"action_type": "book_restaurant", "restaurant_name": "..."}
- {"action_type": "submit_plan"}"""
def build_system_prompt() -> str:
return (
"You are an expert executive assistant AI managing calendars, emails, and "
"dining reservations. For each turn, respond with EXACTLY ONE JSON tool call.\n\n"
f"{TOOL_DESCRIPTIONS}\n\n"
"Rules:\n"
"1. Respond with ONLY JSON, no markdown or explanation\n"
"2. Handle higher-priority items first\n"
"3. When cancelling/rescheduling commitments, ALWAYS email affected parties\n"
"4. Call submit_plan when all issues are resolved\n"
"5. Never silently drop a commitment"
)
def build_initial_prompt(scenario: ScenarioDef) -> str:
"""Build the user message for the first turn of an episode."""
from server.world import WorldState
world = WorldState(scenario)
calendar = json.dumps(world.get_calendar_snapshot(), indent=2)
inbox = json.dumps(world.get_inbox_snapshot(), indent=2)
return (
f"SCENARIO: {scenario.briefing}\n\n"
f"CALENDAR:\n{calendar}\n\n"
f"INBOX:\n{inbox}\n\n"
"What is your first action? Respond with a JSON tool call."
)
def parse_action_from_text(text: str) -> Dict[str, Any]:
"""Extract a JSON action from model output, with fallback to submit."""
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
try:
data = json.loads(text)
if isinstance(data, dict) and "action_type" in data:
return data
except (json.JSONDecodeError, ValueError):
pass
for line in text.split("\n"):
line = line.strip()
if line.startswith("{"):
try:
data = json.loads(line)
if isinstance(data, dict) and "action_type" in data:
return data
except (json.JSONDecodeError, ValueError):
continue
return {"action_type": "submit_plan"}
class CommitmentOSEnvFactory:
"""Wraps CommitmentOS for use with TRL's GRPOTrainer.
Usage with TRL::
from training.env_factory import CommitmentOSEnvFactory
factory = CommitmentOSEnvFactory(max_turns=8)
trainer = GRPOTrainer(
...
environment_factory=factory,
)
"""
def __init__(
self,
max_turns: int = 8,
scenario_ids: Optional[List[str]] = None,
) -> None:
self.max_turns = max_turns
self.scenario_ids = scenario_ids or list(get_all_scenarios().keys())
self.system_prompt = build_system_prompt()
def __call__(self, completions: List[str], **kwargs: Any) -> List[float]:
"""Evaluate a batch of model completions.
Each completion is treated as a full multi-turn transcript where
each line is one JSON action. Returns a list of final rewards.
"""
rewards: List[float] = []
for completion in completions:
reward = self._evaluate_single(completion)
rewards.append(reward)
return rewards
def _evaluate_single(self, completion: str) -> float:
import random
env = CommitmentEnvironment()
scenario_id = random.choice(self.scenario_ids)
env.reset(task_id=scenario_id)
actions = completion.strip().split("\n")
last_reward = 0.01
for i, action_text in enumerate(actions[: self.max_turns]):
action_dict = parse_action_from_text(action_text)
try:
action = CommitmentAction(**action_dict)
obs = env.step(action)
last_reward = obs.reward
if obs.done:
break
except Exception:
# Invalid action payloads should be penalized, not silently ignored.
last_reward = 0.01
break
if not env._done:
obs = env.step(CommitmentAction(action_type="submit_plan"))
last_reward = obs.reward
return float(last_reward)
def get_prompt(self, scenario_id: Optional[str] = None) -> List[Dict[str, str]]:
"""Build chat messages for a scenario."""
import random
from server.tasks import get_scenario
sid = scenario_id or random.choice(self.scenario_ids)
scenario = get_scenario(sid)
if scenario is None:
raise ValueError(f"Unknown scenario: {sid}")
return [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": build_initial_prompt(scenario)},
]