File size: 5,898 Bytes
6762657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af8810b
 
 
6762657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Environment factory for TRL GRPOTrainer integration.

Wraps CommitmentOS as a callable that accepts model completions and
returns rewards, making it compatible with TRL's ``environment_factory``
pattern for multi-turn RL training.
"""

from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from server.domain import ScenarioDef
from server.environment import CommitmentEnvironment
from server.tasks import get_all_scenarios
from models import CommitmentAction


TOOL_DESCRIPTIONS = """Available tools (respond with JSON):
- {"action_type": "view_calendar", "date": "2026-04-25"}
- {"action_type": "check_availability", "person": "Name"}
- {"action_type": "search_restaurants", "cuisine": "...", "max_price": 50, "dietary": "..."}
- {"action_type": "schedule_meeting", "title": "...", "date": "...", "time": "HH:MM", "participants": [...]}
- {"action_type": "reschedule_event", "event_id": "evt_X", "new_time": "HH:MM"}
- {"action_type": "cancel_event", "event_id": "evt_X"}
- {"action_type": "send_email", "to": "Name", "subject": "...", "body": "..."}
- {"action_type": "book_restaurant", "restaurant_name": "..."}
- {"action_type": "submit_plan"}"""


def build_system_prompt() -> str:
    return (
        "You are an expert executive assistant AI managing calendars, emails, and "
        "dining reservations. For each turn, respond with EXACTLY ONE JSON tool call.\n\n"
        f"{TOOL_DESCRIPTIONS}\n\n"
        "Rules:\n"
        "1. Respond with ONLY JSON, no markdown or explanation\n"
        "2. Handle higher-priority items first\n"
        "3. When cancelling/rescheduling commitments, ALWAYS email affected parties\n"
        "4. Call submit_plan when all issues are resolved\n"
        "5. Never silently drop a commitment"
    )


def build_initial_prompt(scenario: ScenarioDef) -> str:
    """Build the user message for the first turn of an episode."""
    from server.world import WorldState

    world = WorldState(scenario)
    calendar = json.dumps(world.get_calendar_snapshot(), indent=2)
    inbox = json.dumps(world.get_inbox_snapshot(), indent=2)

    return (
        f"SCENARIO: {scenario.briefing}\n\n"
        f"CALENDAR:\n{calendar}\n\n"
        f"INBOX:\n{inbox}\n\n"
        "What is your first action? Respond with a JSON tool call."
    )


def parse_action_from_text(text: str) -> Dict[str, Any]:
    """Extract a JSON action from model output, with fallback to submit."""
    text = text.strip()
    if text.startswith("```"):
        lines = text.split("\n")
        text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
    try:
        data = json.loads(text)
        if isinstance(data, dict) and "action_type" in data:
            return data
    except (json.JSONDecodeError, ValueError):
        pass

    for line in text.split("\n"):
        line = line.strip()
        if line.startswith("{"):
            try:
                data = json.loads(line)
                if isinstance(data, dict) and "action_type" in data:
                    return data
            except (json.JSONDecodeError, ValueError):
                continue

    return {"action_type": "submit_plan"}


class CommitmentOSEnvFactory:
    """Wraps CommitmentOS for use with TRL's GRPOTrainer.

    Usage with TRL::

        from training.env_factory import CommitmentOSEnvFactory

        factory = CommitmentOSEnvFactory(max_turns=8)

        trainer = GRPOTrainer(
            ...
            environment_factory=factory,
        )
    """

    def __init__(
        self,
        max_turns: int = 8,
        scenario_ids: Optional[List[str]] = None,
    ) -> None:
        self.max_turns = max_turns
        self.scenario_ids = scenario_ids or list(get_all_scenarios().keys())
        self.system_prompt = build_system_prompt()

    def __call__(self, completions: List[str], **kwargs: Any) -> List[float]:
        """Evaluate a batch of model completions.

        Each completion is treated as a full multi-turn transcript where
        each line is one JSON action. Returns a list of final rewards.
        """
        rewards: List[float] = []
        for completion in completions:
            reward = self._evaluate_single(completion)
            rewards.append(reward)
        return rewards

    def _evaluate_single(self, completion: str) -> float:
        import random

        env = CommitmentEnvironment()
        scenario_id = random.choice(self.scenario_ids)
        env.reset(task_id=scenario_id)

        actions = completion.strip().split("\n")
        last_reward = 0.01

        for i, action_text in enumerate(actions[: self.max_turns]):
            action_dict = parse_action_from_text(action_text)
            try:
                action = CommitmentAction(**action_dict)
                obs = env.step(action)
                last_reward = obs.reward
                if obs.done:
                    break
            except Exception:
                # Invalid action payloads should be penalized, not silently ignored.
                last_reward = 0.01
                break

        if not env._done:
            obs = env.step(CommitmentAction(action_type="submit_plan"))
            last_reward = obs.reward

        return float(last_reward)

    def get_prompt(self, scenario_id: Optional[str] = None) -> List[Dict[str, str]]:
        """Build chat messages for a scenario."""
        import random
        from server.tasks import get_scenario

        sid = scenario_id or random.choice(self.scenario_ids)
        scenario = get_scenario(sid)
        if scenario is None:
            raise ValueError(f"Unknown scenario: {sid}")

        return [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": build_initial_prompt(scenario)},
        ]