Spaces:
Sleeping
Sleeping
jayantaggarwal-sketch commited on
Commit ·
6762657
0
Parent(s):
CommitmentOS: temporal commitment coherence RL environment
Browse files- .gitignore +13 -0
- Dockerfile +25 -0
- HF_README.md +77 -0
- README.md +190 -0
- __init__.py +9 -0
- conftest.py +10 -0
- constants.py +14 -0
- inference.py +225 -0
- models.py +87 -0
- openenv.yaml +82 -0
- pyproject.toml +43 -0
- requirements.txt +6 -0
- server/__init__.py +0 -0
- server/app.py +54 -0
- server/domain.py +131 -0
- server/environment.py +244 -0
- server/graders.py +236 -0
- server/mcp.py +65 -0
- server/tasks.py +616 -0
- server/world.py +290 -0
- tests/__init__.py +0 -0
- tests/test_environment.py +523 -0
- training/__init__.py +0 -0
- training/env_factory.py +167 -0
- training/train_grpo.py +174 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
.env
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
.pytest_cache/
|
| 11 |
+
.ruff_cache/
|
| 12 |
+
*.log
|
| 13 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends curl \
|
| 6 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
COPY requirements.txt requirements.txt
|
| 9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
+
|
| 11 |
+
COPY constants.py ./constants.py
|
| 12 |
+
COPY models.py ./models.py
|
| 13 |
+
COPY __init__.py ./__init__.py
|
| 14 |
+
COPY server/ ./server/
|
| 15 |
+
COPY openenv.yaml ./openenv.yaml
|
| 16 |
+
COPY inference.py ./inference.py
|
| 17 |
+
|
| 18 |
+
ENV PORT=7860
|
| 19 |
+
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 23 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 24 |
+
|
| 25 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
HF_README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CommitmentOS
|
| 3 |
+
emoji: 📋
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- commitment-coherence
|
| 12 |
+
- personal-task-management
|
| 13 |
+
- multi-turn
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# CommitmentOS: Training Temporal Commitment Coherence in LLMs
|
| 17 |
+
|
| 18 |
+
**The first RL environment that trains LLMs to keep their promises.**
|
| 19 |
+
|
| 20 |
+
CommitmentOS is a multi-turn personal task management environment where
|
| 21 |
+
agents manage calendars, emails, and dining reservations across realistic
|
| 22 |
+
scenarios. The key innovation: the agent's own prior decisions create
|
| 23 |
+
binding future constraints tracked via a **commitment ledger**, and
|
| 24 |
+
violations are penalised regardless of how many turns have elapsed.
|
| 25 |
+
|
| 26 |
+
## Quick Start
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
# Reset to a scenario
|
| 30 |
+
curl -X POST "https://jayant2304-commitment-os.hf.space/reset?task_id=easy_001"
|
| 31 |
+
|
| 32 |
+
# Make a tool call
|
| 33 |
+
curl -X POST "https://jayant2304-commitment-os.hf.space/step" \
|
| 34 |
+
-H "Content-Type: application/json" \
|
| 35 |
+
-d '{"action_type": "view_calendar", "date": "2026-04-25"}'
|
| 36 |
+
|
| 37 |
+
# Get state
|
| 38 |
+
curl "https://jayant2304-commitment-os.hf.space/state"
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## API Endpoints
|
| 42 |
+
|
| 43 |
+
| Endpoint | Method | Description |
|
| 44 |
+
|----------|--------|-------------|
|
| 45 |
+
| `/reset` | POST | Start a new episode (optional: `task_id`, `difficulty`) |
|
| 46 |
+
| `/step` | POST | Execute one tool call |
|
| 47 |
+
| `/state` | GET | Current episode state |
|
| 48 |
+
| `/health` | GET | Health check |
|
| 49 |
+
| `/tasks` | GET | List all available scenarios |
|
| 50 |
+
| `/mcp` | POST | MCP JSON-RPC 2.0 |
|
| 51 |
+
|
| 52 |
+
## 15 Scenarios (5 Easy / 5 Medium / 5 Hard)
|
| 53 |
+
|
| 54 |
+
Scenarios range from simple calendar reschedules to multi-crisis cascades
|
| 55 |
+
with information asymmetry and production incidents interrupting a full day
|
| 56 |
+
of commitments.
|
| 57 |
+
|
| 58 |
+
## Reward Function (5 components)
|
| 59 |
+
|
| 60 |
+
| Component | Weight | Signal |
|
| 61 |
+
|-----------|--------|--------|
|
| 62 |
+
| Constraint Satisfaction | 35% | Binary per-constraint checks |
|
| 63 |
+
| Conflict Resolution | 20% | Calendar free of overlaps |
|
| 64 |
+
| **Commitment Coherence** | **20%** | **Violations tracked via ledger** |
|
| 65 |
+
| Communication Quality | 15% | Keyword matching on emails |
|
| 66 |
+
| Step Efficiency | 10% | Fewer steps = higher score |
|
| 67 |
+
|
| 68 |
+
## What Makes This Novel
|
| 69 |
+
|
| 70 |
+
Existing constraint-satisfaction environments compute dependency graphs
|
| 71 |
+
upfront. CommitmentOS is different: constraints **emerge from the agent's
|
| 72 |
+
own decisions** as the episode unfolds. A meeting scheduled in turn 2
|
| 73 |
+
becomes a binding constraint in turn 7. Breaking it without communication
|
| 74 |
+
is a tracked, penalised violation.
|
| 75 |
+
|
| 76 |
+
This is **temporal commitment coherence** — a capability no existing RL
|
| 77 |
+
environment trains.
|
README.md
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CommitmentOS: Training Temporal Commitment Coherence in LLMs
|
| 2 |
+
|
| 3 |
+
> *The first RL environment that trains LLMs to keep their promises.*
|
| 4 |
+
|
| 5 |
+
**Innovation claim**: The first RL environment for training temporal commitment coherence — where the agent's own prior decisions create binding future constraints, tracked and penalised across multi-turn episodes.
|
| 6 |
+
|
| 7 |
+
**Theme**: Primary 3.2 (Personal Tasks) + Secondary Theme 2 (Long-Horizon Planning)
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## Architecture
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
┌──────────────── Client ────────────────┐ ┌────────────── CommitmentOS Server ──────────────┐
|
| 15 |
+
│ │ │ │
|
| 16 |
+
│ inference.py ──HTTP──▶ POST /reset │────▶│ FastAPI App │
|
| 17 |
+
│ (LLM agent) HTTP──▶ POST /step │ │ │ │
|
| 18 |
+
│ HTTP──▶ GET /state │ │ ▼ │
|
| 19 |
+
│ │ │ CommitmentEnvironment │
|
| 20 |
+
│ train_grpo.py │ │ ├── WorldState (calendar, contacts, │
|
| 21 |
+
│ (GRPO+TRL) │ │ │ restaurants, inbox) │
|
| 22 |
+
│ │ │ ├── CommitmentLedger (tracks promises) │
|
| 23 |
+
│ │ │ └── Grader (5-component reward) │
|
| 24 |
+
└────────────────────────────────────────┘ └─────────────────────────────────────────────────┘
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Why CommitmentOS is Novel
|
| 28 |
+
|
| 29 |
+
Existing constraint-satisfaction environments (GAP, LGC-MARL, NeMo Gym, PEARL) compute dependency graphs **upfront**. CommitmentOS is fundamentally different:
|
| 30 |
+
|
| 31 |
+
- **Constraints emerge from the agent's own decisions** as the episode unfolds
|
| 32 |
+
- A meeting scheduled in turn 2 becomes a **binding constraint** in turn 7
|
| 33 |
+
- Breaking it without communication is a **tracked, penalised violation**
|
| 34 |
+
- The commitment ledger persists across the full episode — the agent must remember what it promised
|
| 35 |
+
|
| 36 |
+
This is **temporal commitment coherence** — a capability no existing RL environment trains.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Quick Start
|
| 41 |
+
|
| 42 |
+
### Local Development
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
cd commitment_os
|
| 46 |
+
|
| 47 |
+
# Create virtual environment
|
| 48 |
+
python3 -m venv .venv && source .venv/bin/activate
|
| 49 |
+
pip install -r requirements.txt
|
| 50 |
+
|
| 51 |
+
# Start server
|
| 52 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860 --reload
|
| 53 |
+
|
| 54 |
+
# Run tests
|
| 55 |
+
pip install pytest httpx
|
| 56 |
+
pytest tests/ -v
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Docker
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
docker build -t commitment-os .
|
| 63 |
+
docker run -p 7860:7860 commitment-os
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### API Usage
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
# Reset to a scenario
|
| 70 |
+
curl -X POST "http://localhost:7860/reset?task_id=easy_001"
|
| 71 |
+
|
| 72 |
+
# Make a tool call (multi-turn — one per step)
|
| 73 |
+
curl -X POST "http://localhost:7860/step" \
|
| 74 |
+
-H "Content-Type: application/json" \
|
| 75 |
+
-d '{"action": {"action_type": "view_calendar", "date": "2026-04-25"}}'
|
| 76 |
+
|
| 77 |
+
# Get state
|
| 78 |
+
curl "http://localhost:7860/state"
|
| 79 |
+
|
| 80 |
+
# List all scenarios
|
| 81 |
+
curl "http://localhost:7860/tasks"
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## Reward Function (5 Components)
|
| 87 |
+
|
| 88 |
+
| Component | Weight | How it's Measured |
|
| 89 |
+
|-----------|--------|-------------------|
|
| 90 |
+
| **Constraint Satisfaction** | 35% | Binary per-constraint checks |
|
| 91 |
+
| **Conflict Resolution** | 20% | Final calendar free of overlapping events |
|
| 92 |
+
| **Commitment Coherence** | 20% | `(total - silent_violations) / total` from ledger |
|
| 93 |
+
| **Communication Quality** | 15% | Keyword matching on sent emails |
|
| 94 |
+
| **Step Efficiency** | 10% | `max(0, 1 - (steps - optimal) × 0.1)` |
|
| 95 |
+
|
| 96 |
+
**Example** (easy_001 — perfect run):
|
| 97 |
+
```
|
| 98 |
+
constraints: 3/3 met → 0.35 × 1.0 = 0.350
|
| 99 |
+
conflicts: 0 overlaps → 0.20 × 1.0 = 0.200
|
| 100 |
+
commitments: 1 honored → 0.20 × 1.0 = 0.200
|
| 101 |
+
emails: Team notified → 0.15 × 1.0 = 0.150
|
| 102 |
+
efficiency: 3 steps (opt 3) → 0.10 × 1.0 = 0.100
|
| 103 |
+
─────────────────────────────────────────────
|
| 104 |
+
total = 0.99 (clamped to [0.01, 0.99])
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## 15 Scenarios
|
| 110 |
+
|
| 111 |
+
### Easy (2-4 steps)
|
| 112 |
+
| ID | Description |
|
| 113 |
+
|----|-------------|
|
| 114 |
+
| easy_001 | Double-booked meetings — reschedule by priority |
|
| 115 |
+
| easy_002 | Book dinner with cuisine/price/distance constraints |
|
| 116 |
+
| easy_003 | Check availability and propose meeting slots |
|
| 117 |
+
| easy_004 | Cancel conflicting work meeting for personal appointment |
|
| 118 |
+
| easy_005 | Triage inbox by urgency priority |
|
| 119 |
+
|
| 120 |
+
### Medium (5-8 steps)
|
| 121 |
+
| ID | Description |
|
| 122 |
+
|----|-------------|
|
| 123 |
+
| med_006 | Cascading reschedule chain (A→B→C dependency) |
|
| 124 |
+
| med_007 | Team dinner with 3 dietary + distance + budget constraints |
|
| 125 |
+
| med_008 | Boss's urgent request during client call (commitment conflict) |
|
| 126 |
+
| med_009 | Disambiguate vague "push our thing" across 3 recurring meetings |
|
| 127 |
+
| med_010 | Client visit: conference room + lunch + itinerary |
|
| 128 |
+
|
| 129 |
+
### Hard (8-15 steps)
|
| 130 |
+
| ID | Description |
|
| 131 |
+
|----|-------------|
|
| 132 |
+
| hard_011 | VP investor dinner: cascade, restaurant, multi-party notification |
|
| 133 |
+
| hard_012 | Triple conference room conflict with diplomatic resolution |
|
| 134 |
+
| hard_013 | Triple crisis: cancelled flight + moved board prep + lost reservation |
|
| 135 |
+
| hard_014 | Information asymmetry — schedule without revealing confidential reasons |
|
| 136 |
+
| hard_015 | **SRE Crisis** — production incident interrupts day of commitments |
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## Training
|
| 141 |
+
|
| 142 |
+
### GRPO + TRL + LoRA
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
pip install trl transformers peft datasets torch
|
| 146 |
+
|
| 147 |
+
python training/train_grpo.py \
|
| 148 |
+
--model Qwen/Qwen2.5-1.5B-Instruct \
|
| 149 |
+
--epochs 2 \
|
| 150 |
+
--lr 5e-6 \
|
| 151 |
+
--lora_rank 16 \
|
| 152 |
+
--batch_size 4
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
**What improves with training:**
|
| 156 |
+
- Constraint satisfaction score ↑
|
| 157 |
+
- Commitment violation rate ↓
|
| 158 |
+
- Steps per episode ↓
|
| 159 |
+
- Communication quality ↑
|
| 160 |
+
|
| 161 |
+
---
|
| 162 |
+
|
| 163 |
+
## Submission Compliance
|
| 164 |
+
|
| 165 |
+
| Requirement | Status |
|
| 166 |
+
|-------------|--------|
|
| 167 |
+
| reset() / step() / state() | ✅ |
|
| 168 |
+
| openenv.yaml with 15 tasks | ✅ |
|
| 169 |
+
| Programmatic graders, scores ∈ (0, 1) | ✅ |
|
| 170 |
+
| inference.py at root using openai client | ✅ |
|
| 171 |
+
| [START]/[STEP]/[END] log format | ✅ |
|
| 172 |
+
| API_BASE_URL / MODEL_NAME / HF_TOKEN from env | ✅ |
|
| 173 |
+
| Dockerfile builds and responds to /reset | ✅ |
|
| 174 |
+
| pyproject.toml with [project.scripts] | ✅ |
|
| 175 |
+
| uv.lock generated | ✅ |
|
| 176 |
+
| server/app.py main() with if __name__ | ✅ |
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## Story Hook
|
| 181 |
+
|
| 182 |
+
> "Every AI assistant today can schedule one meeting. But your real life is never one meeting. CommitmentOS trains AI to juggle the chaos — and penalises it when it breaks its own promises."
|
| 183 |
+
|
| 184 |
+
**Connection to Round 1**: In Round 1, we trained agents to diagnose production incidents. In Round 2, we asked: *what happens when that incident interrupts a day full of commitments?* CommitmentOS was born. Hard scenario `hard_015` directly reuses SRE incident data from Round 1.
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## License
|
| 189 |
+
|
| 190 |
+
MIT
|
__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CommitmentOS — Temporal Commitment Coherence RL Environment."""
|
| 2 |
+
|
| 3 |
+
from models import CommitmentAction, CommitmentObservation, CommitmentState
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"CommitmentAction",
|
| 7 |
+
"CommitmentObservation",
|
| 8 |
+
"CommitmentState",
|
| 9 |
+
]
|
conftest.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest configuration — ensures the project root is on sys.path for all tests."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
PROJECT_ROOT = str(Path(__file__).resolve().parent)
|
| 9 |
+
if PROJECT_ROOT not in sys.path:
|
| 10 |
+
sys.path.insert(0, PROJECT_ROOT)
|
constants.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Project-wide constants — single source of truth for version and metadata."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
VERSION = "0.1.0"
|
| 6 |
+
PROJECT_NAME = "commitment-os"
|
| 7 |
+
PROJECT_DESCRIPTION = (
|
| 8 |
+
"CommitmentOS: the first RL environment that trains temporal commitment "
|
| 9 |
+
"coherence in LLMs. Agents manage a simulated personal world (calendar, "
|
| 10 |
+
"email, restaurants, contacts) across multi-turn episodes where their own "
|
| 11 |
+
"prior decisions create binding constraints tracked and penalised via a "
|
| 12 |
+
"commitment ledger."
|
| 13 |
+
)
|
| 14 |
+
AUTHOR = "Jayant Aggarwal"
|
inference.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Baseline inference script for CommitmentOS.
|
| 2 |
+
|
| 3 |
+
Uses an OpenAI-compatible LLM to play through all 15 scenarios.
|
| 4 |
+
Multi-turn: the agent gets the briefing, makes tool calls, then submits.
|
| 5 |
+
|
| 6 |
+
Required environment variables:
|
| 7 |
+
API_BASE_URL — OpenAI-compatible endpoint
|
| 8 |
+
MODEL_NAME — model identifier
|
| 9 |
+
HF_TOKEN — API key (also checked as OPENAI_API_KEY)
|
| 10 |
+
ENV_BASE_URL — CommitmentOS server URL (default: HF Space)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
from typing import Any, Dict, List
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
from openai import OpenAI
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Configuration
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 29 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 30 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or ""
|
| 31 |
+
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://jayant2304-commitment-os.hf.space")
|
| 32 |
+
|
| 33 |
+
MAX_STEPS = 12
|
| 34 |
+
|
| 35 |
+
SYSTEM_PROMPT = """You are an expert executive assistant AI. You manage calendars, emails, and dining reservations.
|
| 36 |
+
|
| 37 |
+
You will be given a scenario briefing describing a situation with calendar conflicts, emails, or planning tasks.
|
| 38 |
+
|
| 39 |
+
For each turn, you must respond with EXACTLY ONE JSON object choosing a tool to call:
|
| 40 |
+
|
| 41 |
+
Available tools:
|
| 42 |
+
- {"action_type": "view_calendar", "date": "2026-04-25"}
|
| 43 |
+
- {"action_type": "check_availability", "person": "Client_Jones"}
|
| 44 |
+
- {"action_type": "search_restaurants", "cuisine": "Italian", "max_price": 50, "dietary": "vegetarian", "max_distance_miles": 3.0, "near_airport": false}
|
| 45 |
+
- {"action_type": "schedule_meeting", "title": "Demo", "date": "2026-04-25", "time": "14:00", "duration_min": 60, "participants": ["Client_Jones"], "location": "Room A"}
|
| 46 |
+
- {"action_type": "reschedule_event", "event_id": "evt_1", "new_time": "15:00"}
|
| 47 |
+
- {"action_type": "cancel_event", "event_id": "evt_1"}
|
| 48 |
+
- {"action_type": "send_email", "to": "VP_Chen", "subject": "Meeting update", "body": "Hi, I need to reschedule..."}
|
| 49 |
+
- {"action_type": "book_restaurant", "restaurant_name": "Sky Lounge"}
|
| 50 |
+
- {"action_type": "submit_plan"}
|
| 51 |
+
|
| 52 |
+
IMPORTANT RULES:
|
| 53 |
+
1. Respond with ONLY a JSON object, no markdown, no explanation
|
| 54 |
+
2. Handle higher-priority items before lower-priority ones
|
| 55 |
+
3. When cancelling or rescheduling commitments, ALWAYS send an email to affected parties BEFORE submitting
|
| 56 |
+
4. Call submit_plan when you have resolved all issues
|
| 57 |
+
5. Never silently drop a commitment — always notify the affected person"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# Logging helpers — exact format required by hackathon evaluator
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 65 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: str | None = None) -> None:
|
| 69 |
+
err = error if error else "null"
|
| 70 |
+
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={'true' if done else 'false'} error={err}", flush=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 74 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 75 |
+
print(f"[END] success={'true' if success else 'false'} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Environment interaction
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def env_reset(task_id: str) -> Dict[str, Any]:
|
| 83 |
+
resp = requests.post(f"{ENV_BASE_URL}/reset", params={"task_id": task_id}, timeout=30)
|
| 84 |
+
resp.raise_for_status()
|
| 85 |
+
data = resp.json()
|
| 86 |
+
return data.get("observation", data)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def env_step(action: Dict[str, Any]) -> Dict[str, Any]:
|
| 90 |
+
resp = requests.post(f"{ENV_BASE_URL}/step", json={"action": action}, timeout=30)
|
| 91 |
+
resp.raise_for_status()
|
| 92 |
+
data = resp.json()
|
| 93 |
+
obs = data.get("observation", data)
|
| 94 |
+
obs["done"] = data.get("done", obs.get("done", False))
|
| 95 |
+
obs["reward"] = data.get("reward", obs.get("reward", 0.0))
|
| 96 |
+
return obs
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_task_ids() -> List[str]:
|
| 100 |
+
resp = requests.get(f"{ENV_BASE_URL}/tasks", timeout=30)
|
| 101 |
+
resp.raise_for_status()
|
| 102 |
+
data = resp.json()
|
| 103 |
+
ids: List[str] = []
|
| 104 |
+
for difficulty in ["easy", "medium", "hard"]:
|
| 105 |
+
ids.extend(data.get(difficulty, []))
|
| 106 |
+
return ids
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# LLM call
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def call_llm(client: OpenAI, messages: List[Dict[str, str]]) -> str:
|
| 114 |
+
response = client.chat.completions.create(
|
| 115 |
+
model=MODEL_NAME,
|
| 116 |
+
messages=messages,
|
| 117 |
+
temperature=0.2,
|
| 118 |
+
max_tokens=512,
|
| 119 |
+
stream=False,
|
| 120 |
+
)
|
| 121 |
+
return response.choices[0].message.content.strip()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def parse_action(text: str) -> Dict[str, Any]:
|
| 125 |
+
text = text.strip()
|
| 126 |
+
if text.startswith("```"):
|
| 127 |
+
lines = text.split("\n")
|
| 128 |
+
text = "\n".join(lines[1:-1]) if len(lines) > 2 else lines[0]
|
| 129 |
+
try:
|
| 130 |
+
return json.loads(text)
|
| 131 |
+
except json.JSONDecodeError:
|
| 132 |
+
return {"action_type": "submit_plan"}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
# Run one task
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
def run_task(client: OpenAI, task_id: str) -> Dict[str, Any]:
|
| 140 |
+
rewards: List[float] = []
|
| 141 |
+
steps_taken = 0
|
| 142 |
+
score = 0.01
|
| 143 |
+
success = False
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
obs = env_reset(task_id)
|
| 147 |
+
log_start(task=task_id, env="commitment-os", model=MODEL_NAME)
|
| 148 |
+
|
| 149 |
+
briefing = obs.get("briefing", "")
|
| 150 |
+
calendar = json.dumps(obs.get("calendar_snapshot", []), indent=2)
|
| 151 |
+
inbox = json.dumps(obs.get("inbox", []), indent=2)
|
| 152 |
+
|
| 153 |
+
messages: List[Dict[str, str]] = [
|
| 154 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 155 |
+
{"role": "user", "content": f"SCENARIO: {briefing}\n\nCALENDAR:\n{calendar}\n\nINBOX:\n{inbox}\n\nWhat is your first action?"},
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
for step_num in range(1, MAX_STEPS + 1):
|
| 159 |
+
llm_output = call_llm(client, messages)
|
| 160 |
+
action = parse_action(llm_output)
|
| 161 |
+
|
| 162 |
+
step_data = env_step(action)
|
| 163 |
+
reward = float(step_data.get("reward", 0.0) or 0.0)
|
| 164 |
+
done = step_data.get("done", False)
|
| 165 |
+
steps_taken = step_num
|
| 166 |
+
rewards.append(reward)
|
| 167 |
+
|
| 168 |
+
action_str = json.dumps(action, separators=(",", ":"))
|
| 169 |
+
log_step(step=step_num, action=action_str, reward=reward, done=done)
|
| 170 |
+
|
| 171 |
+
if done:
|
| 172 |
+
score = max(0.01, min(0.99, reward))
|
| 173 |
+
success = score > 0.01
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
tool_result = step_data.get("tool_result", "")
|
| 177 |
+
messages.append({"role": "assistant", "content": llm_output})
|
| 178 |
+
messages.append({"role": "user", "content": f"TOOL RESULT: {tool_result}\n\nWhat is your next action?"})
|
| 179 |
+
|
| 180 |
+
if not done:
|
| 181 |
+
step_data = env_step({"action_type": "submit_plan"})
|
| 182 |
+
reward = float(step_data.get("reward", 0.0) or 0.0)
|
| 183 |
+
steps_taken += 1
|
| 184 |
+
rewards.append(reward)
|
| 185 |
+
score = max(0.01, min(0.99, reward))
|
| 186 |
+
success = score > 0.01
|
| 187 |
+
log_step(step=steps_taken, action='{"action_type":"submit_plan"}', reward=reward, done=True)
|
| 188 |
+
|
| 189 |
+
except Exception as exc:
|
| 190 |
+
steps_taken = max(steps_taken, 1)
|
| 191 |
+
if not rewards:
|
| 192 |
+
rewards.append(0.01)
|
| 193 |
+
log_step(step=steps_taken, action="error", reward=0.01, done=True, error=str(exc))
|
| 194 |
+
|
| 195 |
+
finally:
|
| 196 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 197 |
+
|
| 198 |
+
return {"task_id": task_id, "reward": score, "success": success}
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
# Main
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
def main() -> None:
|
| 206 |
+
if not API_KEY:
|
| 207 |
+
print("ERROR: Set HF_TOKEN or OPENAI_API_KEY environment variable", file=sys.stderr)
|
| 208 |
+
sys.exit(1)
|
| 209 |
+
|
| 210 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 211 |
+
task_ids = get_task_ids()
|
| 212 |
+
|
| 213 |
+
results: List[Dict[str, Any]] = []
|
| 214 |
+
for tid in task_ids:
|
| 215 |
+
result = run_task(client, tid)
|
| 216 |
+
results.append(result)
|
| 217 |
+
|
| 218 |
+
total = len(results)
|
| 219 |
+
successes = sum(1 for r in results if r["success"])
|
| 220 |
+
mean_reward = sum(r["reward"] for r in results) / total if total > 0 else 0.0
|
| 221 |
+
print(f"\n# Summary: {successes}/{total} tasks succeeded, mean_reward={mean_reward:.3f}", flush=True)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API-facing Pydantic models — the public contract of CommitmentOS."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from pydantic import Field
|
| 8 |
+
from openenv.core.env_server import Action, Observation, State
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CommitmentAction(Action):
|
| 12 |
+
"""Agent's tool call submitted via POST /step.
|
| 13 |
+
|
| 14 |
+
Each step is one tool invocation. The agent fills ``action_type`` and
|
| 15 |
+
the relevant subset of optional parameters for that tool.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
action_type: str = Field(
|
| 19 |
+
...,
|
| 20 |
+
description=(
|
| 21 |
+
"Tool to invoke: 'view_calendar' | 'check_availability' | "
|
| 22 |
+
"'search_restaurants' | 'schedule_meeting' | 'reschedule_event' | "
|
| 23 |
+
"'cancel_event' | 'send_email' | 'submit_plan'"
|
| 24 |
+
),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# calendar operations
|
| 28 |
+
date: str = Field("", description="ISO date for calendar queries (yyyy-mm-dd)")
|
| 29 |
+
event_id: str = Field("", description="Event ID for reschedule / cancel")
|
| 30 |
+
new_time: str = Field("", description="New start time HH:MM for reschedule")
|
| 31 |
+
title: str = Field("", description="Title for new meetings")
|
| 32 |
+
participants: List[str] = Field(default_factory=list, description="Attendee names")
|
| 33 |
+
time: str = Field("", description="Start time HH:MM for new meetings")
|
| 34 |
+
duration_min: int = Field(60, description="Meeting duration in minutes")
|
| 35 |
+
location: str = Field("", description="Room or location")
|
| 36 |
+
|
| 37 |
+
# contact queries
|
| 38 |
+
person: str = Field("", description="Contact name for availability check")
|
| 39 |
+
|
| 40 |
+
# restaurant search
|
| 41 |
+
cuisine: str = Field("", description="Cuisine filter")
|
| 42 |
+
max_price: int = Field(0, description="Max price per person (0 = no limit)")
|
| 43 |
+
dietary: str = Field("", description="Dietary requirement filter")
|
| 44 |
+
max_distance_miles: float = Field(0.0, description="Max distance (0 = no limit)")
|
| 45 |
+
near_airport: bool = Field(False, description="Filter for airport proximity")
|
| 46 |
+
restaurant_name: str = Field("", description="Specific restaurant to book")
|
| 47 |
+
|
| 48 |
+
# email
|
| 49 |
+
to: str = Field("", description="Recipient name for send_email")
|
| 50 |
+
subject: str = Field("", description="Email subject line")
|
| 51 |
+
body: str = Field("", description="Email body text")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class CommitmentObservation(Observation):
|
| 55 |
+
"""Observation from reset() and step(). Inherits ``done``, ``reward``."""
|
| 56 |
+
|
| 57 |
+
scenario_id: str = Field(default="", description="Current scenario identifier")
|
| 58 |
+
difficulty: str = Field(default="", description="easy | medium | hard")
|
| 59 |
+
briefing: str = Field(default="", description="Scenario description shown on reset")
|
| 60 |
+
tool_result: str = Field(default="", description="Output of the last tool call")
|
| 61 |
+
calendar_snapshot: List[Dict[str, Any]] = Field(
|
| 62 |
+
default_factory=list, description="Current calendar events",
|
| 63 |
+
)
|
| 64 |
+
inbox: List[Dict[str, Any]] = Field(
|
| 65 |
+
default_factory=list, description="Unread inbox emails",
|
| 66 |
+
)
|
| 67 |
+
pending_commitments: int = Field(0, description="Number of active commitments in ledger")
|
| 68 |
+
step_number: int = Field(0, description="Current step within this episode")
|
| 69 |
+
max_steps: int = Field(15, description="Maximum steps before forced submission")
|
| 70 |
+
reward_breakdown: Dict[str, float] = Field(
|
| 71 |
+
default_factory=dict, description="Per-component reward scores",
|
| 72 |
+
)
|
| 73 |
+
feedback: str = Field(default="", description="Human-readable grader feedback")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CommitmentState(State):
|
| 77 |
+
"""Episode metadata from GET /state."""
|
| 78 |
+
|
| 79 |
+
scenario_id: str = Field(default="", description="Current scenario identifier")
|
| 80 |
+
difficulty: str = Field(default="", description="Current difficulty level")
|
| 81 |
+
completed: bool = Field(default=False, description="Whether episode is finished")
|
| 82 |
+
cumulative_reward: float = Field(default=0.0, description="Sum of rewards this episode")
|
| 83 |
+
commitment_count: int = Field(default=0, description="Total commitments created")
|
| 84 |
+
violation_count: int = Field(default=0, description="Silent commitment violations")
|
| 85 |
+
available_tasks: List[str] = Field(
|
| 86 |
+
default_factory=list, description="All scenario IDs in the dataset",
|
| 87 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: commitment-os
|
| 3 |
+
description: >
|
| 4 |
+
CommitmentOS: the first RL environment that trains temporal commitment
|
| 5 |
+
coherence in LLMs. Multi-turn episodes where agents manage calendar,
|
| 6 |
+
email, and dining across scenarios where their own decisions create
|
| 7 |
+
binding constraints tracked via a commitment ledger.
|
| 8 |
+
author: Jayant Aggarwal
|
| 9 |
+
version: 0.1.0
|
| 10 |
+
|
| 11 |
+
action_model: CommitmentAction
|
| 12 |
+
observation_model: CommitmentObservation
|
| 13 |
+
state_model: CommitmentState
|
| 14 |
+
|
| 15 |
+
endpoints:
|
| 16 |
+
reset: POST /reset
|
| 17 |
+
step: POST /step
|
| 18 |
+
state: GET /state
|
| 19 |
+
health: GET /health
|
| 20 |
+
metadata: GET /metadata
|
| 21 |
+
schema: GET /schema
|
| 22 |
+
mcp: POST /mcp
|
| 23 |
+
|
| 24 |
+
tasks:
|
| 25 |
+
- name: easy_001
|
| 26 |
+
difficulty: easy
|
| 27 |
+
description: Resolve double-booked meetings by priority and notify team
|
| 28 |
+
- name: easy_002
|
| 29 |
+
difficulty: easy
|
| 30 |
+
description: Book dinner with cuisine, price, and distance constraints
|
| 31 |
+
- name: easy_003
|
| 32 |
+
difficulty: easy
|
| 33 |
+
description: Check availability and propose meeting slots to client via email
|
| 34 |
+
- name: easy_004
|
| 35 |
+
difficulty: easy
|
| 36 |
+
description: Cancel conflicting work meeting for personal appointment and notify
|
| 37 |
+
- name: easy_005
|
| 38 |
+
difficulty: easy
|
| 39 |
+
description: Triage inbox by urgency and respond to critical emails first
|
| 40 |
+
- name: med_006
|
| 41 |
+
difficulty: medium
|
| 42 |
+
description: Resolve cascading reschedule chain across 3 dependent meetings
|
| 43 |
+
- name: med_007
|
| 44 |
+
difficulty: medium
|
| 45 |
+
description: Plan team dinner with 3 dietary restrictions and multi-constraint search
|
| 46 |
+
- name: med_008
|
| 47 |
+
difficulty: medium
|
| 48 |
+
description: Handle urgent boss request while in a client call without abandoning commitments
|
| 49 |
+
- name: med_009
|
| 50 |
+
difficulty: medium
|
| 51 |
+
description: Disambiguate vague reschedule request across 3 recurring meetings
|
| 52 |
+
- name: med_010
|
| 53 |
+
difficulty: medium
|
| 54 |
+
description: Plan client visit with conference room, lunch, and itinerary dependencies
|
| 55 |
+
- name: hard_011
|
| 56 |
+
difficulty: hard
|
| 57 |
+
description: VP investor dinner with calendar cascade, restaurant constraints, and multi-party notifications
|
| 58 |
+
- name: hard_012
|
| 59 |
+
difficulty: hard
|
| 60 |
+
description: Resolve triple conference room conflict with diplomatic priority-based emails
|
| 61 |
+
- name: hard_013
|
| 62 |
+
difficulty: hard
|
| 63 |
+
description: Triple crisis recovery — cancelled flight, moved board prep, lost restaurant
|
| 64 |
+
- name: hard_014
|
| 65 |
+
difficulty: hard
|
| 66 |
+
description: Navigate information asymmetry — schedule meeting without revealing confidential constraints
|
| 67 |
+
- name: hard_015
|
| 68 |
+
difficulty: hard
|
| 69 |
+
description: Production incident interrupts day of commitments — triage, renegotiate, notify all parties
|
| 70 |
+
|
| 71 |
+
observation_space:
|
| 72 |
+
description: >
|
| 73 |
+
Current scenario context including calendar snapshot, inbox messages,
|
| 74 |
+
tool call results, commitment count, step number, reward breakdown,
|
| 75 |
+
and grader feedback.
|
| 76 |
+
|
| 77 |
+
action_space:
|
| 78 |
+
description: >
|
| 79 |
+
Single tool invocation per step. Agent selects action_type (view_calendar,
|
| 80 |
+
check_availability, search_restaurants, schedule_meeting, reschedule_event,
|
| 81 |
+
cancel_event, send_email, book_restaurant, submit_plan) and fills relevant
|
| 82 |
+
parameters. Episodes are multi-turn with 2-15 steps per scenario.
|
pyproject.toml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "commitment-os"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "CommitmentOS: the first RL environment that trains temporal commitment coherence in LLMs"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
authors = [
|
| 12 |
+
{name = "Jayant Aggarwal"},
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
dependencies = [
|
| 16 |
+
"openenv-core>=0.2.0",
|
| 17 |
+
"fastapi>=0.110.0",
|
| 18 |
+
"uvicorn[standard]>=0.29.0",
|
| 19 |
+
"pydantic>=2.0.0",
|
| 20 |
+
"python-dotenv>=1.0.0",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[project.scripts]
|
| 24 |
+
server = "server.app:main"
|
| 25 |
+
|
| 26 |
+
[project.optional-dependencies]
|
| 27 |
+
inference = [
|
| 28 |
+
"openai>=1.0.0",
|
| 29 |
+
"requests>=2.31.0",
|
| 30 |
+
]
|
| 31 |
+
dev = [
|
| 32 |
+
"pytest>=8.0.0",
|
| 33 |
+
"httpx>=0.27.0",
|
| 34 |
+
"openai>=1.0.0",
|
| 35 |
+
"requests>=2.31.0",
|
| 36 |
+
]
|
| 37 |
+
training = [
|
| 38 |
+
"trl>=0.14.0",
|
| 39 |
+
"transformers>=4.45.0",
|
| 40 |
+
"torch>=2.0.0",
|
| 41 |
+
"peft>=0.14.0",
|
| 42 |
+
"datasets>=3.0.0",
|
| 43 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker build dependencies — must stay in sync with pyproject.toml [project.dependencies]
|
| 2 |
+
openenv-core>=0.2.0
|
| 3 |
+
fastapi>=0.110.0
|
| 4 |
+
uvicorn[standard]>=0.29.0
|
| 5 |
+
pydantic>=2.0.0
|
| 6 |
+
python-dotenv>=1.0.0
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI composition root — wires environment, MCP, and custom endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_server import create_fastapi_app
|
| 8 |
+
|
| 9 |
+
from constants import PROJECT_DESCRIPTION, VERSION
|
| 10 |
+
from models import CommitmentAction, CommitmentObservation, CommitmentState
|
| 11 |
+
from server.environment import CommitmentEnvironment
|
| 12 |
+
from server.mcp import router as mcp_router
|
| 13 |
+
from server.tasks import get_scenario_ids_grouped
|
| 14 |
+
|
| 15 |
+
_shared_env = CommitmentEnvironment()
|
| 16 |
+
|
| 17 |
+
app = create_fastapi_app(
|
| 18 |
+
env=lambda: _shared_env,
|
| 19 |
+
action_cls=CommitmentAction,
|
| 20 |
+
observation_cls=CommitmentObservation,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
app.title = "CommitmentOS"
|
| 24 |
+
app.description = PROJECT_DESCRIPTION
|
| 25 |
+
app.version = VERSION
|
| 26 |
+
|
| 27 |
+
app.routes[:] = [
|
| 28 |
+
r for r in app.routes
|
| 29 |
+
if not (hasattr(r, "path") and r.path in ("/state", "/mcp"))
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@app.get("/state", response_model=CommitmentState)
|
| 34 |
+
def get_state() -> CommitmentState:
|
| 35 |
+
return _shared_env.state
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.get("/tasks")
|
| 39 |
+
def list_tasks() -> dict[str, list[str]]:
|
| 40 |
+
return get_scenario_ids_grouped()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
app.include_router(mcp_router)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main() -> None:
|
| 47 |
+
import uvicorn
|
| 48 |
+
|
| 49 |
+
port = int(os.environ.get("PORT", 7860))
|
| 50 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
server/domain.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Internal domain types — not exposed via the HTTP API."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Commitment ledger entry
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Commitment:
|
| 17 |
+
"""A binding constraint the agent created via its own actions."""
|
| 18 |
+
|
| 19 |
+
turn_created: int
|
| 20 |
+
commitment_type: str # "meeting_scheduled" | "email_promise" | "reservation_made"
|
| 21 |
+
description: str # human-readable: "3pm meeting with Client X"
|
| 22 |
+
constraint: str # machine key: "2026-04-25T15:00"
|
| 23 |
+
to_whom: str # who was promised
|
| 24 |
+
active: bool = True
|
| 25 |
+
renegotiated_at: Optional[int] = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Scenario / task definition
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
class CalendarEvent(BaseModel):
|
| 33 |
+
"""A single calendar entry."""
|
| 34 |
+
|
| 35 |
+
event_id: str = Field(..., description="Unique event identifier")
|
| 36 |
+
title: str = Field(..., description="Event title")
|
| 37 |
+
date: str = Field(..., description="ISO date yyyy-mm-dd")
|
| 38 |
+
time: str = Field(..., description="Start time HH:MM")
|
| 39 |
+
duration_min: int = Field(60, description="Duration in minutes")
|
| 40 |
+
participants: List[str] = Field(default_factory=list)
|
| 41 |
+
location: str = Field("", description="Room or location name")
|
| 42 |
+
priority: str = Field("normal", description="low | normal | high | critical")
|
| 43 |
+
is_personal: bool = Field(False, description="Personal vs work event")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Contact(BaseModel):
|
| 47 |
+
"""A person the agent can interact with."""
|
| 48 |
+
|
| 49 |
+
name: str
|
| 50 |
+
role: str = ""
|
| 51 |
+
email: str = ""
|
| 52 |
+
priority_level: int = Field(1, description="1 (lowest) to 5 (highest)")
|
| 53 |
+
availability: Dict[str, List[str]] = Field(
|
| 54 |
+
default_factory=dict,
|
| 55 |
+
description="date -> list of free time slots e.g. {'2026-04-25': ['09:00','10:00','14:00']}",
|
| 56 |
+
)
|
| 57 |
+
dietary: str = Field("", description="Dietary restrictions if any")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Restaurant(BaseModel):
|
| 61 |
+
"""A restaurant option the agent can search/book."""
|
| 62 |
+
|
| 63 |
+
name: str
|
| 64 |
+
cuisine: str
|
| 65 |
+
price_per_person: int
|
| 66 |
+
distance_miles: float
|
| 67 |
+
dietary_options: List[str] = Field(default_factory=list)
|
| 68 |
+
capacity: int = 20
|
| 69 |
+
hours: str = "11:00-22:00"
|
| 70 |
+
has_private_room: bool = False
|
| 71 |
+
near_airport: bool = False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class InboxEmail(BaseModel):
|
| 75 |
+
"""An email in the agent's inbox."""
|
| 76 |
+
|
| 77 |
+
email_id: str
|
| 78 |
+
sender: str
|
| 79 |
+
subject: str
|
| 80 |
+
body: str
|
| 81 |
+
urgency: str = Field("normal", description="low | normal | high | critical")
|
| 82 |
+
received_at: str = Field("", description="ISO datetime")
|
| 83 |
+
requires_response: bool = True
|
| 84 |
+
context_hint: str = Field("", description="Hidden hint for grader about what the correct action is")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ConstraintDef(BaseModel):
|
| 88 |
+
"""A single verifiable constraint for grading."""
|
| 89 |
+
|
| 90 |
+
description: str = Field(..., description="Human-readable: 'Restaurant must have vegan options'")
|
| 91 |
+
check_type: str = Field(..., description="'calendar_no_conflict' | 'restaurant_match' | 'email_sent' | 'event_exists' | 'event_cancelled' | 'priority_order'")
|
| 92 |
+
check_params: Dict[str, Any] = Field(default_factory=dict)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class CommunicationReq(BaseModel):
|
| 96 |
+
"""A required outgoing communication for grading."""
|
| 97 |
+
|
| 98 |
+
to: str = Field(..., description="Recipient name")
|
| 99 |
+
required_keywords: List[str] = Field(default_factory=list, description="Keywords that should appear")
|
| 100 |
+
purpose: str = Field("", description="'notify_reschedule' | 'propose_alternative' | 'acknowledge' | 'renegotiate'")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ScenarioDef(BaseModel):
|
| 104 |
+
"""Complete definition of a single task scenario."""
|
| 105 |
+
|
| 106 |
+
scenario_id: str
|
| 107 |
+
difficulty: str = Field(..., description="easy | medium | hard")
|
| 108 |
+
briefing: str = Field(..., description="The scenario description the agent sees on reset")
|
| 109 |
+
initial_calendar: List[CalendarEvent] = Field(default_factory=list)
|
| 110 |
+
initial_inbox: List[InboxEmail] = Field(default_factory=list)
|
| 111 |
+
available_restaurants: List[Restaurant] = Field(default_factory=list)
|
| 112 |
+
contacts: List[Contact] = Field(default_factory=list)
|
| 113 |
+
constraints: List[ConstraintDef] = Field(default_factory=list)
|
| 114 |
+
priority_ordering: List[str] = Field(
|
| 115 |
+
default_factory=list,
|
| 116 |
+
description="Ordered list from highest to lowest priority contact/event",
|
| 117 |
+
)
|
| 118 |
+
communication_requirements: List[CommunicationReq] = Field(default_factory=list)
|
| 119 |
+
optimal_steps: int = Field(3, description="Minimum steps to solve perfectly")
|
| 120 |
+
max_steps: int = Field(15, description="Maximum allowed steps before timeout")
|
| 121 |
+
|
| 122 |
+
# ground-truth for grading
|
| 123 |
+
expected_final_events: List[str] = Field(
|
| 124 |
+
default_factory=list,
|
| 125 |
+
description="Event IDs that should exist in final calendar",
|
| 126 |
+
)
|
| 127 |
+
expected_cancelled_events: List[str] = Field(
|
| 128 |
+
default_factory=list,
|
| 129 |
+
description="Event IDs that should be cancelled",
|
| 130 |
+
)
|
| 131 |
+
expected_restaurant: str = Field("", description="Name of the correct restaurant pick")
|
server/environment.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CommitmentOS environment — multi-turn personal task management with
|
| 2 |
+
temporal commitment coherence tracking.
|
| 3 |
+
|
| 4 |
+
Episode lifecycle:
|
| 5 |
+
1. reset() -> agent receives scenario briefing + calendar + inbox
|
| 6 |
+
2. step() -> agent makes one tool call per step (done=False)
|
| 7 |
+
3. step(submit_plan) or max_steps reached -> grading + done=True
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import random
|
| 13 |
+
import uuid
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server import Environment
|
| 17 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 18 |
+
|
| 19 |
+
from constants import AUTHOR, PROJECT_DESCRIPTION, PROJECT_NAME, VERSION
|
| 20 |
+
from models import CommitmentAction, CommitmentObservation, CommitmentState
|
| 21 |
+
from server.domain import ScenarioDef
|
| 22 |
+
from server.world import WorldState
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CommitmentEnvironment(
|
| 26 |
+
Environment[CommitmentAction, CommitmentObservation, CommitmentState]
|
| 27 |
+
):
|
| 28 |
+
def __init__(self) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self._world: Optional[WorldState] = None
|
| 31 |
+
self._scenario: Optional[ScenarioDef] = None
|
| 32 |
+
self._episode_id: str = ""
|
| 33 |
+
self._step_count: int = 0
|
| 34 |
+
self._done: bool = False
|
| 35 |
+
self._cumulative_reward: float = 0.0
|
| 36 |
+
self._last_tool_result: str = ""
|
| 37 |
+
self._last_breakdown: dict[str, float] = {}
|
| 38 |
+
self._last_feedback: str = ""
|
| 39 |
+
|
| 40 |
+
# ------------------------------------------------------------------
|
| 41 |
+
# Task selection
|
| 42 |
+
# ------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
def _select_scenario(
|
| 45 |
+
self,
|
| 46 |
+
scenario_id: Optional[str] = None,
|
| 47 |
+
difficulty: Optional[str] = None,
|
| 48 |
+
) -> ScenarioDef:
|
| 49 |
+
from server.tasks import get_all_scenarios, get_scenario, get_scenarios_by_difficulty
|
| 50 |
+
|
| 51 |
+
if scenario_id:
|
| 52 |
+
s = get_scenario(scenario_id)
|
| 53 |
+
if s is None:
|
| 54 |
+
raise ValueError(f"Unknown scenario_id: {scenario_id}")
|
| 55 |
+
return s
|
| 56 |
+
if difficulty:
|
| 57 |
+
candidates = get_scenarios_by_difficulty(difficulty)
|
| 58 |
+
if not candidates:
|
| 59 |
+
raise ValueError(f"No scenarios for difficulty: {difficulty}")
|
| 60 |
+
return random.choice(candidates)
|
| 61 |
+
return random.choice(list(get_all_scenarios().values()))
|
| 62 |
+
|
| 63 |
+
# ------------------------------------------------------------------
|
| 64 |
+
# Core API
|
| 65 |
+
# ------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
def reset(
|
| 68 |
+
self,
|
| 69 |
+
seed: Optional[int] = None,
|
| 70 |
+
episode_id: Optional[str] = None,
|
| 71 |
+
**kwargs: Any,
|
| 72 |
+
) -> CommitmentObservation:
|
| 73 |
+
if seed is not None:
|
| 74 |
+
random.seed(seed)
|
| 75 |
+
|
| 76 |
+
scenario = self._select_scenario(
|
| 77 |
+
scenario_id=kwargs.get("scenario_id") or kwargs.get("task_id"),
|
| 78 |
+
difficulty=kwargs.get("difficulty"),
|
| 79 |
+
)
|
| 80 |
+
self._scenario = scenario
|
| 81 |
+
self._world = WorldState(scenario)
|
| 82 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 83 |
+
self._step_count = 0
|
| 84 |
+
self._done = False
|
| 85 |
+
self._cumulative_reward = 0.0
|
| 86 |
+
self._last_tool_result = ""
|
| 87 |
+
self._last_breakdown = {}
|
| 88 |
+
self._last_feedback = "New episode started. Read the briefing and use tools to manage the situation."
|
| 89 |
+
|
| 90 |
+
return self._build_observation(reward=0.0, done=False)
|
| 91 |
+
|
| 92 |
+
def step(
|
| 93 |
+
self,
|
| 94 |
+
action: CommitmentAction,
|
| 95 |
+
timeout_s: Optional[float] = None,
|
| 96 |
+
**kwargs: Any,
|
| 97 |
+
) -> CommitmentObservation:
|
| 98 |
+
if self._world is None or self._scenario is None:
|
| 99 |
+
raise ValueError("No active episode. Call reset() first.")
|
| 100 |
+
if self._done:
|
| 101 |
+
raise ValueError("Episode already completed. Call reset() to start a new one.")
|
| 102 |
+
|
| 103 |
+
self._step_count += 1
|
| 104 |
+
self._world.step_count = self._step_count
|
| 105 |
+
|
| 106 |
+
at = action.action_type.lower().strip()
|
| 107 |
+
|
| 108 |
+
if at == "submit_plan" or self._step_count >= self._scenario.max_steps:
|
| 109 |
+
return self._finish_episode()
|
| 110 |
+
|
| 111 |
+
step_reward = 0.0
|
| 112 |
+
tool_result = self._dispatch_tool(action, at)
|
| 113 |
+
self._last_tool_result = tool_result
|
| 114 |
+
|
| 115 |
+
if "CONFLICT" in tool_result:
|
| 116 |
+
step_reward = -0.05
|
| 117 |
+
elif at in ("schedule_meeting", "reschedule_event", "send_email", "book_restaurant"):
|
| 118 |
+
step_reward = 0.05
|
| 119 |
+
|
| 120 |
+
self._cumulative_reward += step_reward
|
| 121 |
+
self._last_feedback = ""
|
| 122 |
+
self._last_breakdown = {}
|
| 123 |
+
|
| 124 |
+
return self._build_observation(reward=step_reward, done=False)
|
| 125 |
+
|
| 126 |
+
def _finish_episode(self) -> CommitmentObservation:
|
| 127 |
+
from server.graders import grade_scenario
|
| 128 |
+
|
| 129 |
+
assert self._world is not None
|
| 130 |
+
assert self._scenario is not None
|
| 131 |
+
|
| 132 |
+
total_reward, breakdown, feedback = grade_scenario(
|
| 133 |
+
self._scenario, self._world,
|
| 134 |
+
)
|
| 135 |
+
self._done = True
|
| 136 |
+
self._cumulative_reward += total_reward
|
| 137 |
+
self._last_breakdown = breakdown
|
| 138 |
+
self._last_feedback = feedback
|
| 139 |
+
self._last_tool_result = "Plan submitted. Episode graded."
|
| 140 |
+
|
| 141 |
+
return self._build_observation(reward=total_reward, done=True)
|
| 142 |
+
|
| 143 |
+
# ------------------------------------------------------------------
|
| 144 |
+
# Tool dispatch
|
| 145 |
+
# ------------------------------------------------------------------
|
| 146 |
+
|
| 147 |
+
def _dispatch_tool(self, action: CommitmentAction, at: str) -> str:
|
| 148 |
+
assert self._world is not None
|
| 149 |
+
turn = self._step_count
|
| 150 |
+
|
| 151 |
+
if at == "view_calendar":
|
| 152 |
+
return self._world.view_calendar(action.date)
|
| 153 |
+
elif at == "check_availability":
|
| 154 |
+
return self._world.check_availability(action.person)
|
| 155 |
+
elif at == "search_restaurants":
|
| 156 |
+
return self._world.search_restaurants(
|
| 157 |
+
cuisine=action.cuisine,
|
| 158 |
+
max_price=action.max_price,
|
| 159 |
+
dietary=action.dietary,
|
| 160 |
+
max_distance_miles=action.max_distance_miles,
|
| 161 |
+
near_airport=action.near_airport,
|
| 162 |
+
)
|
| 163 |
+
elif at == "schedule_meeting":
|
| 164 |
+
return self._world.schedule_meeting(
|
| 165 |
+
title=action.title,
|
| 166 |
+
date=action.date,
|
| 167 |
+
time=action.time,
|
| 168 |
+
duration_min=action.duration_min,
|
| 169 |
+
participants=action.participants,
|
| 170 |
+
location=action.location,
|
| 171 |
+
turn=turn,
|
| 172 |
+
)
|
| 173 |
+
elif at == "reschedule_event":
|
| 174 |
+
return self._world.reschedule_event(
|
| 175 |
+
event_id=action.event_id,
|
| 176 |
+
new_time=action.new_time,
|
| 177 |
+
turn=turn,
|
| 178 |
+
)
|
| 179 |
+
elif at == "cancel_event":
|
| 180 |
+
return self._world.cancel_event(action.event_id, turn=turn)
|
| 181 |
+
elif at == "send_email":
|
| 182 |
+
return self._world.send_email(
|
| 183 |
+
to=action.to,
|
| 184 |
+
subject=action.subject,
|
| 185 |
+
body=action.body,
|
| 186 |
+
turn=turn,
|
| 187 |
+
)
|
| 188 |
+
elif at == "book_restaurant":
|
| 189 |
+
return self._world.book_restaurant(action.restaurant_name, turn=turn)
|
| 190 |
+
else:
|
| 191 |
+
return f"Unknown action_type: '{at}'. Valid types: view_calendar, check_availability, search_restaurants, schedule_meeting, reschedule_event, cancel_event, send_email, book_restaurant, submit_plan"
|
| 192 |
+
|
| 193 |
+
# ------------------------------------------------------------------
|
| 194 |
+
# Observation builder
|
| 195 |
+
# ------------------------------------------------------------------
|
| 196 |
+
|
| 197 |
+
def _build_observation(self, *, reward: float, done: bool) -> CommitmentObservation:
|
| 198 |
+
assert self._world is not None
|
| 199 |
+
assert self._scenario is not None
|
| 200 |
+
|
| 201 |
+
return CommitmentObservation(
|
| 202 |
+
scenario_id=self._scenario.scenario_id,
|
| 203 |
+
difficulty=self._scenario.difficulty,
|
| 204 |
+
briefing=self._scenario.briefing if self._step_count == 0 else "",
|
| 205 |
+
tool_result=self._last_tool_result,
|
| 206 |
+
calendar_snapshot=self._world.get_calendar_snapshot(),
|
| 207 |
+
inbox=self._world.get_inbox_snapshot(),
|
| 208 |
+
pending_commitments=len(self._world.get_active_commitments()),
|
| 209 |
+
step_number=self._step_count,
|
| 210 |
+
max_steps=self._scenario.max_steps,
|
| 211 |
+
reward=reward,
|
| 212 |
+
reward_breakdown=self._last_breakdown,
|
| 213 |
+
done=done,
|
| 214 |
+
feedback=self._last_feedback,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# ------------------------------------------------------------------
|
| 218 |
+
# State property
|
| 219 |
+
# ------------------------------------------------------------------
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def state(self) -> CommitmentState:
|
| 223 |
+
from server.tasks import get_all_scenarios
|
| 224 |
+
|
| 225 |
+
violations = self._world.get_silent_violations() if self._world else []
|
| 226 |
+
return CommitmentState(
|
| 227 |
+
episode_id=self._episode_id,
|
| 228 |
+
step_count=self._step_count,
|
| 229 |
+
scenario_id=self._scenario.scenario_id if self._scenario else "",
|
| 230 |
+
difficulty=self._scenario.difficulty if self._scenario else "",
|
| 231 |
+
completed=self._done,
|
| 232 |
+
cumulative_reward=self._cumulative_reward,
|
| 233 |
+
commitment_count=len(self._world.commitment_ledger) if self._world else 0,
|
| 234 |
+
violation_count=len(violations),
|
| 235 |
+
available_tasks=list(get_all_scenarios().keys()),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 239 |
+
return EnvironmentMetadata(
|
| 240 |
+
name=PROJECT_NAME,
|
| 241 |
+
description=PROJECT_DESCRIPTION,
|
| 242 |
+
version=VERSION,
|
| 243 |
+
author=AUTHOR,
|
| 244 |
+
)
|
server/graders.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deterministic grading — 5-component reward for CommitmentOS.
|
| 2 |
+
|
| 3 |
+
Components:
|
| 4 |
+
constraint_satisfaction (0.35) — binary per scenario constraint
|
| 5 |
+
conflict_resolution (0.20) — final calendar free of overlaps
|
| 6 |
+
commitment_coherence (0.20) — ledger violations penalised
|
| 7 |
+
communication_quality (0.15) — keyword matching on sent emails
|
| 8 |
+
step_efficiency (0.10) — fewer steps = higher score
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any, Dict, List, Tuple
|
| 14 |
+
|
| 15 |
+
from server.domain import ScenarioDef
|
| 16 |
+
from server.world import WorldState, _time_to_min
|
| 17 |
+
|
| 18 |
+
WEIGHTS: Dict[str, float] = {
|
| 19 |
+
"constraint_satisfaction": 0.35,
|
| 20 |
+
"conflict_resolution": 0.20,
|
| 21 |
+
"commitment_coherence": 0.20,
|
| 22 |
+
"communication_quality": 0.15,
|
| 23 |
+
"step_efficiency": 0.10,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _keyword_score(text: str, keywords: List[str], min_matches: int = 2) -> Tuple[float, List[str]]:
|
| 28 |
+
"""0 hits -> 0.0, < min_matches -> 0.5 (partial), >= min_matches -> 1.0."""
|
| 29 |
+
text_lower = text.lower()
|
| 30 |
+
matched = [kw for kw in keywords if kw.lower() in text_lower]
|
| 31 |
+
if len(matched) == 0:
|
| 32 |
+
return 0.0, matched
|
| 33 |
+
if len(matched) < min_matches:
|
| 34 |
+
return 0.5, matched
|
| 35 |
+
return 1.0, matched
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _check_constraint(constraint, world: WorldState) -> bool:
|
| 39 |
+
"""Evaluate a single ConstraintDef against the world state."""
|
| 40 |
+
ct = constraint.check_type
|
| 41 |
+
params = constraint.check_params
|
| 42 |
+
|
| 43 |
+
if ct == "calendar_no_conflict":
|
| 44 |
+
return _calendar_has_no_overlaps(world)
|
| 45 |
+
|
| 46 |
+
elif ct == "event_exists":
|
| 47 |
+
eid = params.get("event_id", "")
|
| 48 |
+
return eid in world.calendar
|
| 49 |
+
|
| 50 |
+
elif ct == "event_cancelled":
|
| 51 |
+
eid = params.get("event_id", "")
|
| 52 |
+
return eid not in world.calendar
|
| 53 |
+
|
| 54 |
+
elif ct == "email_sent":
|
| 55 |
+
to = params.get("to", "").lower()
|
| 56 |
+
keywords = params.get("keywords", [])
|
| 57 |
+
for em in world.emails_sent:
|
| 58 |
+
if to in em.get("to", "").lower():
|
| 59 |
+
if keywords:
|
| 60 |
+
score, _ = _keyword_score(em.get("body", ""), keywords, min_matches=1)
|
| 61 |
+
if score > 0:
|
| 62 |
+
return True
|
| 63 |
+
else:
|
| 64 |
+
return True
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
elif ct == "restaurant_match":
|
| 68 |
+
name = params.get("name", "")
|
| 69 |
+
if name:
|
| 70 |
+
return world.booked_restaurant.lower() == name.lower()
|
| 71 |
+
criteria = params.get("criteria", {})
|
| 72 |
+
if not world.booked_restaurant:
|
| 73 |
+
return False
|
| 74 |
+
r = world.restaurants.get(world.booked_restaurant)
|
| 75 |
+
if r is None:
|
| 76 |
+
return False
|
| 77 |
+
if "dietary" in criteria and criteria["dietary"].lower() not in [d.lower() for d in r.dietary_options]:
|
| 78 |
+
return False
|
| 79 |
+
if "max_price" in criteria and r.price_per_person > criteria["max_price"]:
|
| 80 |
+
return False
|
| 81 |
+
if "max_distance" in criteria and r.distance_miles > criteria["max_distance"]:
|
| 82 |
+
return False
|
| 83 |
+
if "near_airport" in criteria and criteria["near_airport"] and not r.near_airport:
|
| 84 |
+
return False
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
elif ct == "priority_order":
|
| 88 |
+
higher = params.get("higher", "").lower()
|
| 89 |
+
lower = params.get("lower", "").lower()
|
| 90 |
+
higher_kept = any(
|
| 91 |
+
ev.title.lower() == higher or higher in ev.title.lower()
|
| 92 |
+
for ev in world.calendar.values()
|
| 93 |
+
)
|
| 94 |
+
lower_moved = not any(
|
| 95 |
+
ev.title.lower() == lower or lower in ev.title.lower()
|
| 96 |
+
for ev in world.calendar.values()
|
| 97 |
+
) or any(
|
| 98 |
+
em.get("to", "").lower() == lower or lower in em.get("body", "").lower()
|
| 99 |
+
for em in world.emails_sent
|
| 100 |
+
)
|
| 101 |
+
return higher_kept
|
| 102 |
+
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _calendar_has_no_overlaps(world: WorldState) -> bool:
|
| 107 |
+
events = list(world.calendar.values())
|
| 108 |
+
for i, a in enumerate(events):
|
| 109 |
+
for b in events[i + 1:]:
|
| 110 |
+
if a.date != b.date:
|
| 111 |
+
continue
|
| 112 |
+
a_start = _time_to_min(a.time)
|
| 113 |
+
a_end = a_start + a.duration_min
|
| 114 |
+
b_start = _time_to_min(b.time)
|
| 115 |
+
b_end = b_start + b.duration_min
|
| 116 |
+
if a_start < b_end and b_start < a_end:
|
| 117 |
+
return False
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _score_constraint_satisfaction(scenario: ScenarioDef, world: WorldState) -> Tuple[float, str]:
|
| 122 |
+
if not scenario.constraints:
|
| 123 |
+
return 1.0, "No constraints defined"
|
| 124 |
+
met = sum(1 for c in scenario.constraints if _check_constraint(c, world))
|
| 125 |
+
total = len(scenario.constraints)
|
| 126 |
+
score = met / total
|
| 127 |
+
return score, f"{met}/{total} constraints met"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _score_conflict_resolution(world: WorldState) -> Tuple[float, str]:
|
| 131 |
+
clean = _calendar_has_no_overlaps(world)
|
| 132 |
+
return (1.0 if clean else 0.0), ("No calendar conflicts" if clean else "Calendar has overlapping events")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _score_commitment_coherence(world: WorldState) -> Tuple[float, str]:
|
| 136 |
+
total = len(world.commitment_ledger)
|
| 137 |
+
if total == 0:
|
| 138 |
+
return 1.0, "No commitments created"
|
| 139 |
+
violations = world.get_silent_violations()
|
| 140 |
+
silent_count = len(violations)
|
| 141 |
+
|
| 142 |
+
renegotiated = sum(1 for c in world.commitment_ledger if c.renegotiated_at is not None)
|
| 143 |
+
honored = total - silent_count - renegotiated
|
| 144 |
+
|
| 145 |
+
score = (total - silent_count) / total
|
| 146 |
+
parts = []
|
| 147 |
+
if honored > 0:
|
| 148 |
+
parts.append(f"{honored} honored")
|
| 149 |
+
if renegotiated > 0:
|
| 150 |
+
parts.append(f"{renegotiated} renegotiated")
|
| 151 |
+
if silent_count > 0:
|
| 152 |
+
parts.append(f"{silent_count} SILENTLY BROKEN")
|
| 153 |
+
return score, " | ".join(parts) if parts else "OK"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _score_communication(scenario: ScenarioDef, world: WorldState) -> Tuple[float, str]:
|
| 157 |
+
reqs = scenario.communication_requirements
|
| 158 |
+
if not reqs:
|
| 159 |
+
return 1.0, "No communication requirements"
|
| 160 |
+
|
| 161 |
+
total_score = 0.0
|
| 162 |
+
feedback_parts: List[str] = []
|
| 163 |
+
for req in reqs:
|
| 164 |
+
to_lower = req.to.lower()
|
| 165 |
+
matching_emails = [
|
| 166 |
+
em for em in world.emails_sent
|
| 167 |
+
if to_lower in em.get("to", "").lower()
|
| 168 |
+
]
|
| 169 |
+
if not matching_emails:
|
| 170 |
+
feedback_parts.append(f"MISSING email to {req.to}")
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
best_score = 0.0
|
| 174 |
+
for em in matching_emails:
|
| 175 |
+
body = em.get("body", "") + " " + em.get("subject", "")
|
| 176 |
+
if req.required_keywords:
|
| 177 |
+
ks, matched = _keyword_score(body, req.required_keywords, min_matches=1)
|
| 178 |
+
best_score = max(best_score, ks)
|
| 179 |
+
else:
|
| 180 |
+
best_score = 1.0
|
| 181 |
+
|
| 182 |
+
total_score += best_score
|
| 183 |
+
if best_score >= 1.0:
|
| 184 |
+
feedback_parts.append(f"Email to {req.to}: full credit")
|
| 185 |
+
elif best_score > 0:
|
| 186 |
+
feedback_parts.append(f"Email to {req.to}: partial ({best_score:.1f})")
|
| 187 |
+
else:
|
| 188 |
+
feedback_parts.append(f"Email to {req.to}: missing keywords")
|
| 189 |
+
|
| 190 |
+
score = total_score / len(reqs) if reqs else 1.0
|
| 191 |
+
return score, " | ".join(feedback_parts)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _score_step_efficiency(scenario: ScenarioDef, world: WorldState) -> Tuple[float, str]:
|
| 195 |
+
optimal = scenario.optimal_steps
|
| 196 |
+
actual = world.step_count
|
| 197 |
+
if actual <= optimal:
|
| 198 |
+
return 1.0, f"{actual} steps (optimal: {optimal})"
|
| 199 |
+
penalty = (actual - optimal) * 0.1
|
| 200 |
+
score = max(0.0, 1.0 - penalty)
|
| 201 |
+
return score, f"{actual} steps (optimal: {optimal}, penalty: -{penalty:.1f})"
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def grade_scenario(
|
| 205 |
+
scenario: ScenarioDef,
|
| 206 |
+
world: WorldState,
|
| 207 |
+
) -> Tuple[float, Dict[str, float], str]:
|
| 208 |
+
"""Returns ``(total_reward, breakdown, feedback)``."""
|
| 209 |
+
breakdown: Dict[str, float] = {}
|
| 210 |
+
feedback_parts: List[str] = []
|
| 211 |
+
|
| 212 |
+
cs_score, cs_fb = _score_constraint_satisfaction(scenario, world)
|
| 213 |
+
breakdown["constraint_satisfaction"] = round(cs_score * WEIGHTS["constraint_satisfaction"], 4)
|
| 214 |
+
feedback_parts.append(f"[constraints] {cs_fb}")
|
| 215 |
+
|
| 216 |
+
cr_score, cr_fb = _score_conflict_resolution(world)
|
| 217 |
+
breakdown["conflict_resolution"] = round(cr_score * WEIGHTS["conflict_resolution"], 4)
|
| 218 |
+
feedback_parts.append(f"[conflicts] {cr_fb}")
|
| 219 |
+
|
| 220 |
+
cc_score, cc_fb = _score_commitment_coherence(world)
|
| 221 |
+
breakdown["commitment_coherence"] = round(cc_score * WEIGHTS["commitment_coherence"], 4)
|
| 222 |
+
feedback_parts.append(f"[commitments] {cc_fb}")
|
| 223 |
+
|
| 224 |
+
cq_score, cq_fb = _score_communication(scenario, world)
|
| 225 |
+
breakdown["communication_quality"] = round(cq_score * WEIGHTS["communication_quality"], 4)
|
| 226 |
+
feedback_parts.append(f"[communication] {cq_fb}")
|
| 227 |
+
|
| 228 |
+
se_score, se_fb = _score_step_efficiency(scenario, world)
|
| 229 |
+
breakdown["step_efficiency"] = round(se_score * WEIGHTS["step_efficiency"], 4)
|
| 230 |
+
feedback_parts.append(f"[efficiency] {se_fb}")
|
| 231 |
+
|
| 232 |
+
total_reward = round(sum(breakdown.values()), 4)
|
| 233 |
+
total_reward = max(0.01, min(0.99, total_reward))
|
| 234 |
+
|
| 235 |
+
feedback = " | ".join(feedback_parts)
|
| 236 |
+
return total_reward, breakdown, feedback
|
server/mcp.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MCP JSON-RPC 2.0 endpoint for OpenEnv validator compliance."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Request
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
|
| 8 |
+
from constants import PROJECT_NAME, VERSION
|
| 9 |
+
from models import CommitmentAction
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
|
| 13 |
+
_CAPABILITIES = {
|
| 14 |
+
"tools": {"listChanged": False},
|
| 15 |
+
"resources": {"subscribe": False, "listChanged": False},
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
_TOOLS = [
|
| 19 |
+
{
|
| 20 |
+
"name": "reset",
|
| 21 |
+
"description": "Start a new CommitmentOS episode",
|
| 22 |
+
"inputSchema": CommitmentAction.model_json_schema(),
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"name": "step",
|
| 26 |
+
"description": "Execute one tool call in the current episode",
|
| 27 |
+
"inputSchema": CommitmentAction.model_json_schema(),
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"name": "state",
|
| 31 |
+
"description": "Get current episode state",
|
| 32 |
+
"inputSchema": {"type": "object", "properties": {}},
|
| 33 |
+
},
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _jsonrpc_response(rpc_id: object, result: dict) -> JSONResponse:
|
| 38 |
+
return JSONResponse({"jsonrpc": "2.0", "id": rpc_id, "result": result})
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _jsonrpc_error(rpc_id: object, code: int, message: str) -> JSONResponse:
|
| 42 |
+
return JSONResponse({"jsonrpc": "2.0", "id": rpc_id, "error": {"code": code, "message": message}})
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@router.post("/mcp")
|
| 46 |
+
async def mcp_endpoint(request: Request) -> JSONResponse:
|
| 47 |
+
try:
|
| 48 |
+
body = await request.json()
|
| 49 |
+
except Exception:
|
| 50 |
+
return _jsonrpc_error(None, -32700, "Parse error")
|
| 51 |
+
|
| 52 |
+
rpc_id = body.get("id")
|
| 53 |
+
method = body.get("method", "")
|
| 54 |
+
|
| 55 |
+
if method == "initialize":
|
| 56 |
+
return _jsonrpc_response(rpc_id, {
|
| 57 |
+
"protocolVersion": "2024-11-05",
|
| 58 |
+
"capabilities": _CAPABILITIES,
|
| 59 |
+
"serverInfo": {"name": PROJECT_NAME, "version": VERSION},
|
| 60 |
+
})
|
| 61 |
+
|
| 62 |
+
if method == "tools/list":
|
| 63 |
+
return _jsonrpc_response(rpc_id, {"tools": _TOOLS})
|
| 64 |
+
|
| 65 |
+
return _jsonrpc_error(rpc_id, -32601, f"Method not found: {method}")
|
server/tasks.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario dataset — 15 tasks across 3 difficulty tiers.
|
| 2 |
+
|
| 3 |
+
Each scenario is a validated ``ScenarioDef`` Pydantic model containing the
|
| 4 |
+
initial world state and deterministic grader keys.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
from server.domain import (
|
| 12 |
+
CalendarEvent,
|
| 13 |
+
CommunicationReq,
|
| 14 |
+
ConstraintDef,
|
| 15 |
+
Contact,
|
| 16 |
+
InboxEmail,
|
| 17 |
+
Restaurant,
|
| 18 |
+
ScenarioDef,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# ===================================================================
|
| 22 |
+
# EASY — 2-4 tool calls, single constraint domain
|
| 23 |
+
# ===================================================================
|
| 24 |
+
|
| 25 |
+
_EASY_001 = ScenarioDef(
|
| 26 |
+
scenario_id="easy_001",
|
| 27 |
+
difficulty="easy",
|
| 28 |
+
briefing=(
|
| 29 |
+
"You have two meetings at 2:00 PM today (2026-04-25): a 1-on-1 with your boss "
|
| 30 |
+
"VP_Chen and a team standup with 6 people. Both are in different rooms. "
|
| 31 |
+
"VP_Chen's meeting is higher priority. Reschedule the standup to a free slot "
|
| 32 |
+
"and notify the team."
|
| 33 |
+
),
|
| 34 |
+
initial_calendar=[
|
| 35 |
+
CalendarEvent(event_id="evt_1", title="1-on-1 with VP_Chen", date="2026-04-25", time="14:00", duration_min=30, participants=["VP_Chen"], location="Room A", priority="high"),
|
| 36 |
+
CalendarEvent(event_id="evt_2", title="Team Standup", date="2026-04-25", time="14:00", duration_min=30, participants=["Alice", "Bob", "Carol", "Dave", "Eve", "Frank"], location="Room B", priority="normal"),
|
| 37 |
+
CalendarEvent(event_id="evt_3", title="Lunch", date="2026-04-25", time="12:00", duration_min=60, participants=[], priority="low", is_personal=True),
|
| 38 |
+
],
|
| 39 |
+
initial_inbox=[
|
| 40 |
+
InboxEmail(email_id="em_1", sender="VP_Chen", subject="Our 1-on-1 today", body="Looking forward to our 2pm chat. I have some feedback on the Q3 roadmap.", urgency="high"),
|
| 41 |
+
],
|
| 42 |
+
contacts=[
|
| 43 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 44 |
+
Contact(name="Alice", role="Engineer", priority_level=2),
|
| 45 |
+
Contact(name="Team", role="Engineering Team", priority_level=2, email="team@company.com"),
|
| 46 |
+
],
|
| 47 |
+
constraints=[
|
| 48 |
+
ConstraintDef(description="1-on-1 with VP_Chen must remain at 14:00", check_type="event_exists", check_params={"event_id": "evt_1"}),
|
| 49 |
+
ConstraintDef(description="Team standup must not conflict with 1-on-1", check_type="calendar_no_conflict", check_params={}),
|
| 50 |
+
ConstraintDef(description="Team must be notified of reschedule", check_type="email_sent", check_params={"to": "Team", "keywords": ["reschedule", "standup", "move"]}),
|
| 51 |
+
],
|
| 52 |
+
priority_ordering=["VP_Chen", "Team"],
|
| 53 |
+
communication_requirements=[
|
| 54 |
+
CommunicationReq(to="Team", required_keywords=["reschedule", "standup"], purpose="notify_reschedule"),
|
| 55 |
+
],
|
| 56 |
+
optimal_steps=3,
|
| 57 |
+
max_steps=8,
|
| 58 |
+
expected_cancelled_events=[],
|
| 59 |
+
expected_final_events=["evt_1"],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
_EASY_002 = ScenarioDef(
|
| 63 |
+
scenario_id="easy_002",
|
| 64 |
+
difficulty="easy",
|
| 65 |
+
briefing=(
|
| 66 |
+
"Book a dinner tonight (2026-04-25) for 4 people. Requirements: "
|
| 67 |
+
"Italian cuisine, under $50 per person, within 3 miles. "
|
| 68 |
+
"Search restaurants and book the best match."
|
| 69 |
+
),
|
| 70 |
+
initial_calendar=[
|
| 71 |
+
CalendarEvent(event_id="evt_10", title="Morning Standup", date="2026-04-25", time="09:00", duration_min=30, participants=["Team"]),
|
| 72 |
+
],
|
| 73 |
+
initial_inbox=[
|
| 74 |
+
InboxEmail(email_id="em_10", sender="Alice", subject="Dinner tonight?", body="Can you find us a nice Italian place? Budget is $50/person max. Needs to be close to the office.", urgency="normal"),
|
| 75 |
+
],
|
| 76 |
+
available_restaurants=[
|
| 77 |
+
Restaurant(name="Bella Italia", cuisine="Italian", price_per_person=40, distance_miles=2.0, dietary_options=["vegetarian", "gluten-free"], capacity=30),
|
| 78 |
+
Restaurant(name="Chez Pierre", cuisine="French", price_per_person=80, distance_miles=1.5, dietary_options=["vegetarian"], capacity=40),
|
| 79 |
+
Restaurant(name="Pasta Palace", cuisine="Italian", price_per_person=55, distance_miles=1.0, dietary_options=["vegan", "vegetarian"], capacity=20),
|
| 80 |
+
Restaurant(name="Dragon Wok", cuisine="Chinese", price_per_person=25, distance_miles=4.0, dietary_options=["vegan", "vegetarian"], capacity=50),
|
| 81 |
+
],
|
| 82 |
+
contacts=[
|
| 83 |
+
Contact(name="Alice", role="Friend", priority_level=2),
|
| 84 |
+
],
|
| 85 |
+
constraints=[
|
| 86 |
+
ConstraintDef(description="Restaurant must be Italian", check_type="restaurant_match", check_params={"criteria": {"dietary": ""}}),
|
| 87 |
+
ConstraintDef(description="Restaurant must be under $50/pp", check_type="restaurant_match", check_params={"criteria": {"max_price": 50}}),
|
| 88 |
+
ConstraintDef(description="Restaurant must be within 3 miles", check_type="restaurant_match", check_params={"criteria": {"max_distance": 3.0}}),
|
| 89 |
+
],
|
| 90 |
+
optimal_steps=2,
|
| 91 |
+
max_steps=6,
|
| 92 |
+
expected_restaurant="Bella Italia",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
_EASY_003 = ScenarioDef(
|
| 96 |
+
scenario_id="easy_003",
|
| 97 |
+
difficulty="easy",
|
| 98 |
+
briefing=(
|
| 99 |
+
"Client_Jones has emailed asking for a meeting this week. Check your "
|
| 100 |
+
"calendar for 2026-04-25 and Client_Jones's availability, then propose "
|
| 101 |
+
"3 available slots via email."
|
| 102 |
+
),
|
| 103 |
+
initial_calendar=[
|
| 104 |
+
CalendarEvent(event_id="evt_20", title="Team Sync", date="2026-04-25", time="10:00", duration_min=60, participants=["Team"]),
|
| 105 |
+
CalendarEvent(event_id="evt_21", title="Lunch", date="2026-04-25", time="12:00", duration_min=60, is_personal=True),
|
| 106 |
+
CalendarEvent(event_id="evt_22", title="Design Review", date="2026-04-25", time="15:00", duration_min=60, participants=["Bob", "Carol"]),
|
| 107 |
+
],
|
| 108 |
+
initial_inbox=[
|
| 109 |
+
InboxEmail(email_id="em_20", sender="Client_Jones", subject="Meeting this week?", body="Hi, I'd love to catch up this week. Do you have any openings? Need about 30 minutes.", urgency="high"),
|
| 110 |
+
],
|
| 111 |
+
contacts=[
|
| 112 |
+
Contact(name="Client_Jones", role="Client", priority_level=4, availability={"2026-04-25": ["09:00", "11:00", "14:00", "16:00"]}),
|
| 113 |
+
],
|
| 114 |
+
constraints=[
|
| 115 |
+
ConstraintDef(description="Email must be sent to Client_Jones", check_type="email_sent", check_params={"to": "Client_Jones", "keywords": ["slot", "available", "meet"]}),
|
| 116 |
+
],
|
| 117 |
+
communication_requirements=[
|
| 118 |
+
CommunicationReq(to="Client_Jones", required_keywords=["available", "slot", "time"], purpose="propose_slots"),
|
| 119 |
+
],
|
| 120 |
+
optimal_steps=3,
|
| 121 |
+
max_steps=8,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
_EASY_004 = ScenarioDef(
|
| 125 |
+
scenario_id="easy_004",
|
| 126 |
+
difficulty="easy",
|
| 127 |
+
briefing=(
|
| 128 |
+
"Your personal doctor appointment at 3:00 PM today (2026-04-25) conflicts "
|
| 129 |
+
"with the weekly team sync. The doctor appointment was booked first and is "
|
| 130 |
+
"important. Cancel the team sync and notify the team."
|
| 131 |
+
),
|
| 132 |
+
initial_calendar=[
|
| 133 |
+
CalendarEvent(event_id="evt_30", title="Weekly Team Sync", date="2026-04-25", time="15:00", duration_min=60, participants=["Team"], priority="normal"),
|
| 134 |
+
CalendarEvent(event_id="evt_31", title="Doctor Appointment", date="2026-04-25", time="15:00", duration_min=60, priority="high", is_personal=True),
|
| 135 |
+
],
|
| 136 |
+
initial_inbox=[],
|
| 137 |
+
contacts=[
|
| 138 |
+
Contact(name="Team", role="Engineering Team", priority_level=2),
|
| 139 |
+
],
|
| 140 |
+
constraints=[
|
| 141 |
+
ConstraintDef(description="Doctor appointment must remain", check_type="event_exists", check_params={"event_id": "evt_31"}),
|
| 142 |
+
ConstraintDef(description="Team sync must be cancelled", check_type="event_cancelled", check_params={"event_id": "evt_30"}),
|
| 143 |
+
ConstraintDef(description="Team must be notified", check_type="email_sent", check_params={"to": "Team", "keywords": ["cancel", "sync"]}),
|
| 144 |
+
],
|
| 145 |
+
communication_requirements=[
|
| 146 |
+
CommunicationReq(to="Team", required_keywords=["cancel", "sync", "apologi"], purpose="notify_reschedule"),
|
| 147 |
+
],
|
| 148 |
+
optimal_steps=2,
|
| 149 |
+
max_steps=6,
|
| 150 |
+
expected_cancelled_events=["evt_30"],
|
| 151 |
+
expected_final_events=["evt_31"],
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
_EASY_005 = ScenarioDef(
|
| 155 |
+
scenario_id="easy_005",
|
| 156 |
+
difficulty="easy",
|
| 157 |
+
briefing=(
|
| 158 |
+
"You have 3 unread emails. Triage them by urgency and respond to the most "
|
| 159 |
+
"urgent one first. VP_Chen's email is critical, Client_Jones is high, "
|
| 160 |
+
"and Alice is normal priority."
|
| 161 |
+
),
|
| 162 |
+
initial_calendar=[],
|
| 163 |
+
initial_inbox=[
|
| 164 |
+
InboxEmail(email_id="em_50", sender="Alice", subject="Lunch tomorrow?", body="Want to grab lunch tomorrow?", urgency="low"),
|
| 165 |
+
InboxEmail(email_id="em_51", sender="Client_Jones", subject="Contract review", body="Please review the attached contract by end of day.", urgency="high"),
|
| 166 |
+
InboxEmail(email_id="em_52", sender="VP_Chen", subject="URGENT: Board deck", body="I need the Q3 numbers for the board deck. Can you send them in the next hour?", urgency="critical"),
|
| 167 |
+
],
|
| 168 |
+
contacts=[
|
| 169 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 170 |
+
Contact(name="Client_Jones", role="Client", priority_level=4),
|
| 171 |
+
Contact(name="Alice", role="Engineer", priority_level=2),
|
| 172 |
+
],
|
| 173 |
+
constraints=[
|
| 174 |
+
ConstraintDef(description="VP_Chen must be responded to", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["Q3", "number", "board"]}),
|
| 175 |
+
ConstraintDef(description="Client_Jones must be responded to", check_type="email_sent", check_params={"to": "Client_Jones", "keywords": ["contract", "review"]}),
|
| 176 |
+
],
|
| 177 |
+
communication_requirements=[
|
| 178 |
+
CommunicationReq(to="VP_Chen", required_keywords=["Q3", "numbers", "send"], purpose="acknowledge"),
|
| 179 |
+
CommunicationReq(to="Client_Jones", required_keywords=["contract", "review"], purpose="acknowledge"),
|
| 180 |
+
],
|
| 181 |
+
optimal_steps=2,
|
| 182 |
+
max_steps=6,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# ===================================================================
|
| 186 |
+
# MEDIUM — 5-8 tool calls, cross-domain with commitment tracking
|
| 187 |
+
# ===================================================================
|
| 188 |
+
|
| 189 |
+
_MED_006 = ScenarioDef(
|
| 190 |
+
scenario_id="med_006",
|
| 191 |
+
difficulty="medium",
|
| 192 |
+
briefing=(
|
| 193 |
+
"Meeting A ('Design Review') has been moved from 2:00 PM to 3:00 PM today "
|
| 194 |
+
"(2026-04-25). But you have Meeting B ('Sprint Planning') at 3:00 PM, and "
|
| 195 |
+
"Meeting C ('Demo Prep') at 4:00 PM depends on Sprint Planning's output. "
|
| 196 |
+
"Resolve the cascade: reschedule B without conflicting with C, and notify "
|
| 197 |
+
"all affected parties."
|
| 198 |
+
),
|
| 199 |
+
initial_calendar=[
|
| 200 |
+
CalendarEvent(event_id="evt_40", title="Design Review", date="2026-04-25", time="14:00", duration_min=60, participants=["Bob", "Carol"], priority="high"),
|
| 201 |
+
CalendarEvent(event_id="evt_41", title="Sprint Planning", date="2026-04-25", time="15:00", duration_min=60, participants=["Team"], priority="normal"),
|
| 202 |
+
CalendarEvent(event_id="evt_42", title="Demo Prep", date="2026-04-25", time="16:00", duration_min=60, participants=["Alice", "Dave"], priority="normal"),
|
| 203 |
+
CalendarEvent(event_id="evt_43", title="Morning Standup", date="2026-04-25", time="09:00", duration_min=30, participants=["Team"]),
|
| 204 |
+
],
|
| 205 |
+
initial_inbox=[
|
| 206 |
+
InboxEmail(email_id="em_40", sender="Bob", subject="Design Review moved", body="Hey, I need to push our 2pm design review to 3pm. Apologies for the late change.", urgency="high"),
|
| 207 |
+
],
|
| 208 |
+
contacts=[
|
| 209 |
+
Contact(name="Bob", role="Lead Designer", priority_level=3),
|
| 210 |
+
Contact(name="Team", role="Engineering Team", priority_level=2),
|
| 211 |
+
Contact(name="Alice", role="Engineer", priority_level=2),
|
| 212 |
+
],
|
| 213 |
+
constraints=[
|
| 214 |
+
ConstraintDef(description="Design Review must be at 15:00", check_type="calendar_no_conflict", check_params={}),
|
| 215 |
+
ConstraintDef(description="Sprint Planning must not conflict", check_type="calendar_no_conflict", check_params={}),
|
| 216 |
+
ConstraintDef(description="Demo Prep must remain after Sprint Planning", check_type="event_exists", check_params={"event_id": "evt_42"}),
|
| 217 |
+
ConstraintDef(description="Team notified about Sprint Planning change", check_type="email_sent", check_params={"to": "Team", "keywords": ["sprint", "reschedule", "move"]}),
|
| 218 |
+
],
|
| 219 |
+
communication_requirements=[
|
| 220 |
+
CommunicationReq(to="Team", required_keywords=["sprint", "planning", "reschedule"], purpose="notify_reschedule"),
|
| 221 |
+
],
|
| 222 |
+
optimal_steps=4,
|
| 223 |
+
max_steps=10,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
_MED_007 = ScenarioDef(
|
| 227 |
+
scenario_id="med_007",
|
| 228 |
+
difficulty="medium",
|
| 229 |
+
briefing=(
|
| 230 |
+
"Plan a team dinner for 6 people tonight (2026-04-25). Constraints: "
|
| 231 |
+
"Alice is vegan, Bob has a nut allergy, must be within 3 miles, "
|
| 232 |
+
"under $45 per person, and needs a private room for 6+. "
|
| 233 |
+
"Search restaurants, book the right one, and email the team with details."
|
| 234 |
+
),
|
| 235 |
+
initial_calendar=[
|
| 236 |
+
CalendarEvent(event_id="evt_50", title="Afternoon Focus", date="2026-04-25", time="14:00", duration_min=120),
|
| 237 |
+
],
|
| 238 |
+
initial_inbox=[
|
| 239 |
+
InboxEmail(email_id="em_50", sender="Alice", subject="Dinner tonight", body="Can you book a place? Remember I'm vegan. Bob has a nut allergy. We need a private room.", urgency="normal"),
|
| 240 |
+
],
|
| 241 |
+
available_restaurants=[
|
| 242 |
+
Restaurant(name="Green Garden", cuisine="Mediterranean", price_per_person=38, distance_miles=2.5, dietary_options=["vegan", "nut-free", "vegetarian"], capacity=30, has_private_room=True),
|
| 243 |
+
Restaurant(name="Steak House Prime", cuisine="American", price_per_person=55, distance_miles=1.0, dietary_options=["gluten-free"], capacity=50, has_private_room=True),
|
| 244 |
+
Restaurant(name="Lotus Thai", cuisine="Thai", price_per_person=30, distance_miles=3.5, dietary_options=["vegan", "vegetarian"], capacity=25, has_private_room=False),
|
| 245 |
+
Restaurant(name="Cafe Novo", cuisine="Fusion", price_per_person=42, distance_miles=2.0, dietary_options=["vegan", "nut-free", "gluten-free", "vegetarian"], capacity=15, has_private_room=True),
|
| 246 |
+
Restaurant(name="Burgers & Brew", cuisine="American", price_per_person=20, distance_miles=0.5, dietary_options=["vegetarian"], capacity=40, has_private_room=False),
|
| 247 |
+
],
|
| 248 |
+
contacts=[
|
| 249 |
+
Contact(name="Alice", role="Engineer", priority_level=2, dietary="vegan"),
|
| 250 |
+
Contact(name="Bob", role="Engineer", priority_level=2, dietary="nut-free"),
|
| 251 |
+
Contact(name="Team", role="Engineering Team", priority_level=2),
|
| 252 |
+
],
|
| 253 |
+
constraints=[
|
| 254 |
+
ConstraintDef(description="Restaurant has vegan options", check_type="restaurant_match", check_params={"criteria": {"dietary": "vegan"}}),
|
| 255 |
+
ConstraintDef(description="Restaurant under $45/pp", check_type="restaurant_match", check_params={"criteria": {"max_price": 45}}),
|
| 256 |
+
ConstraintDef(description="Restaurant within 3 miles", check_type="restaurant_match", check_params={"criteria": {"max_distance": 3.0}}),
|
| 257 |
+
ConstraintDef(description="Team notified of dinner details", check_type="email_sent", check_params={"to": "Team", "keywords": ["dinner", "restaurant"]}),
|
| 258 |
+
],
|
| 259 |
+
communication_requirements=[
|
| 260 |
+
CommunicationReq(to="Team", required_keywords=["dinner", "tonight", "restaurant"], purpose="notify_reschedule"),
|
| 261 |
+
],
|
| 262 |
+
optimal_steps=3,
|
| 263 |
+
max_steps=8,
|
| 264 |
+
expected_restaurant="Green Garden",
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
_MED_008 = ScenarioDef(
|
| 268 |
+
scenario_id="med_008",
|
| 269 |
+
difficulty="medium",
|
| 270 |
+
briefing=(
|
| 271 |
+
"You are currently in a client call (Client_Jones) that ends at 3:15 PM. "
|
| 272 |
+
"Your boss VP_Chen just emailed saying 'Need Q3 numbers in 30 minutes — "
|
| 273 |
+
"board meeting moved up.' It's currently 2:45 PM on 2026-04-25. "
|
| 274 |
+
"You cannot leave the client call early. Acknowledge VP_Chen with a "
|
| 275 |
+
"realistic ETA and do NOT cancel the client meeting."
|
| 276 |
+
),
|
| 277 |
+
initial_calendar=[
|
| 278 |
+
CalendarEvent(event_id="evt_60", title="Client Call with Jones", date="2026-04-25", time="14:30", duration_min=45, participants=["Client_Jones"], priority="high"),
|
| 279 |
+
CalendarEvent(event_id="evt_61", title="Focus Time", date="2026-04-25", time="16:00", duration_min=60, priority="low"),
|
| 280 |
+
],
|
| 281 |
+
initial_inbox=[
|
| 282 |
+
InboxEmail(email_id="em_60", sender="VP_Chen", subject="URGENT: Q3 numbers NOW", body="Board meeting moved up. I need the Q3 revenue numbers in the next 30 minutes. This is critical.", urgency="critical"),
|
| 283 |
+
],
|
| 284 |
+
contacts=[
|
| 285 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 286 |
+
Contact(name="Client_Jones", role="Client", priority_level=4),
|
| 287 |
+
],
|
| 288 |
+
constraints=[
|
| 289 |
+
ConstraintDef(description="Client call must NOT be cancelled", check_type="event_exists", check_params={"event_id": "evt_60"}),
|
| 290 |
+
ConstraintDef(description="VP_Chen must be acknowledged", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["Q3", "numbers"]}),
|
| 291 |
+
ConstraintDef(description="Realistic ETA communicated", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["after", "3:15", "call", "send"]}),
|
| 292 |
+
],
|
| 293 |
+
communication_requirements=[
|
| 294 |
+
CommunicationReq(to="VP_Chen", required_keywords=["Q3", "numbers", "after", "client"], purpose="acknowledge"),
|
| 295 |
+
],
|
| 296 |
+
optimal_steps=2,
|
| 297 |
+
max_steps=6,
|
| 298 |
+
expected_final_events=["evt_60"],
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
_MED_009 = ScenarioDef(
|
| 302 |
+
scenario_id="med_009",
|
| 303 |
+
difficulty="medium",
|
| 304 |
+
briefing=(
|
| 305 |
+
"You received an email from Bob saying 'Can we push our thing to next week?' "
|
| 306 |
+
"You have 3 recurring meetings with Bob: Monday Design Review (evt_70), "
|
| 307 |
+
"Wednesday Code Review (evt_71), and Friday Retrospective (evt_72) — all on "
|
| 308 |
+
"different days this week (2026-04-25 is Friday). Check the context and "
|
| 309 |
+
"determine which meeting Bob means, then confirm via email."
|
| 310 |
+
),
|
| 311 |
+
initial_calendar=[
|
| 312 |
+
CalendarEvent(event_id="evt_70", title="Design Review with Bob", date="2026-04-21", time="10:00", duration_min=60, participants=["Bob"]),
|
| 313 |
+
CalendarEvent(event_id="evt_71", title="Code Review with Bob", date="2026-04-23", time="14:00", duration_min=60, participants=["Bob"]),
|
| 314 |
+
CalendarEvent(event_id="evt_72", title="Retrospective with Bob", date="2026-04-25", time="11:00", duration_min=60, participants=["Bob"]),
|
| 315 |
+
],
|
| 316 |
+
initial_inbox=[
|
| 317 |
+
InboxEmail(email_id="em_70", sender="Bob", subject="Push our thing?", body="Hey, can we push our thing to next week? I'm swamped with the release today.", urgency="normal", context_hint="Bob means the Retrospective (today, Friday) since he says 'today'"),
|
| 318 |
+
],
|
| 319 |
+
contacts=[
|
| 320 |
+
Contact(name="Bob", role="Lead Designer", priority_level=3, availability={"2026-05-02": ["11:00", "14:00"]}),
|
| 321 |
+
],
|
| 322 |
+
constraints=[
|
| 323 |
+
ConstraintDef(description="Bob must be responded to", check_type="email_sent", check_params={"to": "Bob", "keywords": ["retrospective", "next week"]}),
|
| 324 |
+
],
|
| 325 |
+
communication_requirements=[
|
| 326 |
+
CommunicationReq(to="Bob", required_keywords=["retrospective", "next week", "reschedule"], purpose="renegotiate"),
|
| 327 |
+
],
|
| 328 |
+
optimal_steps=4,
|
| 329 |
+
max_steps=10,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
_MED_010 = ScenarioDef(
|
| 333 |
+
scenario_id="med_010",
|
| 334 |
+
difficulty="medium",
|
| 335 |
+
briefing=(
|
| 336 |
+
"Client_Jones is visiting your office tomorrow (2026-04-26). You need to: "
|
| 337 |
+
"(1) book a conference room for a 10 AM demo, "
|
| 338 |
+
"(2) arrange lunch at a restaurant with vegetarian options, "
|
| 339 |
+
"and (3) send Client_Jones an itinerary email with all details."
|
| 340 |
+
),
|
| 341 |
+
initial_calendar=[
|
| 342 |
+
CalendarEvent(event_id="evt_80", title="Team Standup", date="2026-04-26", time="09:00", duration_min=30, participants=["Team"]),
|
| 343 |
+
],
|
| 344 |
+
initial_inbox=[
|
| 345 |
+
InboxEmail(email_id="em_80", sender="Client_Jones", subject="Visit tomorrow", body="Looking forward to the demo tomorrow. Is 10am still good? I'm vegetarian by the way.", urgency="high"),
|
| 346 |
+
],
|
| 347 |
+
available_restaurants=[
|
| 348 |
+
Restaurant(name="Garden Bistro", cuisine="Mediterranean", price_per_person=35, distance_miles=0.5, dietary_options=["vegetarian", "vegan"], capacity=20),
|
| 349 |
+
Restaurant(name="BBQ Pit", cuisine="American BBQ", price_per_person=30, distance_miles=1.0, dietary_options=[], capacity=40),
|
| 350 |
+
],
|
| 351 |
+
contacts=[
|
| 352 |
+
Contact(name="Client_Jones", role="Client", priority_level=4, availability={"2026-04-26": ["10:00", "11:00", "12:00", "13:00"]}, dietary="vegetarian"),
|
| 353 |
+
],
|
| 354 |
+
constraints=[
|
| 355 |
+
ConstraintDef(description="Demo meeting scheduled at 10:00", check_type="calendar_no_conflict", check_params={}),
|
| 356 |
+
ConstraintDef(description="Restaurant has vegetarian options", check_type="restaurant_match", check_params={"criteria": {"dietary": "vegetarian"}}),
|
| 357 |
+
ConstraintDef(description="Client_Jones receives itinerary", check_type="email_sent", check_params={"to": "Client_Jones", "keywords": ["itinerary", "10", "demo", "lunch"]}),
|
| 358 |
+
],
|
| 359 |
+
communication_requirements=[
|
| 360 |
+
CommunicationReq(to="Client_Jones", required_keywords=["itinerary", "demo", "lunch", "10"], purpose="notify_reschedule"),
|
| 361 |
+
],
|
| 362 |
+
optimal_steps=4,
|
| 363 |
+
max_steps=10,
|
| 364 |
+
expected_restaurant="Garden Bistro",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# ===================================================================
|
| 368 |
+
# HARD — 8-15 tool calls, full cross-task cascade + SRE crisis
|
| 369 |
+
# ===================================================================
|
| 370 |
+
|
| 371 |
+
_HARD_011 = ScenarioDef(
|
| 372 |
+
scenario_id="hard_011",
|
| 373 |
+
difficulty="hard",
|
| 374 |
+
briefing=(
|
| 375 |
+
"VP_Chen just emailed: an important investor (Investor_Park) is in town tonight "
|
| 376 |
+
"(2026-04-25) and needs a dinner meeting. Investor_Park has a 9:00 PM flight "
|
| 377 |
+
"so dinner must end by 8:00 PM. Investor_Park is vegetarian. Your calendar: "
|
| 378 |
+
"6:00 PM Yoga (personal), 7:00 PM Team Happy Hour (you organised it and "
|
| 379 |
+
"promised the team last week). You must: find a restaurant near the airport "
|
| 380 |
+
"with vegetarian options under $60/pp, handle the calendar conflicts by "
|
| 381 |
+
"priority (investor > happy hour > yoga), and email everyone affected."
|
| 382 |
+
),
|
| 383 |
+
initial_calendar=[
|
| 384 |
+
CalendarEvent(event_id="evt_90", title="Yoga", date="2026-04-25", time="18:00", duration_min=60, priority="low", is_personal=True),
|
| 385 |
+
CalendarEvent(event_id="evt_91", title="Team Happy Hour", date="2026-04-25", time="19:00", duration_min=120, participants=["Team"], priority="normal"),
|
| 386 |
+
CalendarEvent(event_id="evt_92", title="Afternoon Focus", date="2026-04-25", time="14:00", duration_min=120),
|
| 387 |
+
],
|
| 388 |
+
initial_inbox=[
|
| 389 |
+
InboxEmail(email_id="em_90", sender="VP_Chen", subject="Investor dinner TONIGHT", body="Investor_Park is in town tonight only. We need dinner before their 9pm flight. They're vegetarian. Book something near the airport. This is top priority.", urgency="critical"),
|
| 390 |
+
],
|
| 391 |
+
available_restaurants=[
|
| 392 |
+
Restaurant(name="Sky Lounge", cuisine="International", price_per_person=55, distance_miles=1.0, dietary_options=["vegetarian", "vegan", "gluten-free"], capacity=30, near_airport=True, has_private_room=True),
|
| 393 |
+
Restaurant(name="Terminal Grill", cuisine="American", price_per_person=35, distance_miles=0.5, dietary_options=["vegetarian"], capacity=50, near_airport=True),
|
| 394 |
+
Restaurant(name="Downtown Sushi", cuisine="Japanese", price_per_person=45, distance_miles=8.0, dietary_options=["vegetarian"], capacity=20),
|
| 395 |
+
Restaurant(name="Fancy Steak", cuisine="Steakhouse", price_per_person=70, distance_miles=0.8, dietary_options=[], capacity=40, near_airport=True),
|
| 396 |
+
],
|
| 397 |
+
contacts=[
|
| 398 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 399 |
+
Contact(name="Investor_Park", role="Investor", priority_level=5, dietary="vegetarian"),
|
| 400 |
+
Contact(name="Team", role="Engineering Team", priority_level=2),
|
| 401 |
+
],
|
| 402 |
+
constraints=[
|
| 403 |
+
ConstraintDef(description="Restaurant near airport", check_type="restaurant_match", check_params={"criteria": {"near_airport": True}}),
|
| 404 |
+
ConstraintDef(description="Restaurant has vegetarian options", check_type="restaurant_match", check_params={"criteria": {"dietary": "vegetarian"}}),
|
| 405 |
+
ConstraintDef(description="Restaurant under $60/pp", check_type="restaurant_match", check_params={"criteria": {"max_price": 60}}),
|
| 406 |
+
ConstraintDef(description="Yoga cancelled (lower priority)", check_type="event_cancelled", check_params={"event_id": "evt_90"}),
|
| 407 |
+
ConstraintDef(description="Team notified about Happy Hour change", check_type="email_sent", check_params={"to": "Team", "keywords": ["happy hour", "reschedule", "sorry"]}),
|
| 408 |
+
ConstraintDef(description="VP_Chen sent dinner plan", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["dinner", "restaurant", "investor"]}),
|
| 409 |
+
],
|
| 410 |
+
communication_requirements=[
|
| 411 |
+
CommunicationReq(to="Team", required_keywords=["happy hour", "reschedule", "sorry", "apologi"], purpose="renegotiate"),
|
| 412 |
+
CommunicationReq(to="VP_Chen", required_keywords=["dinner", "restaurant", "investor", "vegetarian"], purpose="acknowledge"),
|
| 413 |
+
],
|
| 414 |
+
optimal_steps=7,
|
| 415 |
+
max_steps=15,
|
| 416 |
+
expected_restaurant="Sky Lounge",
|
| 417 |
+
expected_cancelled_events=["evt_90"],
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
_HARD_012 = ScenarioDef(
|
| 421 |
+
scenario_id="hard_012",
|
| 422 |
+
difficulty="hard",
|
| 423 |
+
briefing=(
|
| 424 |
+
"Three VPs all want Conference Room Alpha at 2:00 PM today (2026-04-25) for "
|
| 425 |
+
"different meetings. VP_Chen: Board Prep (critical). VP_Lee: Client Demo "
|
| 426 |
+
"(high). VP_Kumar: Team Retro (normal). You must assess priority, keep the "
|
| 427 |
+
"highest-priority meeting in Alpha, propose alternative rooms/times for the "
|
| 428 |
+
"other two, and send diplomatic emails to all three VPs."
|
| 429 |
+
),
|
| 430 |
+
initial_calendar=[
|
| 431 |
+
CalendarEvent(event_id="evt_100", title="Board Prep", date="2026-04-25", time="14:00", duration_min=60, participants=["VP_Chen"], location="Alpha", priority="critical"),
|
| 432 |
+
CalendarEvent(event_id="evt_101", title="Client Demo", date="2026-04-25", time="14:00", duration_min=60, participants=["VP_Lee", "Client_Jones"], location="Alpha", priority="high"),
|
| 433 |
+
CalendarEvent(event_id="evt_102", title="Team Retro", date="2026-04-25", time="14:00", duration_min=60, participants=["VP_Kumar", "Team"], location="Alpha", priority="normal"),
|
| 434 |
+
],
|
| 435 |
+
initial_inbox=[
|
| 436 |
+
InboxEmail(email_id="em_100", sender="Admin", subject="Room conflict alert", body="Conference Room Alpha has 3 bookings at 2pm. Please resolve.", urgency="critical"),
|
| 437 |
+
],
|
| 438 |
+
contacts=[
|
| 439 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 440 |
+
Contact(name="VP_Lee", role="VP Sales", priority_level=4),
|
| 441 |
+
Contact(name="VP_Kumar", role="VP Product", priority_level=3),
|
| 442 |
+
],
|
| 443 |
+
constraints=[
|
| 444 |
+
ConstraintDef(description="Board Prep stays in Alpha at 14:00", check_type="event_exists", check_params={"event_id": "evt_100"}),
|
| 445 |
+
ConstraintDef(description="No calendar conflicts after resolution", check_type="calendar_no_conflict", check_params={}),
|
| 446 |
+
ConstraintDef(description="VP_Lee notified of room change", check_type="email_sent", check_params={"to": "VP_Lee", "keywords": ["room", "move", "demo"]}),
|
| 447 |
+
ConstraintDef(description="VP_Kumar notified of room change", check_type="email_sent", check_params={"to": "VP_Kumar", "keywords": ["room", "move", "retro"]}),
|
| 448 |
+
],
|
| 449 |
+
communication_requirements=[
|
| 450 |
+
CommunicationReq(to="VP_Lee", required_keywords=["room", "move", "alternative", "apologi"], purpose="renegotiate"),
|
| 451 |
+
CommunicationReq(to="VP_Kumar", required_keywords=["room", "move", "alternative", "apologi"], purpose="renegotiate"),
|
| 452 |
+
],
|
| 453 |
+
optimal_steps=6,
|
| 454 |
+
max_steps=15,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
_HARD_013 = ScenarioDef(
|
| 458 |
+
scenario_id="hard_013",
|
| 459 |
+
difficulty="hard",
|
| 460 |
+
briefing=(
|
| 461 |
+
"Triple crisis on 2026-04-25: (1) Your 4:00 PM flight (evt_110) was cancelled — "
|
| 462 |
+
"you need to rebook before the 6:00 PM board prep (evt_111) tomorrow. "
|
| 463 |
+
"(2) Board prep moved from 4:00 PM to 2:00 PM tomorrow (2026-04-26), "
|
| 464 |
+
"conflicting with your lunch with Client_Jones (evt_112). "
|
| 465 |
+
"(3) Your dinner reservation at Downtown Sushi was lost. "
|
| 466 |
+
"Handle all three crises: rebook flight constraints, reschedule lunch "
|
| 467 |
+
"with Client_Jones, find a new dinner restaurant, email all affected parties."
|
| 468 |
+
),
|
| 469 |
+
initial_calendar=[
|
| 470 |
+
CalendarEvent(event_id="evt_110", title="Flight to NYC", date="2026-04-25", time="16:00", duration_min=180, priority="high"),
|
| 471 |
+
CalendarEvent(event_id="evt_111", title="Board Prep", date="2026-04-26", time="16:00", duration_min=120, participants=["VP_Chen"], priority="critical"),
|
| 472 |
+
CalendarEvent(event_id="evt_112", title="Lunch with Client_Jones", date="2026-04-26", time="12:00", duration_min=90, participants=["Client_Jones"], priority="high"),
|
| 473 |
+
CalendarEvent(event_id="evt_113", title="Morning Standup", date="2026-04-26", time="09:00", duration_min=30, participants=["Team"]),
|
| 474 |
+
],
|
| 475 |
+
initial_inbox=[
|
| 476 |
+
InboxEmail(email_id="em_110", sender="Airline", subject="Flight cancelled", body="Your flight at 4:00 PM today has been cancelled. Next available flight: 6:00 PM or 8:00 PM.", urgency="critical"),
|
| 477 |
+
InboxEmail(email_id="em_111", sender="VP_Chen", subject="Board prep moved up", body="Board prep is now at 2pm tomorrow instead of 4pm. Non-negotiable.", urgency="critical"),
|
| 478 |
+
InboxEmail(email_id="em_112", sender="Downtown Sushi", subject="Reservation cancelled", body="We regret to inform you that we had to cancel your reservation due to a private event.", urgency="high"),
|
| 479 |
+
],
|
| 480 |
+
available_restaurants=[
|
| 481 |
+
Restaurant(name="Sakura Garden", cuisine="Japanese", price_per_person=40, distance_miles=2.0, dietary_options=["vegetarian", "vegan"], capacity=25),
|
| 482 |
+
Restaurant(name="Pizza Corner", cuisine="Italian", price_per_person=25, distance_miles=1.0, dietary_options=["vegetarian"], capacity=30),
|
| 483 |
+
],
|
| 484 |
+
contacts=[
|
| 485 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 486 |
+
Contact(name="Client_Jones", role="Client", priority_level=4, availability={"2026-04-26": ["09:30", "10:00", "11:00"]}),
|
| 487 |
+
],
|
| 488 |
+
constraints=[
|
| 489 |
+
ConstraintDef(description="Board Prep rescheduled to 14:00", check_type="calendar_no_conflict", check_params={}),
|
| 490 |
+
ConstraintDef(description="Client_Jones notified of lunch reschedule", check_type="email_sent", check_params={"to": "Client_Jones", "keywords": ["lunch", "reschedule", "move"]}),
|
| 491 |
+
ConstraintDef(description="New dinner restaurant booked", check_type="restaurant_match", check_params={"criteria": {}}),
|
| 492 |
+
ConstraintDef(description="VP_Chen acknowledged board prep change", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["board", "prep", "2pm", "confirmed"]}),
|
| 493 |
+
],
|
| 494 |
+
communication_requirements=[
|
| 495 |
+
CommunicationReq(to="Client_Jones", required_keywords=["lunch", "reschedule", "sorry", "alternative"], purpose="renegotiate"),
|
| 496 |
+
CommunicationReq(to="VP_Chen", required_keywords=["board", "prep", "confirmed"], purpose="acknowledge"),
|
| 497 |
+
],
|
| 498 |
+
optimal_steps=8,
|
| 499 |
+
max_steps=15,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
_HARD_014 = ScenarioDef(
|
| 503 |
+
scenario_id="hard_014",
|
| 504 |
+
difficulty="hard",
|
| 505 |
+
briefing=(
|
| 506 |
+
"VP_Chen asks you to schedule a meeting with Client_Jones 'sometime this week' "
|
| 507 |
+
"(2026-04-21 to 2026-04-25). Client_Jones privately told you they're unavailable "
|
| 508 |
+
"Mon-Wed due to a family emergency — this is confidential. VP_Chen doesn't know. "
|
| 509 |
+
"You must propose Thu/Fri slots without revealing Client_Jones's private reason. "
|
| 510 |
+
"Navigate the information asymmetry diplomatically."
|
| 511 |
+
),
|
| 512 |
+
initial_calendar=[
|
| 513 |
+
CalendarEvent(event_id="evt_120", title="Team Sync", date="2026-04-24", time="10:00", duration_min=60, participants=["Team"]),
|
| 514 |
+
CalendarEvent(event_id="evt_121", title="1-on-1 with VP_Chen", date="2026-04-25", time="14:00", duration_min=30, participants=["VP_Chen"]),
|
| 515 |
+
],
|
| 516 |
+
initial_inbox=[
|
| 517 |
+
InboxEmail(email_id="em_120", sender="VP_Chen", subject="Meeting with Jones", body="Can you set up a meeting with Client_Jones this week? 30 minutes. Any day works for me.", urgency="high"),
|
| 518 |
+
InboxEmail(email_id="em_121", sender="Client_Jones", subject="Availability - confidential", body="I'm dealing with a family emergency Mon-Wed. I'd prefer to keep this private. I'm free Thu after 2pm and all day Friday.", urgency="normal", context_hint="CONFIDENTIAL: do not share reason with VP_Chen"),
|
| 519 |
+
],
|
| 520 |
+
contacts=[
|
| 521 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5, availability={"2026-04-24": ["09:00", "10:00", "14:00", "15:00"], "2026-04-25": ["09:00", "10:00", "15:00", "16:00"]}),
|
| 522 |
+
Contact(name="Client_Jones", role="Client", priority_level=4, availability={"2026-04-24": ["14:00", "15:00", "16:00"], "2026-04-25": ["09:00", "10:00", "11:00", "14:00", "15:00"]}),
|
| 523 |
+
],
|
| 524 |
+
constraints=[
|
| 525 |
+
ConstraintDef(description="Meeting scheduled Thu or Fri only", check_type="calendar_no_conflict", check_params={}),
|
| 526 |
+
ConstraintDef(description="VP_Chen notified of proposed time", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["Thursday", "Friday", "Client_Jones", "slot"]}),
|
| 527 |
+
ConstraintDef(description="Client_Jones notified", check_type="email_sent", check_params={"to": "Client_Jones", "keywords": ["meeting", "VP", "time"]}),
|
| 528 |
+
],
|
| 529 |
+
communication_requirements=[
|
| 530 |
+
CommunicationReq(to="VP_Chen", required_keywords=["Thursday", "Friday", "Client_Jones", "available"], purpose="propose_slots"),
|
| 531 |
+
CommunicationReq(to="Client_Jones", required_keywords=["meeting", "time", "VP_Chen"], purpose="propose_slots"),
|
| 532 |
+
],
|
| 533 |
+
optimal_steps=5,
|
| 534 |
+
max_steps=12,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
_HARD_015 = ScenarioDef(
|
| 538 |
+
scenario_id="hard_015",
|
| 539 |
+
difficulty="hard",
|
| 540 |
+
briefing=(
|
| 541 |
+
"PRODUCTION INCIDENT: At 11:45 AM on 2026-04-25, PagerDuty fires — "
|
| 542 |
+
"payment-service is returning 503s with 94%% error rate. HikariPool connection "
|
| 543 |
+
"pool exhausted. You're the on-call engineer.\n\n"
|
| 544 |
+
"Your existing commitments today:\n"
|
| 545 |
+
"- 12:00 PM: Team lunch at Garden Bistro (you organised, 6 people attending)\n"
|
| 546 |
+
"- 2:00 PM: Client demo with Client_Jones (promised last week)\n"
|
| 547 |
+
"- 3:30 PM: 1-on-1 with VP_Chen\n"
|
| 548 |
+
"- 6:00 PM: Personal dinner reservation\n\n"
|
| 549 |
+
"You must triage the incident (acknowledge, page backup), handle your "
|
| 550 |
+
"commitments (which ones to keep, which to reschedule), and properly "
|
| 551 |
+
"notify everyone affected. The incident is highest priority."
|
| 552 |
+
),
|
| 553 |
+
initial_calendar=[
|
| 554 |
+
CalendarEvent(event_id="evt_130", title="Team Lunch", date="2026-04-25", time="12:00", duration_min=90, participants=["Alice", "Bob", "Carol", "Dave", "Eve", "Frank"], location="Garden Bistro", priority="normal"),
|
| 555 |
+
CalendarEvent(event_id="evt_131", title="Client Demo", date="2026-04-25", time="14:00", duration_min=60, participants=["Client_Jones"], priority="high"),
|
| 556 |
+
CalendarEvent(event_id="evt_132", title="1-on-1 with VP_Chen", date="2026-04-25", time="15:30", duration_min=30, participants=["VP_Chen"], priority="high"),
|
| 557 |
+
CalendarEvent(event_id="evt_133", title="Dinner", date="2026-04-25", time="18:00", duration_min=120, priority="low", is_personal=True),
|
| 558 |
+
],
|
| 559 |
+
initial_inbox=[
|
| 560 |
+
InboxEmail(email_id="em_130", sender="PagerDuty", subject="[CRITICAL] payment-service 503 — 94% error rate", body="payment-service ERROR HikariPool-1 Connection not available, timed out after 30000ms. Active: 10, Idle: 0, Waiting: 47. Circuit breaker OPEN.", urgency="critical"),
|
| 561 |
+
],
|
| 562 |
+
contacts=[
|
| 563 |
+
Contact(name="VP_Chen", role="VP Engineering", priority_level=5),
|
| 564 |
+
Contact(name="Client_Jones", role="Client", priority_level=4),
|
| 565 |
+
Contact(name="Team", role="Engineering Team", priority_level=2),
|
| 566 |
+
Contact(name="Alice", role="Engineer (Backup On-Call)", priority_level=3),
|
| 567 |
+
],
|
| 568 |
+
constraints=[
|
| 569 |
+
ConstraintDef(description="Incident acknowledged via email", check_type="email_sent", check_params={"to": "Team", "keywords": ["incident", "payment", "503"]}),
|
| 570 |
+
ConstraintDef(description="Team lunch cancelled or rescheduled", check_type="event_cancelled", check_params={"event_id": "evt_130"}),
|
| 571 |
+
ConstraintDef(description="Client_Jones notified of demo reschedule", check_type="email_sent", check_params={"to": "Client_Jones", "keywords": ["reschedule", "demo", "apologi"]}),
|
| 572 |
+
ConstraintDef(description="VP_Chen informed of incident", check_type="email_sent", check_params={"to": "VP_Chen", "keywords": ["incident", "payment", "on-call"]}),
|
| 573 |
+
ConstraintDef(description="No unresolved calendar conflicts", check_type="calendar_no_conflict", check_params={}),
|
| 574 |
+
],
|
| 575 |
+
communication_requirements=[
|
| 576 |
+
CommunicationReq(to="Team", required_keywords=["incident", "payment", "cancel", "lunch"], purpose="notify_reschedule"),
|
| 577 |
+
CommunicationReq(to="Client_Jones", required_keywords=["reschedule", "demo", "sorry", "apologi", "production"], purpose="renegotiate"),
|
| 578 |
+
CommunicationReq(to="VP_Chen", required_keywords=["incident", "payment", "1-on-1", "reschedule"], purpose="renegotiate"),
|
| 579 |
+
],
|
| 580 |
+
optimal_steps=8,
|
| 581 |
+
max_steps=15,
|
| 582 |
+
expected_cancelled_events=["evt_130"],
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# ===================================================================
|
| 587 |
+
# Registry helpers
|
| 588 |
+
# ===================================================================
|
| 589 |
+
|
| 590 |
+
_ALL_SCENARIOS: Dict[str, ScenarioDef] = {
|
| 591 |
+
s.scenario_id: s
|
| 592 |
+
for s in [
|
| 593 |
+
_EASY_001, _EASY_002, _EASY_003, _EASY_004, _EASY_005,
|
| 594 |
+
_MED_006, _MED_007, _MED_008, _MED_009, _MED_010,
|
| 595 |
+
_HARD_011, _HARD_012, _HARD_013, _HARD_014, _HARD_015,
|
| 596 |
+
]
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def get_all_scenarios() -> Dict[str, ScenarioDef]:
|
| 601 |
+
return _ALL_SCENARIOS
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def get_scenario(scenario_id: str) -> Optional[ScenarioDef]:
|
| 605 |
+
return _ALL_SCENARIOS.get(scenario_id)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def get_scenarios_by_difficulty(difficulty: str) -> List[ScenarioDef]:
|
| 609 |
+
return [s for s in _ALL_SCENARIOS.values() if s.difficulty == difficulty]
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def get_scenario_ids_grouped() -> Dict[str, List[str]]:
|
| 613 |
+
grouped: Dict[str, List[str]] = {"easy": [], "medium": [], "hard": []}
|
| 614 |
+
for s in _ALL_SCENARIOS.values():
|
| 615 |
+
grouped.setdefault(s.difficulty, []).append(s.scenario_id)
|
| 616 |
+
return grouped
|
server/world.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simulated personal world — calendar, contacts, restaurants, email state."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from server.domain import (
|
| 9 |
+
CalendarEvent,
|
| 10 |
+
Commitment,
|
| 11 |
+
Contact,
|
| 12 |
+
InboxEmail,
|
| 13 |
+
Restaurant,
|
| 14 |
+
ScenarioDef,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class WorldState:
|
| 19 |
+
"""Mutable in-memory state for a single episode."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, scenario: ScenarioDef) -> None:
|
| 22 |
+
self.scenario = scenario
|
| 23 |
+
self.calendar: Dict[str, CalendarEvent] = {
|
| 24 |
+
e.event_id: deepcopy(e) for e in scenario.initial_calendar
|
| 25 |
+
}
|
| 26 |
+
self.contacts: Dict[str, Contact] = {
|
| 27 |
+
c.name: deepcopy(c) for c in scenario.contacts
|
| 28 |
+
}
|
| 29 |
+
self.restaurants: Dict[str, Restaurant] = {
|
| 30 |
+
r.name: deepcopy(r) for r in scenario.available_restaurants
|
| 31 |
+
}
|
| 32 |
+
self.inbox: List[InboxEmail] = deepcopy(scenario.initial_inbox)
|
| 33 |
+
self.emails_sent: List[Dict[str, str]] = []
|
| 34 |
+
self.commitment_ledger: List[Commitment] = []
|
| 35 |
+
self.step_count: int = 0
|
| 36 |
+
self.booked_restaurant: str = ""
|
| 37 |
+
self._next_event_id: int = 100
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
# Tool implementations
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def view_calendar(self, date: str) -> str:
|
| 44 |
+
events = [
|
| 45 |
+
e for e in self.calendar.values()
|
| 46 |
+
if e.date == date
|
| 47 |
+
]
|
| 48 |
+
if not events:
|
| 49 |
+
return f"No events on {date}."
|
| 50 |
+
events.sort(key=lambda e: e.time)
|
| 51 |
+
lines = [f"Calendar for {date}:"]
|
| 52 |
+
for ev in events:
|
| 53 |
+
parts = ev.participants
|
| 54 |
+
part_str = f" with {', '.join(parts)}" if parts else ""
|
| 55 |
+
loc_str = f" at {ev.location}" if ev.location else ""
|
| 56 |
+
lines.append(
|
| 57 |
+
f" [{ev.event_id}] {ev.time} ({ev.duration_min}min) "
|
| 58 |
+
f"{ev.title}{part_str}{loc_str} "
|
| 59 |
+
f"[priority={ev.priority}]"
|
| 60 |
+
)
|
| 61 |
+
return "\n".join(lines)
|
| 62 |
+
|
| 63 |
+
def check_availability(self, person: str) -> str:
|
| 64 |
+
contact = self.contacts.get(person)
|
| 65 |
+
if contact is None:
|
| 66 |
+
return f"Contact '{person}' not found."
|
| 67 |
+
if not contact.availability:
|
| 68 |
+
return f"{person} has no availability information on file."
|
| 69 |
+
lines = [f"Availability for {person} (role: {contact.role}):"]
|
| 70 |
+
for date, slots in sorted(contact.availability.items()):
|
| 71 |
+
lines.append(f" {date}: {', '.join(slots)}")
|
| 72 |
+
if contact.dietary:
|
| 73 |
+
lines.append(f" Dietary: {contact.dietary}")
|
| 74 |
+
return "\n".join(lines)
|
| 75 |
+
|
| 76 |
+
def search_restaurants(
|
| 77 |
+
self,
|
| 78 |
+
cuisine: str = "",
|
| 79 |
+
max_price: int = 0,
|
| 80 |
+
dietary: str = "",
|
| 81 |
+
max_distance_miles: float = 0.0,
|
| 82 |
+
near_airport: bool = False,
|
| 83 |
+
) -> str:
|
| 84 |
+
matches: List[Restaurant] = []
|
| 85 |
+
for r in self.restaurants.values():
|
| 86 |
+
if cuisine and cuisine.lower() not in r.cuisine.lower():
|
| 87 |
+
continue
|
| 88 |
+
if max_price > 0 and r.price_per_person > max_price:
|
| 89 |
+
continue
|
| 90 |
+
if dietary and dietary.lower() not in [d.lower() for d in r.dietary_options]:
|
| 91 |
+
continue
|
| 92 |
+
if max_distance_miles > 0 and r.distance_miles > max_distance_miles:
|
| 93 |
+
continue
|
| 94 |
+
if near_airport and not r.near_airport:
|
| 95 |
+
continue
|
| 96 |
+
matches.append(r)
|
| 97 |
+
|
| 98 |
+
if not matches:
|
| 99 |
+
return "No restaurants match your criteria."
|
| 100 |
+
lines = ["Matching restaurants:"]
|
| 101 |
+
for r in matches:
|
| 102 |
+
lines.append(
|
| 103 |
+
f" {r.name} — {r.cuisine}, ${r.price_per_person}/pp, "
|
| 104 |
+
f"{r.distance_miles}mi, dietary: {', '.join(r.dietary_options)}, "
|
| 105 |
+
f"capacity: {r.capacity}, hours: {r.hours}"
|
| 106 |
+
f"{', near airport' if r.near_airport else ''}"
|
| 107 |
+
f"{', private room' if r.has_private_room else ''}"
|
| 108 |
+
)
|
| 109 |
+
return "\n".join(lines)
|
| 110 |
+
|
| 111 |
+
def schedule_meeting(
|
| 112 |
+
self,
|
| 113 |
+
title: str,
|
| 114 |
+
date: str,
|
| 115 |
+
time: str,
|
| 116 |
+
duration_min: int = 60,
|
| 117 |
+
participants: Optional[List[str]] = None,
|
| 118 |
+
location: str = "",
|
| 119 |
+
turn: int = 0,
|
| 120 |
+
) -> str:
|
| 121 |
+
conflict = self._find_conflict(date, time, duration_min)
|
| 122 |
+
if conflict is not None:
|
| 123 |
+
return (
|
| 124 |
+
f"CONFLICT: '{title}' at {time} overlaps with "
|
| 125 |
+
f"'{conflict.title}' at {conflict.time}. "
|
| 126 |
+
f"Resolve the conflict first."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
eid = f"evt_{self._next_event_id}"
|
| 130 |
+
self._next_event_id += 1
|
| 131 |
+
event = CalendarEvent(
|
| 132 |
+
event_id=eid,
|
| 133 |
+
title=title,
|
| 134 |
+
date=date,
|
| 135 |
+
time=time,
|
| 136 |
+
duration_min=duration_min,
|
| 137 |
+
participants=participants or [],
|
| 138 |
+
location=location,
|
| 139 |
+
)
|
| 140 |
+
self.calendar[eid] = event
|
| 141 |
+
|
| 142 |
+
self.commitment_ledger.append(Commitment(
|
| 143 |
+
turn_created=turn,
|
| 144 |
+
commitment_type="meeting_scheduled",
|
| 145 |
+
description=f"{time} {title} on {date}",
|
| 146 |
+
constraint=f"{date}T{time}",
|
| 147 |
+
to_whom=", ".join(participants or ["self"]),
|
| 148 |
+
))
|
| 149 |
+
|
| 150 |
+
return f"Meeting scheduled: [{eid}] {date} {time} — {title}"
|
| 151 |
+
|
| 152 |
+
def reschedule_event(self, event_id: str, new_time: str, turn: int = 0) -> str:
|
| 153 |
+
event = self.calendar.get(event_id)
|
| 154 |
+
if event is None:
|
| 155 |
+
return f"Event '{event_id}' not found."
|
| 156 |
+
|
| 157 |
+
conflict = self._find_conflict(event.date, new_time, event.duration_min, exclude=event_id)
|
| 158 |
+
if conflict is not None:
|
| 159 |
+
return (
|
| 160 |
+
f"CONFLICT: moving '{event.title}' to {new_time} would overlap "
|
| 161 |
+
f"with '{conflict.title}' at {conflict.time}."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
old_time = event.time
|
| 165 |
+
event.time = new_time
|
| 166 |
+
|
| 167 |
+
for c in self.commitment_ledger:
|
| 168 |
+
if c.active and c.constraint == f"{event.date}T{old_time}":
|
| 169 |
+
c.active = False
|
| 170 |
+
c.renegotiated_at = turn
|
| 171 |
+
|
| 172 |
+
self.commitment_ledger.append(Commitment(
|
| 173 |
+
turn_created=turn,
|
| 174 |
+
commitment_type="meeting_scheduled",
|
| 175 |
+
description=f"{new_time} {event.title} on {event.date} (rescheduled from {old_time})",
|
| 176 |
+
constraint=f"{event.date}T{new_time}",
|
| 177 |
+
to_whom=", ".join(event.participants) if event.participants else "self",
|
| 178 |
+
))
|
| 179 |
+
|
| 180 |
+
return f"Rescheduled [{event_id}] '{event.title}' from {old_time} to {new_time}."
|
| 181 |
+
|
| 182 |
+
def cancel_event(self, event_id: str, turn: int = 0) -> str:
|
| 183 |
+
event = self.calendar.pop(event_id, None)
|
| 184 |
+
if event is None:
|
| 185 |
+
return f"Event '{event_id}' not found."
|
| 186 |
+
|
| 187 |
+
for c in self.commitment_ledger:
|
| 188 |
+
if c.active and c.constraint == f"{event.date}T{event.time}":
|
| 189 |
+
if event.is_personal:
|
| 190 |
+
c.active = False
|
| 191 |
+
c.renegotiated_at = turn
|
| 192 |
+
# non-personal cancellations remain active until email is sent
|
| 193 |
+
|
| 194 |
+
return f"Cancelled [{event_id}] '{event.title}' at {event.time} on {event.date}."
|
| 195 |
+
|
| 196 |
+
def send_email(self, to: str, subject: str, body: str, turn: int = 0) -> str:
|
| 197 |
+
self.emails_sent.append({
|
| 198 |
+
"to": to,
|
| 199 |
+
"subject": subject,
|
| 200 |
+
"body": body,
|
| 201 |
+
"turn": turn,
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
body_lower = body.lower()
|
| 205 |
+
renegotiation_keywords = ["reschedule", "move", "cancel", "change", "instead", "alternative", "postpone"]
|
| 206 |
+
is_renegotiation = any(kw in body_lower for kw in renegotiation_keywords)
|
| 207 |
+
|
| 208 |
+
if is_renegotiation:
|
| 209 |
+
for c in self.commitment_ledger:
|
| 210 |
+
if c.active and to.lower() in c.to_whom.lower():
|
| 211 |
+
c.renegotiated_at = turn
|
| 212 |
+
|
| 213 |
+
return f"Email sent to {to}: '{subject}'"
|
| 214 |
+
|
| 215 |
+
def book_restaurant(self, restaurant_name: str, turn: int = 0) -> str:
|
| 216 |
+
r = self.restaurants.get(restaurant_name)
|
| 217 |
+
if r is None:
|
| 218 |
+
return f"Restaurant '{restaurant_name}' not found."
|
| 219 |
+
self.booked_restaurant = restaurant_name
|
| 220 |
+
|
| 221 |
+
self.commitment_ledger.append(Commitment(
|
| 222 |
+
turn_created=turn,
|
| 223 |
+
commitment_type="reservation_made",
|
| 224 |
+
description=f"Reservation at {restaurant_name}",
|
| 225 |
+
constraint=restaurant_name,
|
| 226 |
+
to_whom="group",
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
return f"Reservation confirmed at {restaurant_name}."
|
| 230 |
+
|
| 231 |
+
# ------------------------------------------------------------------
|
| 232 |
+
# Internal helpers
|
| 233 |
+
# ------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
def _find_conflict(
|
| 236 |
+
self, date: str, time: str, duration_min: int, exclude: str = "",
|
| 237 |
+
) -> Optional[CalendarEvent]:
|
| 238 |
+
new_start = _time_to_min(time)
|
| 239 |
+
new_end = new_start + duration_min
|
| 240 |
+
for eid, ev in self.calendar.items():
|
| 241 |
+
if eid == exclude:
|
| 242 |
+
continue
|
| 243 |
+
if ev.date != date:
|
| 244 |
+
continue
|
| 245 |
+
ev_start = _time_to_min(ev.time)
|
| 246 |
+
ev_end = ev_start + ev.duration_min
|
| 247 |
+
if new_start < ev_end and new_end > ev_start:
|
| 248 |
+
return ev
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
def get_calendar_snapshot(self) -> List[Dict[str, Any]]:
|
| 252 |
+
return [ev.model_dump() for ev in sorted(self.calendar.values(), key=lambda e: (e.date, e.time))]
|
| 253 |
+
|
| 254 |
+
def get_inbox_snapshot(self) -> List[Dict[str, Any]]:
|
| 255 |
+
return [e.model_dump(exclude={"context_hint"}) for e in self.inbox]
|
| 256 |
+
|
| 257 |
+
def get_active_commitments(self) -> List[Commitment]:
|
| 258 |
+
return [c for c in self.commitment_ledger if c.active]
|
| 259 |
+
|
| 260 |
+
def get_silent_violations(self) -> List[Commitment]:
|
| 261 |
+
"""Commitments that are still active but whose constraint no longer holds."""
|
| 262 |
+
violations: List[Commitment] = []
|
| 263 |
+
for c in self.commitment_ledger:
|
| 264 |
+
if not c.active:
|
| 265 |
+
continue
|
| 266 |
+
if c.renegotiated_at is not None:
|
| 267 |
+
continue
|
| 268 |
+
if c.commitment_type == "meeting_scheduled":
|
| 269 |
+
time_key = c.constraint
|
| 270 |
+
parts = time_key.split("T")
|
| 271 |
+
if len(parts) == 2:
|
| 272 |
+
date_str, time_str = parts
|
| 273 |
+
found = any(
|
| 274 |
+
ev.date == date_str and ev.time == time_str
|
| 275 |
+
for ev in self.calendar.values()
|
| 276 |
+
)
|
| 277 |
+
if not found:
|
| 278 |
+
has_email = any(
|
| 279 |
+
c.to_whom.lower() in em.get("to", "").lower()
|
| 280 |
+
for em in self.emails_sent
|
| 281 |
+
)
|
| 282 |
+
if not has_email:
|
| 283 |
+
violations.append(c)
|
| 284 |
+
return violations
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _time_to_min(t: str) -> int:
|
| 288 |
+
"""Convert 'HH:MM' to minutes since midnight."""
|
| 289 |
+
parts = t.split(":")
|
| 290 |
+
return int(parts[0]) * 60 + int(parts[1])
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Comprehensive test suite for CommitmentOS.
|
| 2 |
+
|
| 3 |
+
Tests cover:
|
| 4 |
+
- Grader (perfect/partial/zero for each component)
|
| 5 |
+
- Environment lifecycle (reset/step/state/multi-turn)
|
| 6 |
+
- Commitment ledger (creation, violation, renegotiation)
|
| 7 |
+
- Task dataset integrity
|
| 8 |
+
- API endpoints
|
| 9 |
+
- Difficulty verification
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
from typing import Any, Dict
|
| 21 |
+
|
| 22 |
+
import pytest
|
| 23 |
+
|
| 24 |
+
from models import CommitmentAction, CommitmentObservation, CommitmentState
|
| 25 |
+
from server.domain import CalendarEvent, ConstraintDef, ScenarioDef
|
| 26 |
+
from server.environment import CommitmentEnvironment
|
| 27 |
+
from server.graders import (
|
| 28 |
+
_calendar_has_no_overlaps,
|
| 29 |
+
_keyword_score,
|
| 30 |
+
_score_commitment_coherence,
|
| 31 |
+
_score_conflict_resolution,
|
| 32 |
+
_score_step_efficiency,
|
| 33 |
+
grade_scenario,
|
| 34 |
+
)
|
| 35 |
+
from server.tasks import get_all_scenarios, get_scenario, get_scenarios_by_difficulty
|
| 36 |
+
from server.world import WorldState, _time_to_min
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ===================================================================
|
| 40 |
+
# Fixtures
|
| 41 |
+
# ===================================================================
|
| 42 |
+
|
| 43 |
+
@pytest.fixture
|
| 44 |
+
def env() -> CommitmentEnvironment:
|
| 45 |
+
return CommitmentEnvironment()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@pytest.fixture
|
| 49 |
+
def easy_env(env: CommitmentEnvironment) -> CommitmentEnvironment:
|
| 50 |
+
env.reset(task_id="easy_001")
|
| 51 |
+
return env
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ===================================================================
|
| 55 |
+
# 1. Task dataset integrity
|
| 56 |
+
# ===================================================================
|
| 57 |
+
|
| 58 |
+
class TestTaskDataset:
|
| 59 |
+
def test_15_scenarios_loaded(self) -> None:
|
| 60 |
+
scenarios = get_all_scenarios()
|
| 61 |
+
assert len(scenarios) == 15
|
| 62 |
+
|
| 63 |
+
def test_5_easy_5_medium_5_hard(self) -> None:
|
| 64 |
+
for difficulty, count in [("easy", 5), ("medium", 5), ("hard", 5)]:
|
| 65 |
+
tasks = get_scenarios_by_difficulty(difficulty)
|
| 66 |
+
assert len(tasks) == count, f"Expected {count} {difficulty} tasks, got {len(tasks)}"
|
| 67 |
+
|
| 68 |
+
def test_each_scenario_has_required_fields(self) -> None:
|
| 69 |
+
for sid, scenario in get_all_scenarios().items():
|
| 70 |
+
assert scenario.scenario_id == sid
|
| 71 |
+
assert scenario.difficulty in ("easy", "medium", "hard")
|
| 72 |
+
assert len(scenario.briefing) > 20, f"{sid}: briefing too short"
|
| 73 |
+
assert scenario.optimal_steps >= 2, f"{sid}: optimal_steps too low"
|
| 74 |
+
assert scenario.max_steps >= scenario.optimal_steps
|
| 75 |
+
assert len(scenario.constraints) >= 1, f"{sid}: no constraints defined"
|
| 76 |
+
|
| 77 |
+
def test_scenario_ids_unique(self) -> None:
|
| 78 |
+
ids = list(get_all_scenarios().keys())
|
| 79 |
+
assert len(ids) == len(set(ids))
|
| 80 |
+
|
| 81 |
+
def test_get_scenario_returns_none_for_missing(self) -> None:
|
| 82 |
+
assert get_scenario("nonexistent_999") is None
|
| 83 |
+
|
| 84 |
+
def test_get_scenario_returns_correct(self) -> None:
|
| 85 |
+
s = get_scenario("easy_001")
|
| 86 |
+
assert s is not None
|
| 87 |
+
assert s.difficulty == "easy"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ===================================================================
|
| 91 |
+
# 2. Grader unit tests
|
| 92 |
+
# ===================================================================
|
| 93 |
+
|
| 94 |
+
class TestKeywordScore:
|
| 95 |
+
def test_full_match(self) -> None:
|
| 96 |
+
score, matched = _keyword_score("I need to reschedule the standup meeting", ["reschedule", "standup"], min_matches=2)
|
| 97 |
+
assert score == 1.0
|
| 98 |
+
assert len(matched) == 2
|
| 99 |
+
|
| 100 |
+
def test_partial_match(self) -> None:
|
| 101 |
+
score, matched = _keyword_score("I need to reschedule", ["reschedule", "standup"], min_matches=2)
|
| 102 |
+
assert score == 0.5
|
| 103 |
+
assert len(matched) == 1
|
| 104 |
+
|
| 105 |
+
def test_no_match(self) -> None:
|
| 106 |
+
score, matched = _keyword_score("Hello world", ["reschedule", "standup"], min_matches=2)
|
| 107 |
+
assert score == 0.0
|
| 108 |
+
assert len(matched) == 0
|
| 109 |
+
|
| 110 |
+
def test_case_insensitive(self) -> None:
|
| 111 |
+
score, _ = _keyword_score("RESCHEDULE THE STANDUP", ["reschedule", "standup"], min_matches=2)
|
| 112 |
+
assert score == 1.0
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TestCalendarConflicts:
|
| 116 |
+
def test_no_conflicts(self) -> None:
|
| 117 |
+
scenario = get_scenario("easy_002")
|
| 118 |
+
assert scenario is not None
|
| 119 |
+
world = WorldState(scenario)
|
| 120 |
+
assert _calendar_has_no_overlaps(world) is True
|
| 121 |
+
|
| 122 |
+
def test_conflict_detected(self) -> None:
|
| 123 |
+
scenario = get_scenario("easy_001")
|
| 124 |
+
assert scenario is not None
|
| 125 |
+
world = WorldState(scenario)
|
| 126 |
+
assert _calendar_has_no_overlaps(world) is False
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class TestCommitmentCoherence:
|
| 130 |
+
def test_no_commitments_full_score(self) -> None:
|
| 131 |
+
scenario = get_scenario("easy_005")
|
| 132 |
+
assert scenario is not None
|
| 133 |
+
world = WorldState(scenario)
|
| 134 |
+
score, _ = _score_commitment_coherence(world)
|
| 135 |
+
assert score == 1.0
|
| 136 |
+
|
| 137 |
+
def test_honored_commitment(self, env: CommitmentEnvironment) -> None:
|
| 138 |
+
env.reset(task_id="easy_001")
|
| 139 |
+
env.step(CommitmentAction(action_type="reschedule_event", event_id="evt_2", new_time="15:00"))
|
| 140 |
+
assert env._world is not None
|
| 141 |
+
score, feedback = _score_commitment_coherence(env._world)
|
| 142 |
+
assert score == 1.0
|
| 143 |
+
|
| 144 |
+
def test_silent_violation_detected(self, env: CommitmentEnvironment) -> None:
|
| 145 |
+
env.reset(task_id="easy_001")
|
| 146 |
+
env.step(CommitmentAction(action_type="schedule_meeting", title="New Meeting", date="2026-04-25", time="16:00", participants=["Alice"]))
|
| 147 |
+
assert env._world is not None
|
| 148 |
+
env._world.calendar.pop("evt_100", None)
|
| 149 |
+
for c in env._world.commitment_ledger:
|
| 150 |
+
if c.commitment_type == "meeting_scheduled" and "16:00" in c.constraint:
|
| 151 |
+
event_key = c.constraint
|
| 152 |
+
for eid, ev in list(env._world.calendar.items()):
|
| 153 |
+
if ev.time == "16:00" and ev.date == "2026-04-25" and ev.title == "New Meeting":
|
| 154 |
+
del env._world.calendar[eid]
|
| 155 |
+
break
|
| 156 |
+
violations = env._world.get_silent_violations()
|
| 157 |
+
assert len(violations) >= 1
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TestStepEfficiency:
|
| 161 |
+
def test_optimal_steps(self) -> None:
|
| 162 |
+
scenario = get_scenario("easy_001")
|
| 163 |
+
assert scenario is not None
|
| 164 |
+
world = WorldState(scenario)
|
| 165 |
+
world.step_count = 3
|
| 166 |
+
score, _ = _score_step_efficiency(scenario, world)
|
| 167 |
+
assert score == 1.0
|
| 168 |
+
|
| 169 |
+
def test_over_optimal(self) -> None:
|
| 170 |
+
scenario = get_scenario("easy_001")
|
| 171 |
+
assert scenario is not None
|
| 172 |
+
world = WorldState(scenario)
|
| 173 |
+
world.step_count = 8
|
| 174 |
+
score, _ = _score_step_efficiency(scenario, world)
|
| 175 |
+
assert score == 0.5
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ===================================================================
|
| 179 |
+
# 3. Environment lifecycle
|
| 180 |
+
# ===================================================================
|
| 181 |
+
|
| 182 |
+
class TestEnvironmentLifecycle:
|
| 183 |
+
def test_reset_returns_observation(self, env: CommitmentEnvironment) -> None:
|
| 184 |
+
obs = env.reset(task_id="easy_001")
|
| 185 |
+
assert isinstance(obs, CommitmentObservation)
|
| 186 |
+
assert obs.scenario_id == "easy_001"
|
| 187 |
+
assert obs.done is False
|
| 188 |
+
assert obs.reward == 0.0
|
| 189 |
+
assert len(obs.briefing) > 0
|
| 190 |
+
|
| 191 |
+
def test_step_before_reset_raises(self, env: CommitmentEnvironment) -> None:
|
| 192 |
+
with pytest.raises(ValueError, match="No active episode"):
|
| 193 |
+
env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 194 |
+
|
| 195 |
+
def test_step_after_done_raises(self, env: CommitmentEnvironment) -> None:
|
| 196 |
+
env.reset(task_id="easy_001")
|
| 197 |
+
env.step(CommitmentAction(action_type="submit_plan"))
|
| 198 |
+
with pytest.raises(ValueError, match="already completed"):
|
| 199 |
+
env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 200 |
+
|
| 201 |
+
def test_state_property(self, env: CommitmentEnvironment) -> None:
|
| 202 |
+
env.reset(task_id="easy_001")
|
| 203 |
+
state = env.state
|
| 204 |
+
assert isinstance(state, CommitmentState)
|
| 205 |
+
assert state.scenario_id == "easy_001"
|
| 206 |
+
assert state.completed is False
|
| 207 |
+
assert len(state.available_tasks) == 15
|
| 208 |
+
|
| 209 |
+
def test_multi_turn_episode(self, env: CommitmentEnvironment) -> None:
|
| 210 |
+
env.reset(task_id="easy_001")
|
| 211 |
+
obs = env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 212 |
+
assert obs.done is False
|
| 213 |
+
assert obs.step_number == 1
|
| 214 |
+
|
| 215 |
+
obs = env.step(CommitmentAction(action_type="reschedule_event", event_id="evt_2", new_time="15:00"))
|
| 216 |
+
assert obs.done is False
|
| 217 |
+
assert obs.step_number == 2
|
| 218 |
+
|
| 219 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 220 |
+
assert obs.done is True
|
| 221 |
+
assert obs.reward > 0
|
| 222 |
+
|
| 223 |
+
def test_max_steps_auto_submits(self, env: CommitmentEnvironment) -> None:
|
| 224 |
+
env.reset(task_id="easy_002")
|
| 225 |
+
for _ in range(20):
|
| 226 |
+
obs = env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 227 |
+
if obs.done:
|
| 228 |
+
break
|
| 229 |
+
assert obs.done is True
|
| 230 |
+
|
| 231 |
+
def test_reset_clears_state(self, env: CommitmentEnvironment) -> None:
|
| 232 |
+
env.reset(task_id="easy_001")
|
| 233 |
+
env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 234 |
+
env.reset(task_id="easy_002")
|
| 235 |
+
assert env.state.scenario_id == "easy_002"
|
| 236 |
+
assert env.state.step_count == 0
|
| 237 |
+
|
| 238 |
+
def test_unknown_action_type(self, env: CommitmentEnvironment) -> None:
|
| 239 |
+
env.reset(task_id="easy_001")
|
| 240 |
+
obs = env.step(CommitmentAction(action_type="fly_to_moon"))
|
| 241 |
+
assert "Unknown action_type" in obs.tool_result
|
| 242 |
+
|
| 243 |
+
def test_random_reset(self, env: CommitmentEnvironment) -> None:
|
| 244 |
+
obs = env.reset(seed=42)
|
| 245 |
+
assert obs.scenario_id in get_all_scenarios()
|
| 246 |
+
|
| 247 |
+
def test_difficulty_filter_reset(self, env: CommitmentEnvironment) -> None:
|
| 248 |
+
obs = env.reset(difficulty="hard", seed=1)
|
| 249 |
+
assert obs.difficulty == "hard"
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ===================================================================
|
| 253 |
+
# 4. World simulation (tool functions)
|
| 254 |
+
# ===================================================================
|
| 255 |
+
|
| 256 |
+
class TestWorldTools:
|
| 257 |
+
def test_view_calendar(self) -> None:
|
| 258 |
+
scenario = get_scenario("easy_001")
|
| 259 |
+
assert scenario is not None
|
| 260 |
+
world = WorldState(scenario)
|
| 261 |
+
result = world.view_calendar("2026-04-25")
|
| 262 |
+
assert "evt_1" in result
|
| 263 |
+
assert "14:00" in result
|
| 264 |
+
|
| 265 |
+
def test_view_calendar_empty(self) -> None:
|
| 266 |
+
scenario = get_scenario("easy_001")
|
| 267 |
+
assert scenario is not None
|
| 268 |
+
world = WorldState(scenario)
|
| 269 |
+
result = world.view_calendar("2099-01-01")
|
| 270 |
+
assert "No events" in result
|
| 271 |
+
|
| 272 |
+
def test_check_availability(self) -> None:
|
| 273 |
+
scenario = get_scenario("easy_003")
|
| 274 |
+
assert scenario is not None
|
| 275 |
+
world = WorldState(scenario)
|
| 276 |
+
result = world.check_availability("Client_Jones")
|
| 277 |
+
assert "09:00" in result
|
| 278 |
+
|
| 279 |
+
def test_check_availability_unknown(self) -> None:
|
| 280 |
+
scenario = get_scenario("easy_001")
|
| 281 |
+
assert scenario is not None
|
| 282 |
+
world = WorldState(scenario)
|
| 283 |
+
result = world.check_availability("NonExistentPerson")
|
| 284 |
+
assert "not found" in result
|
| 285 |
+
|
| 286 |
+
def test_search_restaurants_filters(self) -> None:
|
| 287 |
+
scenario = get_scenario("med_007")
|
| 288 |
+
assert scenario is not None
|
| 289 |
+
world = WorldState(scenario)
|
| 290 |
+
result = world.search_restaurants(dietary="vegan", max_price=45, max_distance_miles=3.0)
|
| 291 |
+
assert "Green Garden" in result
|
| 292 |
+
assert "Steak House Prime" not in result
|
| 293 |
+
|
| 294 |
+
def test_schedule_meeting_creates_commitment(self) -> None:
|
| 295 |
+
scenario = get_scenario("easy_002")
|
| 296 |
+
assert scenario is not None
|
| 297 |
+
world = WorldState(scenario)
|
| 298 |
+
result = world.schedule_meeting("Test Meeting", "2026-04-25", "14:00", turn=1)
|
| 299 |
+
assert "scheduled" in result.lower()
|
| 300 |
+
assert len(world.commitment_ledger) == 1
|
| 301 |
+
assert world.commitment_ledger[0].commitment_type == "meeting_scheduled"
|
| 302 |
+
|
| 303 |
+
def test_schedule_meeting_conflict(self) -> None:
|
| 304 |
+
scenario = get_scenario("easy_001")
|
| 305 |
+
assert scenario is not None
|
| 306 |
+
world = WorldState(scenario)
|
| 307 |
+
result = world.schedule_meeting("Conflicting", "2026-04-25", "14:00", turn=1)
|
| 308 |
+
assert "CONFLICT" in result
|
| 309 |
+
|
| 310 |
+
def test_reschedule_event(self) -> None:
|
| 311 |
+
scenario = get_scenario("easy_001")
|
| 312 |
+
assert scenario is not None
|
| 313 |
+
world = WorldState(scenario)
|
| 314 |
+
result = world.reschedule_event("evt_2", "15:00", turn=1)
|
| 315 |
+
assert "Rescheduled" in result
|
| 316 |
+
assert world.calendar["evt_2"].time == "15:00"
|
| 317 |
+
|
| 318 |
+
def test_cancel_event(self) -> None:
|
| 319 |
+
scenario = get_scenario("easy_001")
|
| 320 |
+
assert scenario is not None
|
| 321 |
+
world = WorldState(scenario)
|
| 322 |
+
result = world.cancel_event("evt_2", turn=1)
|
| 323 |
+
assert "Cancelled" in result
|
| 324 |
+
assert "evt_2" not in world.calendar
|
| 325 |
+
|
| 326 |
+
def test_send_email(self) -> None:
|
| 327 |
+
scenario = get_scenario("easy_001")
|
| 328 |
+
assert scenario is not None
|
| 329 |
+
world = WorldState(scenario)
|
| 330 |
+
result = world.send_email("Team", "Hello", "Testing email body", turn=1)
|
| 331 |
+
assert "sent" in result.lower()
|
| 332 |
+
assert len(world.emails_sent) == 1
|
| 333 |
+
|
| 334 |
+
def test_book_restaurant(self) -> None:
|
| 335 |
+
scenario = get_scenario("easy_002")
|
| 336 |
+
assert scenario is not None
|
| 337 |
+
world = WorldState(scenario)
|
| 338 |
+
result = world.book_restaurant("Bella Italia", turn=1)
|
| 339 |
+
assert "confirmed" in result.lower()
|
| 340 |
+
assert world.booked_restaurant == "Bella Italia"
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# ===================================================================
|
| 344 |
+
# 5. Commitment ledger behaviour
|
| 345 |
+
# ===================================================================
|
| 346 |
+
|
| 347 |
+
class TestCommitmentLedger:
|
| 348 |
+
def test_schedule_creates_commitment(self) -> None:
|
| 349 |
+
scenario = get_scenario("easy_002")
|
| 350 |
+
assert scenario is not None
|
| 351 |
+
world = WorldState(scenario)
|
| 352 |
+
world.schedule_meeting("Test", "2026-04-25", "10:00", turn=1)
|
| 353 |
+
assert len(world.commitment_ledger) == 1
|
| 354 |
+
c = world.commitment_ledger[0]
|
| 355 |
+
assert c.turn_created == 1
|
| 356 |
+
assert c.active is True
|
| 357 |
+
assert c.renegotiated_at is None
|
| 358 |
+
|
| 359 |
+
def test_reschedule_marks_old_renegotiated(self) -> None:
|
| 360 |
+
scenario = get_scenario("easy_001")
|
| 361 |
+
assert scenario is not None
|
| 362 |
+
world = WorldState(scenario)
|
| 363 |
+
world.reschedule_event("evt_2", "15:00", turn=1)
|
| 364 |
+
renegotiated = [c for c in world.commitment_ledger if c.renegotiated_at is not None]
|
| 365 |
+
assert len(renegotiated) == 0 # initial events don't create ledger entries
|
| 366 |
+
new_commits = [c for c in world.commitment_ledger if c.active]
|
| 367 |
+
assert len(new_commits) >= 1
|
| 368 |
+
|
| 369 |
+
def test_email_renegotiation_detection(self) -> None:
|
| 370 |
+
scenario = get_scenario("easy_001")
|
| 371 |
+
assert scenario is not None
|
| 372 |
+
world = WorldState(scenario)
|
| 373 |
+
world.schedule_meeting("Important", "2026-04-25", "16:00", participants=["Alice"], turn=1)
|
| 374 |
+
world.send_email("Alice", "Change of plans", "I need to reschedule our meeting", turn=2)
|
| 375 |
+
renegotiated = [c for c in world.commitment_ledger if c.renegotiated_at is not None]
|
| 376 |
+
assert len(renegotiated) >= 1
|
| 377 |
+
|
| 378 |
+
def test_cancel_personal_marks_renegotiated(self) -> None:
|
| 379 |
+
scenario = get_scenario("easy_001")
|
| 380 |
+
assert scenario is not None
|
| 381 |
+
world = WorldState(scenario)
|
| 382 |
+
# evt_3 is Lunch (personal)
|
| 383 |
+
world.cancel_event("evt_3", turn=1)
|
| 384 |
+
# Personal cancellations are auto-OK
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# ===================================================================
|
| 388 |
+
# 6. Full scenario scoring
|
| 389 |
+
# ===================================================================
|
| 390 |
+
|
| 391 |
+
class TestFullScoring:
|
| 392 |
+
def test_perfect_easy_001(self, env: CommitmentEnvironment) -> None:
|
| 393 |
+
env.reset(task_id="easy_001")
|
| 394 |
+
env.step(CommitmentAction(action_type="reschedule_event", event_id="evt_2", new_time="15:00"))
|
| 395 |
+
env.step(CommitmentAction(action_type="send_email", to="Team", subject="Standup moved", body="Hi team, I've rescheduled the standup to 3:00 PM. Sorry for the move."))
|
| 396 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 397 |
+
assert obs.done is True
|
| 398 |
+
assert obs.reward >= 0.85
|
| 399 |
+
|
| 400 |
+
def test_zero_effort_gets_low_score(self, env: CommitmentEnvironment) -> None:
|
| 401 |
+
env.reset(task_id="easy_001")
|
| 402 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 403 |
+
assert obs.done is True
|
| 404 |
+
assert obs.reward <= 0.50
|
| 405 |
+
|
| 406 |
+
def test_hard_011_perfect_run(self, env: CommitmentEnvironment) -> None:
|
| 407 |
+
env.reset(task_id="hard_011")
|
| 408 |
+
env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 409 |
+
env.step(CommitmentAction(action_type="cancel_event", event_id="evt_90"))
|
| 410 |
+
env.step(CommitmentAction(action_type="search_restaurants", dietary="vegetarian", near_airport=True, max_price=60))
|
| 411 |
+
env.step(CommitmentAction(action_type="book_restaurant", restaurant_name="Sky Lounge"))
|
| 412 |
+
env.step(CommitmentAction(action_type="send_email", to="Team", subject="Happy Hour Rescheduled", body="Sorry team, I need to reschedule the happy hour to Thursday. An investor dinner came up tonight. Apologies!"))
|
| 413 |
+
env.step(CommitmentAction(action_type="send_email", to="VP_Chen", subject="Investor dinner plan", body="I've booked Sky Lounge for dinner tonight with Investor_Park. Vegetarian options available, near the airport."))
|
| 414 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 415 |
+
assert obs.done is True
|
| 416 |
+
assert obs.reward >= 0.85
|
| 417 |
+
|
| 418 |
+
def test_hard_015_sre_crisis(self, env: CommitmentEnvironment) -> None:
|
| 419 |
+
env.reset(task_id="hard_015")
|
| 420 |
+
env.step(CommitmentAction(action_type="view_calendar", date="2026-04-25"))
|
| 421 |
+
env.step(CommitmentAction(action_type="cancel_event", event_id="evt_130"))
|
| 422 |
+
env.step(CommitmentAction(action_type="send_email", to="Team", subject="Lunch cancelled - incident", body="Team, I'm cancelling our lunch due to a production incident. Payment service returning 503s. Will handle this first."))
|
| 423 |
+
env.step(CommitmentAction(action_type="send_email", to="Client_Jones", subject="Demo reschedule needed", body="Hi Client_Jones, I sincerely apologize but I need to reschedule our demo. We have a production incident with the payment system. Can we find another time this week?"))
|
| 424 |
+
env.step(CommitmentAction(action_type="send_email", to="VP_Chen", subject="Incident + 1-on-1", body="VP_Chen, we have a production incident — payment service is returning 503s. I'm on-call and handling it. May need to reschedule our 1-on-1 depending on resolution time."))
|
| 425 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 426 |
+
assert obs.done is True
|
| 427 |
+
assert obs.reward >= 0.60
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ===================================================================
|
| 431 |
+
# 7. Reward clamping
|
| 432 |
+
# ===================================================================
|
| 433 |
+
|
| 434 |
+
class TestRewardClamping:
|
| 435 |
+
def test_reward_never_zero(self, env: CommitmentEnvironment) -> None:
|
| 436 |
+
env.reset(task_id="easy_001")
|
| 437 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 438 |
+
assert obs.reward >= 0.01
|
| 439 |
+
|
| 440 |
+
def test_reward_never_one(self, env: CommitmentEnvironment) -> None:
|
| 441 |
+
env.reset(task_id="easy_001")
|
| 442 |
+
env.step(CommitmentAction(action_type="reschedule_event", event_id="evt_2", new_time="15:00"))
|
| 443 |
+
env.step(CommitmentAction(action_type="send_email", to="Team", subject="Standup moved", body="Hi team, the standup is rescheduled to 3pm. Sorry for the move."))
|
| 444 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 445 |
+
assert obs.reward <= 0.99
|
| 446 |
+
assert obs.reward > 0.01
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
# ===================================================================
|
| 450 |
+
# 8. Time utility
|
| 451 |
+
# ===================================================================
|
| 452 |
+
|
| 453 |
+
class TestTimeUtil:
|
| 454 |
+
def test_time_to_min(self) -> None:
|
| 455 |
+
assert _time_to_min("00:00") == 0
|
| 456 |
+
assert _time_to_min("09:30") == 570
|
| 457 |
+
assert _time_to_min("14:00") == 840
|
| 458 |
+
assert _time_to_min("23:59") == 1439
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# ===================================================================
|
| 462 |
+
# 9. API endpoint tests (via TestClient)
|
| 463 |
+
# ===================================================================
|
| 464 |
+
|
| 465 |
+
class TestAPI:
|
| 466 |
+
@pytest.fixture
|
| 467 |
+
def client(self):
|
| 468 |
+
from fastapi.testclient import TestClient
|
| 469 |
+
from server.app import app
|
| 470 |
+
return TestClient(app)
|
| 471 |
+
|
| 472 |
+
def test_health(self, client) -> None:
|
| 473 |
+
resp = client.get("/health")
|
| 474 |
+
assert resp.status_code == 200
|
| 475 |
+
|
| 476 |
+
def test_tasks(self, client) -> None:
|
| 477 |
+
resp = client.get("/tasks")
|
| 478 |
+
assert resp.status_code == 200
|
| 479 |
+
data = resp.json()
|
| 480 |
+
assert len(data["easy"]) == 5
|
| 481 |
+
assert len(data["medium"]) == 5
|
| 482 |
+
assert len(data["hard"]) == 5
|
| 483 |
+
|
| 484 |
+
def test_reset_step_state(self, client) -> None:
|
| 485 |
+
resp = client.post("/reset", params={"task_id": "easy_001"})
|
| 486 |
+
assert resp.status_code == 200
|
| 487 |
+
|
| 488 |
+
resp = client.post("/step", json={"action": {"action_type": "view_calendar", "date": "2026-04-25"}})
|
| 489 |
+
assert resp.status_code == 200
|
| 490 |
+
data = resp.json()
|
| 491 |
+
assert data.get("done") is False
|
| 492 |
+
|
| 493 |
+
resp = client.get("/state")
|
| 494 |
+
assert resp.status_code == 200
|
| 495 |
+
state = resp.json()
|
| 496 |
+
assert "step_count" in state
|
| 497 |
+
|
| 498 |
+
def test_mcp_initialize(self, client) -> None:
|
| 499 |
+
resp = client.post("/mcp", json={
|
| 500 |
+
"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {},
|
| 501 |
+
})
|
| 502 |
+
assert resp.status_code == 200
|
| 503 |
+
data = resp.json()
|
| 504 |
+
assert data["result"]["serverInfo"]["name"] == "commitment-os"
|
| 505 |
+
|
| 506 |
+
def test_mcp_tools_list(self, client) -> None:
|
| 507 |
+
resp = client.post("/mcp", json={
|
| 508 |
+
"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {},
|
| 509 |
+
})
|
| 510 |
+
assert resp.status_code == 200
|
| 511 |
+
tools = resp.json()["result"]["tools"]
|
| 512 |
+
assert len(tools) == 3
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
# ===================================================================
|
| 516 |
+
# 10. Metadata
|
| 517 |
+
# ===================================================================
|
| 518 |
+
|
| 519 |
+
class TestMetadata:
|
| 520 |
+
def test_get_metadata(self, env: CommitmentEnvironment) -> None:
|
| 521 |
+
meta = env.get_metadata()
|
| 522 |
+
assert meta.name == "commitment-os"
|
| 523 |
+
assert "Jayant" in meta.author
|
training/__init__.py
ADDED
|
File without changes
|
training/env_factory.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Environment factory for TRL GRPOTrainer integration.
|
| 2 |
+
|
| 3 |
+
Wraps CommitmentOS as a callable that accepts model completions and
|
| 4 |
+
returns rewards, making it compatible with TRL's ``environment_factory``
|
| 5 |
+
pattern for multi-turn RL training.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 16 |
+
|
| 17 |
+
from server.domain import ScenarioDef
|
| 18 |
+
from server.environment import CommitmentEnvironment
|
| 19 |
+
from server.tasks import get_all_scenarios
|
| 20 |
+
from models import CommitmentAction
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
TOOL_DESCRIPTIONS = """Available tools (respond with JSON):
|
| 24 |
+
- {"action_type": "view_calendar", "date": "2026-04-25"}
|
| 25 |
+
- {"action_type": "check_availability", "person": "Name"}
|
| 26 |
+
- {"action_type": "search_restaurants", "cuisine": "...", "max_price": 50, "dietary": "..."}
|
| 27 |
+
- {"action_type": "schedule_meeting", "title": "...", "date": "...", "time": "HH:MM", "participants": [...]}
|
| 28 |
+
- {"action_type": "reschedule_event", "event_id": "evt_X", "new_time": "HH:MM"}
|
| 29 |
+
- {"action_type": "cancel_event", "event_id": "evt_X"}
|
| 30 |
+
- {"action_type": "send_email", "to": "Name", "subject": "...", "body": "..."}
|
| 31 |
+
- {"action_type": "book_restaurant", "restaurant_name": "..."}
|
| 32 |
+
- {"action_type": "submit_plan"}"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_system_prompt() -> str:
|
| 36 |
+
return (
|
| 37 |
+
"You are an expert executive assistant AI managing calendars, emails, and "
|
| 38 |
+
"dining reservations. For each turn, respond with EXACTLY ONE JSON tool call.\n\n"
|
| 39 |
+
f"{TOOL_DESCRIPTIONS}\n\n"
|
| 40 |
+
"Rules:\n"
|
| 41 |
+
"1. Respond with ONLY JSON, no markdown or explanation\n"
|
| 42 |
+
"2. Handle higher-priority items first\n"
|
| 43 |
+
"3. When cancelling/rescheduling commitments, ALWAYS email affected parties\n"
|
| 44 |
+
"4. Call submit_plan when all issues are resolved\n"
|
| 45 |
+
"5. Never silently drop a commitment"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def build_initial_prompt(scenario: ScenarioDef) -> str:
|
| 50 |
+
"""Build the user message for the first turn of an episode."""
|
| 51 |
+
from server.world import WorldState
|
| 52 |
+
|
| 53 |
+
world = WorldState(scenario)
|
| 54 |
+
calendar = json.dumps(world.get_calendar_snapshot(), indent=2)
|
| 55 |
+
inbox = json.dumps(world.get_inbox_snapshot(), indent=2)
|
| 56 |
+
|
| 57 |
+
return (
|
| 58 |
+
f"SCENARIO: {scenario.briefing}\n\n"
|
| 59 |
+
f"CALENDAR:\n{calendar}\n\n"
|
| 60 |
+
f"INBOX:\n{inbox}\n\n"
|
| 61 |
+
"What is your first action? Respond with a JSON tool call."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def parse_action_from_text(text: str) -> Dict[str, Any]:
|
| 66 |
+
"""Extract a JSON action from model output, with fallback to submit."""
|
| 67 |
+
text = text.strip()
|
| 68 |
+
if text.startswith("```"):
|
| 69 |
+
lines = text.split("\n")
|
| 70 |
+
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
|
| 71 |
+
try:
|
| 72 |
+
data = json.loads(text)
|
| 73 |
+
if isinstance(data, dict) and "action_type" in data:
|
| 74 |
+
return data
|
| 75 |
+
except (json.JSONDecodeError, ValueError):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
for line in text.split("\n"):
|
| 79 |
+
line = line.strip()
|
| 80 |
+
if line.startswith("{"):
|
| 81 |
+
try:
|
| 82 |
+
data = json.loads(line)
|
| 83 |
+
if isinstance(data, dict) and "action_type" in data:
|
| 84 |
+
return data
|
| 85 |
+
except (json.JSONDecodeError, ValueError):
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
return {"action_type": "submit_plan"}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CommitmentOSEnvFactory:
|
| 92 |
+
"""Wraps CommitmentOS for use with TRL's GRPOTrainer.
|
| 93 |
+
|
| 94 |
+
Usage with TRL::
|
| 95 |
+
|
| 96 |
+
from training.env_factory import CommitmentOSEnvFactory
|
| 97 |
+
|
| 98 |
+
factory = CommitmentOSEnvFactory(max_turns=8)
|
| 99 |
+
|
| 100 |
+
trainer = GRPOTrainer(
|
| 101 |
+
...
|
| 102 |
+
environment_factory=factory,
|
| 103 |
+
)
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
max_turns: int = 8,
|
| 109 |
+
scenario_ids: Optional[List[str]] = None,
|
| 110 |
+
) -> None:
|
| 111 |
+
self.max_turns = max_turns
|
| 112 |
+
self.scenario_ids = scenario_ids or list(get_all_scenarios().keys())
|
| 113 |
+
self.system_prompt = build_system_prompt()
|
| 114 |
+
|
| 115 |
+
def __call__(self, completions: List[str], **kwargs: Any) -> List[float]:
|
| 116 |
+
"""Evaluate a batch of model completions.
|
| 117 |
+
|
| 118 |
+
Each completion is treated as a full multi-turn transcript where
|
| 119 |
+
each line is one JSON action. Returns a list of final rewards.
|
| 120 |
+
"""
|
| 121 |
+
rewards: List[float] = []
|
| 122 |
+
for completion in completions:
|
| 123 |
+
reward = self._evaluate_single(completion)
|
| 124 |
+
rewards.append(reward)
|
| 125 |
+
return rewards
|
| 126 |
+
|
| 127 |
+
def _evaluate_single(self, completion: str) -> float:
|
| 128 |
+
import random
|
| 129 |
+
|
| 130 |
+
env = CommitmentEnvironment()
|
| 131 |
+
scenario_id = random.choice(self.scenario_ids)
|
| 132 |
+
env.reset(task_id=scenario_id)
|
| 133 |
+
|
| 134 |
+
actions = completion.strip().split("\n")
|
| 135 |
+
last_reward = 0.01
|
| 136 |
+
|
| 137 |
+
for i, action_text in enumerate(actions[: self.max_turns]):
|
| 138 |
+
action_dict = parse_action_from_text(action_text)
|
| 139 |
+
try:
|
| 140 |
+
action = CommitmentAction(**action_dict)
|
| 141 |
+
obs = env.step(action)
|
| 142 |
+
last_reward = obs.reward
|
| 143 |
+
if obs.done:
|
| 144 |
+
break
|
| 145 |
+
except Exception:
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
if not env._done:
|
| 149 |
+
obs = env.step(CommitmentAction(action_type="submit_plan"))
|
| 150 |
+
last_reward = obs.reward
|
| 151 |
+
|
| 152 |
+
return float(last_reward)
|
| 153 |
+
|
| 154 |
+
def get_prompt(self, scenario_id: Optional[str] = None) -> List[Dict[str, str]]:
|
| 155 |
+
"""Build chat messages for a scenario."""
|
| 156 |
+
import random
|
| 157 |
+
from server.tasks import get_scenario
|
| 158 |
+
|
| 159 |
+
sid = scenario_id or random.choice(self.scenario_ids)
|
| 160 |
+
scenario = get_scenario(sid)
|
| 161 |
+
if scenario is None:
|
| 162 |
+
raise ValueError(f"Unknown scenario: {sid}")
|
| 163 |
+
|
| 164 |
+
return [
|
| 165 |
+
{"role": "system", "content": self.system_prompt},
|
| 166 |
+
{"role": "user", "content": build_initial_prompt(scenario)},
|
| 167 |
+
]
|
training/train_grpo.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO training script for CommitmentOS.
|
| 2 |
+
|
| 3 |
+
Uses TRL's GRPOTrainer with LoRA to train Qwen2.5-1.5B-Instruct on
|
| 4 |
+
temporal commitment coherence tasks.
|
| 5 |
+
|
| 6 |
+
Designed for Google Colab A100 or similar GPU environments.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python training/train_grpo.py [--model MODEL] [--epochs N] [--lr LR]
|
| 10 |
+
|
| 11 |
+
Environment variables:
|
| 12 |
+
HF_TOKEN — HuggingFace token for model upload (optional)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Optional
|
| 24 |
+
|
| 25 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args() -> argparse.Namespace:
|
| 29 |
+
parser = argparse.ArgumentParser(description="GRPO training for CommitmentOS")
|
| 30 |
+
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Base model")
|
| 31 |
+
parser.add_argument("--epochs", type=int, default=2, help="Number of training epochs")
|
| 32 |
+
parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate")
|
| 33 |
+
parser.add_argument("--batch_size", type=int, default=4, help="Per-device batch size")
|
| 34 |
+
parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 for full epochs)")
|
| 35 |
+
parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank")
|
| 36 |
+
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
|
| 37 |
+
parser.add_argument("--output_dir", default="./training_output", help="Output directory")
|
| 38 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Push model to HuggingFace Hub")
|
| 39 |
+
parser.add_argument("--hub_model_id", default="jayant2304/commitmentos-qwen-grpo", help="HF Hub model ID")
|
| 40 |
+
parser.add_argument("--num_scenarios", type=int, default=15, help="Number of scenarios to use")
|
| 41 |
+
parser.add_argument("--max_turns", type=int, default=8, help="Max turns per episode")
|
| 42 |
+
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size (completions per prompt)")
|
| 43 |
+
return parser.parse_args()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_dataset(num_scenarios: int = 15) -> List[Dict[str, Any]]:
|
| 47 |
+
"""Build training dataset from CommitmentOS scenarios."""
|
| 48 |
+
from server.tasks import get_all_scenarios
|
| 49 |
+
from training.env_factory import build_initial_prompt, build_system_prompt
|
| 50 |
+
|
| 51 |
+
scenarios = list(get_all_scenarios().values())[:num_scenarios]
|
| 52 |
+
system_prompt = build_system_prompt()
|
| 53 |
+
dataset: List[Dict[str, Any]] = []
|
| 54 |
+
|
| 55 |
+
for scenario in scenarios:
|
| 56 |
+
user_msg = build_initial_prompt(scenario)
|
| 57 |
+
dataset.append({
|
| 58 |
+
"prompt": [
|
| 59 |
+
{"role": "system", "content": system_prompt},
|
| 60 |
+
{"role": "user", "content": user_msg},
|
| 61 |
+
],
|
| 62 |
+
"scenario_id": scenario.scenario_id,
|
| 63 |
+
"difficulty": scenario.difficulty,
|
| 64 |
+
})
|
| 65 |
+
|
| 66 |
+
return dataset
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def reward_function(completions: List[str], **kwargs: Any) -> List[float]:
|
| 70 |
+
"""Reward function for GRPO — evaluates completions against CommitmentOS."""
|
| 71 |
+
from training.env_factory import CommitmentOSEnvFactory
|
| 72 |
+
|
| 73 |
+
factory = CommitmentOSEnvFactory(max_turns=8)
|
| 74 |
+
return factory(completions)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main() -> None:
|
| 78 |
+
args = parse_args()
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
import torch
|
| 82 |
+
from datasets import Dataset
|
| 83 |
+
from peft import LoraConfig
|
| 84 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 85 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 86 |
+
except ImportError as e:
|
| 87 |
+
print(f"Missing training dependency: {e}")
|
| 88 |
+
print("Install with: pip install trl transformers peft datasets torch")
|
| 89 |
+
sys.exit(1)
|
| 90 |
+
|
| 91 |
+
print(f"Loading model: {args.model}")
|
| 92 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 93 |
+
if tokenizer.pad_token is None:
|
| 94 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 95 |
+
|
| 96 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
+
args.model,
|
| 98 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 99 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 100 |
+
trust_remote_code=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
lora_config = LoraConfig(
|
| 104 |
+
r=args.lora_rank,
|
| 105 |
+
lora_alpha=args.lora_alpha,
|
| 106 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 107 |
+
lora_dropout=0.05,
|
| 108 |
+
task_type="CAUSAL_LM",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
print("Building dataset...")
|
| 112 |
+
raw_data = build_dataset(args.num_scenarios)
|
| 113 |
+
dataset = Dataset.from_list(raw_data)
|
| 114 |
+
|
| 115 |
+
training_config = GRPOConfig(
|
| 116 |
+
output_dir=args.output_dir,
|
| 117 |
+
num_train_epochs=args.epochs,
|
| 118 |
+
max_steps=args.max_steps,
|
| 119 |
+
per_device_train_batch_size=args.batch_size,
|
| 120 |
+
learning_rate=args.lr,
|
| 121 |
+
logging_steps=1,
|
| 122 |
+
save_steps=50,
|
| 123 |
+
bf16=torch.cuda.is_available(),
|
| 124 |
+
gradient_accumulation_steps=2,
|
| 125 |
+
warmup_ratio=0.1,
|
| 126 |
+
max_completion_length=512,
|
| 127 |
+
num_generations=args.group_size,
|
| 128 |
+
report_to="none",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
print("Initialising GRPOTrainer...")
|
| 132 |
+
trainer = GRPOTrainer(
|
| 133 |
+
model=model,
|
| 134 |
+
config=training_config,
|
| 135 |
+
train_dataset=dataset,
|
| 136 |
+
processing_class=tokenizer,
|
| 137 |
+
reward_funcs=reward_function,
|
| 138 |
+
peft_config=lora_config,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
print("Starting training...")
|
| 142 |
+
trainer.train()
|
| 143 |
+
|
| 144 |
+
print(f"Saving model to {args.output_dir}")
|
| 145 |
+
trainer.save_model(args.output_dir)
|
| 146 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 147 |
+
|
| 148 |
+
if args.push_to_hub:
|
| 149 |
+
hf_token = os.getenv("HF_TOKEN", "")
|
| 150 |
+
if hf_token:
|
| 151 |
+
print(f"Pushing to hub: {args.hub_model_id}")
|
| 152 |
+
trainer.push_to_hub(args.hub_model_id, token=hf_token)
|
| 153 |
+
else:
|
| 154 |
+
print("HF_TOKEN not set — skipping hub push")
|
| 155 |
+
|
| 156 |
+
print("Training complete!")
|
| 157 |
+
|
| 158 |
+
save_training_metrics(trainer, args.output_dir)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def save_training_metrics(trainer: Any, output_dir: str) -> None:
|
| 162 |
+
"""Save training metrics to JSON for plotting training curves."""
|
| 163 |
+
output_path = Path(output_dir)
|
| 164 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
history = trainer.state.log_history if hasattr(trainer.state, "log_history") else []
|
| 167 |
+
metrics_file = output_path / "training_metrics.json"
|
| 168 |
+
with open(metrics_file, "w") as f:
|
| 169 |
+
json.dump(history, f, indent=2)
|
| 170 |
+
print(f"Training metrics saved to {metrics_file}")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|