Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +249 -4
- __init__.py +11 -0
- client.py +63 -0
- models.py +46 -0
- openenv.yaml +7 -0
- pyproject.toml +45 -0
- server/__init__.py +11 -0
- server/app.py +32 -0
- server/exec_assistant_arena_environment.py +211 -0
- server/requirements.txt +6 -0
- server/reward.py +207 -0
- server/scenario_generator.py +266 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=exec_assistant_arena
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,255 @@
|
|
| 1 |
---
|
| 2 |
-
title: Exec Assistant Arena
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Exec Assistant Arena Environment Server
|
| 3 |
+
emoji: 🎣
|
| 4 |
+
colorFrom: gray
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Exec Assistant Arena Environment
|
| 15 |
+
|
| 16 |
+
A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
The simplest way to use the Exec Assistant Arena environment is through the `ExecAssistantArenaEnv` class:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from exec_assistant_arena import ExecAssistantArenaAction, ExecAssistantArenaEnv
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Create environment from Docker image
|
| 27 |
+
exec_assistant_arenaenv = ExecAssistantArenaEnv.from_docker_image("exec_assistant_arena-env:latest")
|
| 28 |
+
|
| 29 |
+
# Reset
|
| 30 |
+
result = exec_assistant_arenaenv.reset()
|
| 31 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 32 |
+
|
| 33 |
+
# Send multiple messages
|
| 34 |
+
messages = ["Hello, World!", "Testing echo", "Final message"]
|
| 35 |
+
|
| 36 |
+
for msg in messages:
|
| 37 |
+
result = exec_assistant_arenaenv.step(ExecAssistantArenaAction(message=msg))
|
| 38 |
+
print(f"Sent: '{msg}'")
|
| 39 |
+
print(f" → Echoed: '{result.observation.echoed_message}'")
|
| 40 |
+
print(f" → Length: {result.observation.message_length}")
|
| 41 |
+
print(f" → Reward: {result.reward}")
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
# Always clean up
|
| 45 |
+
exec_assistant_arenaenv.close()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
That's it! The `ExecAssistantArenaEnv.from_docker_image()` method handles:
|
| 49 |
+
- Starting the Docker container
|
| 50 |
+
- Waiting for the server to be ready
|
| 51 |
+
- Connecting to the environment
|
| 52 |
+
- Container cleanup when you call `close()`
|
| 53 |
+
|
| 54 |
+
## Building the Docker Image
|
| 55 |
+
|
| 56 |
+
Before using the environment, you need to build the Docker image:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# From project root
|
| 60 |
+
docker build -t exec_assistant_arena-env:latest -f server/Dockerfile .
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Deploying to Hugging Face Spaces
|
| 64 |
+
|
| 65 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# From the environment directory (where openenv.yaml is located)
|
| 69 |
+
openenv push
|
| 70 |
+
|
| 71 |
+
# Or specify options
|
| 72 |
+
openenv push --namespace my-org --private
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `openenv push` command will:
|
| 76 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 77 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 78 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 79 |
+
|
| 80 |
+
### Prerequisites
|
| 81 |
+
|
| 82 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 83 |
+
|
| 84 |
+
### Options
|
| 85 |
+
|
| 86 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 87 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 88 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 89 |
+
- `--private`: Deploy the space as private (default: public)
|
| 90 |
+
|
| 91 |
+
### Examples
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 95 |
+
openenv push
|
| 96 |
+
|
| 97 |
+
# Push to a specific repository
|
| 98 |
+
openenv push --repo-id my-org/my-env
|
| 99 |
+
|
| 100 |
+
# Push with a custom base image
|
| 101 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 102 |
+
|
| 103 |
+
# Push as a private space
|
| 104 |
+
openenv push --private
|
| 105 |
+
|
| 106 |
+
# Combine options
|
| 107 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
After deployment, your space will be available at:
|
| 111 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 112 |
+
|
| 113 |
+
The deployed space includes:
|
| 114 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 115 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 116 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 117 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 118 |
+
|
| 119 |
+
## Environment Details
|
| 120 |
+
|
| 121 |
+
### Action
|
| 122 |
+
**ExecAssistantArenaAction**: Contains a single field
|
| 123 |
+
- `message` (str) - The message to echo back
|
| 124 |
+
|
| 125 |
+
### Observation
|
| 126 |
+
**ExecAssistantArenaObservation**: Contains the echo response and metadata
|
| 127 |
+
- `echoed_message` (str) - The message echoed back
|
| 128 |
+
- `message_length` (int) - Length of the message
|
| 129 |
+
- `reward` (float) - Reward based on message length (length × 0.1)
|
| 130 |
+
- `done` (bool) - Always False for echo environment
|
| 131 |
+
- `metadata` (dict) - Additional info like step count
|
| 132 |
+
|
| 133 |
+
### Reward
|
| 134 |
+
The reward is calculated as: `message_length × 0.1`
|
| 135 |
+
- "Hi" → reward: 0.2
|
| 136 |
+
- "Hello, World!" → reward: 1.3
|
| 137 |
+
- Empty message → reward: 0.0
|
| 138 |
+
|
| 139 |
+
## Advanced Usage
|
| 140 |
+
|
| 141 |
+
### Connecting to an Existing Server
|
| 142 |
+
|
| 143 |
+
If you already have a Exec Assistant Arena environment server running, you can connect directly:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from exec_assistant_arena import ExecAssistantArenaEnv
|
| 147 |
+
|
| 148 |
+
# Connect to existing server
|
| 149 |
+
exec_assistant_arenaenv = ExecAssistantArenaEnv(base_url="<ENV_HTTP_URL_HERE>")
|
| 150 |
+
|
| 151 |
+
# Use as normal
|
| 152 |
+
result = exec_assistant_arenaenv.reset()
|
| 153 |
+
result = exec_assistant_arenaenv.step(ExecAssistantArenaAction(message="Hello!"))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note: When connecting to an existing server, `exec_assistant_arenaenv.close()` will NOT stop the server.
|
| 157 |
+
|
| 158 |
+
### Using the Context Manager
|
| 159 |
+
|
| 160 |
+
The client supports context manager usage for automatic connection management:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from exec_assistant_arena import ExecAssistantArenaAction, ExecAssistantArenaEnv
|
| 164 |
+
|
| 165 |
+
# Connect with context manager (auto-connects and closes)
|
| 166 |
+
with ExecAssistantArenaEnv(base_url="http://localhost:8000") as env:
|
| 167 |
+
result = env.reset()
|
| 168 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 169 |
+
# Multiple steps with low latency
|
| 170 |
+
for msg in ["Hello", "World", "!"]:
|
| 171 |
+
result = env.step(ExecAssistantArenaAction(message=msg))
|
| 172 |
+
print(f"Echoed: {result.observation.echoed_message}")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
The client uses WebSocket connections for:
|
| 176 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 177 |
+
- **Persistent session**: Server maintains your environment state
|
| 178 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 179 |
+
|
| 180 |
+
### Concurrent WebSocket Sessions
|
| 181 |
+
|
| 182 |
+
The server supports multiple concurrent WebSocket connections. To enable this,
|
| 183 |
+
modify `server/app.py` to use factory mode:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
# In server/app.py - use factory mode for concurrent sessions
|
| 187 |
+
app = create_app(
|
| 188 |
+
ExecAssistantArenaEnvironment, # Pass class, not instance
|
| 189 |
+
ExecAssistantArenaAction,
|
| 190 |
+
ExecAssistantArenaObservation,
|
| 191 |
+
max_concurrent_envs=4, # Allow 4 concurrent sessions
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Then multiple clients can connect simultaneously:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from exec_assistant_arena import ExecAssistantArenaAction, ExecAssistantArenaEnv
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 200 |
+
|
| 201 |
+
def run_episode(client_id: int):
|
| 202 |
+
with ExecAssistantArenaEnv(base_url="http://localhost:8000") as env:
|
| 203 |
+
result = env.reset()
|
| 204 |
+
for i in range(10):
|
| 205 |
+
result = env.step(ExecAssistantArenaAction(message=f"Client {client_id}, step {i}"))
|
| 206 |
+
return client_id, result.observation.message_length
|
| 207 |
+
|
| 208 |
+
# Run 4 episodes concurrently
|
| 209 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 210 |
+
results = list(executor.map(run_episode, range(4)))
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Development & Testing
|
| 214 |
+
|
| 215 |
+
### Direct Environment Testing
|
| 216 |
+
|
| 217 |
+
Test the environment logic directly without starting the HTTP server:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
# From the server directory
|
| 221 |
+
python3 server/exec_assistant_arena_environment.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
This verifies that:
|
| 225 |
+
- Environment resets correctly
|
| 226 |
+
- Step executes actions properly
|
| 227 |
+
- State tracking works
|
| 228 |
+
- Rewards are calculated correctly
|
| 229 |
+
|
| 230 |
+
### Running Locally
|
| 231 |
+
|
| 232 |
+
Run the server locally for development:
|
| 233 |
+
|
| 234 |
+
```bash
|
| 235 |
+
uvicorn server.app:app --reload
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Project Structure
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
exec_assistant_arena/
|
| 242 |
+
├── .dockerignore # Docker build exclusions
|
| 243 |
+
├── __init__.py # Module exports
|
| 244 |
+
├── README.md # This file
|
| 245 |
+
├── openenv.yaml # OpenEnv manifest
|
| 246 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 247 |
+
├── uv.lock # Locked dependencies (generated)
|
| 248 |
+
├── client.py # ExecAssistantArenaEnv client
|
| 249 |
+
├── models.py # Action and Observation models
|
| 250 |
+
└── server/
|
| 251 |
+
├── __init__.py # Server module exports
|
| 252 |
+
├── exec_assistant_arena_environment.py # Core environment logic
|
| 253 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 254 |
+
└── Dockerfile # Container image definition
|
| 255 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executive Assistant Arena Environment."""
|
| 2 |
+
|
| 3 |
+
from .client import ExecAssistantArenaEnv
|
| 4 |
+
from .models import AssistantAction, AssistantObservation, AssistantState
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"AssistantAction",
|
| 8 |
+
"AssistantObservation",
|
| 9 |
+
"AssistantState",
|
| 10 |
+
"ExecAssistantArenaEnv",
|
| 11 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executive Assistant Arena Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core.client_types import StepResult
|
| 6 |
+
from openenv.core import EnvClient
|
| 7 |
+
|
| 8 |
+
from .models import AssistantAction, AssistantObservation, AssistantState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ExecAssistantArenaEnv(
|
| 12 |
+
EnvClient[AssistantAction, AssistantObservation, AssistantState]
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Client for the Executive Assistant Arena Environment.
|
| 16 |
+
|
| 17 |
+
Example:
|
| 18 |
+
>>> with ExecAssistantArenaEnv(base_url="http://localhost:8000") as client:
|
| 19 |
+
... result = client.reset(difficulty="medium")
|
| 20 |
+
... result = client.step(AssistantAction(tool="check_calendar"))
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def _step_payload(self, action: AssistantAction) -> Dict:
|
| 24 |
+
return {
|
| 25 |
+
"tool": action.tool,
|
| 26 |
+
"arguments": action.arguments,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def _parse_result(self, payload: Dict) -> StepResult[AssistantObservation]:
|
| 30 |
+
obs_data = payload.get("observation", {})
|
| 31 |
+
observation = AssistantObservation(
|
| 32 |
+
inbox_summary=obs_data.get("inbox_summary", ""),
|
| 33 |
+
calendar_view=obs_data.get("calendar_view", ""),
|
| 34 |
+
pending_tasks=obs_data.get("pending_tasks", []),
|
| 35 |
+
tool_result=obs_data.get("tool_result", ""),
|
| 36 |
+
conflicts=obs_data.get("conflicts", []),
|
| 37 |
+
done=payload.get("done", False),
|
| 38 |
+
reward=payload.get("reward"),
|
| 39 |
+
metadata=obs_data.get("metadata", {}),
|
| 40 |
+
)
|
| 41 |
+
return StepResult(
|
| 42 |
+
observation=observation,
|
| 43 |
+
reward=payload.get("reward"),
|
| 44 |
+
done=payload.get("done", False),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def _parse_state(self, payload: Dict) -> AssistantState:
|
| 48 |
+
return AssistantState(
|
| 49 |
+
episode_id=payload.get("episode_id"),
|
| 50 |
+
step_count=payload.get("step_count", 0),
|
| 51 |
+
conflicts_resolved=payload.get("conflicts_resolved", 0),
|
| 52 |
+
total_conflicts=payload.get("total_conflicts", 0),
|
| 53 |
+
preferences_inferred=payload.get("preferences_inferred", 0),
|
| 54 |
+
total_preferences=payload.get("total_preferences", 0),
|
| 55 |
+
emails_drafted=payload.get("emails_drafted", 0),
|
| 56 |
+
total_emails=payload.get("total_emails", 0),
|
| 57 |
+
deadlines_met=payload.get("deadlines_met", 0),
|
| 58 |
+
deadlines_missed=payload.get("deadlines_missed", 0),
|
| 59 |
+
unnecessary_actions=payload.get("unnecessary_actions", 0),
|
| 60 |
+
late_changes_handled=payload.get("late_changes_handled", 0),
|
| 61 |
+
total_late_changes=payload.get("total_late_changes", 0),
|
| 62 |
+
cumulative_reward=payload.get("cumulative_reward", 0.0),
|
| 63 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for the Executive Assistant Arena Environment."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from pydantic import Field
|
| 5 |
+
|
| 6 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AssistantAction(Action):
|
| 10 |
+
"""Action for the assistant environment - tool calls to manage calendar/email."""
|
| 11 |
+
|
| 12 |
+
tool: str = Field(
|
| 13 |
+
...,
|
| 14 |
+
description="Tool to invoke: check_calendar, check_inbox, reschedule, draft_reply, delegate_task, done",
|
| 15 |
+
)
|
| 16 |
+
arguments: dict = Field(
|
| 17 |
+
default_factory=dict,
|
| 18 |
+
description="Tool arguments, e.g. {'event_id': 'mtg_3', 'new_time': '2pm'}",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AssistantObservation(Observation):
|
| 23 |
+
"""Observation from the assistant environment."""
|
| 24 |
+
|
| 25 |
+
inbox_summary: str = Field(default="", description="Current emails/messages")
|
| 26 |
+
calendar_view: str = Field(default="", description="Today's schedule as text")
|
| 27 |
+
pending_tasks: list[str] = Field(default_factory=list, description="Unresolved items")
|
| 28 |
+
tool_result: str = Field(default="", description="Output of last tool call")
|
| 29 |
+
conflicts: list[str] = Field(default_factory=list, description="Detected scheduling conflicts")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AssistantState(State):
|
| 33 |
+
"""Internal state tracking for the assistant environment."""
|
| 34 |
+
|
| 35 |
+
conflicts_resolved: int = Field(default=0)
|
| 36 |
+
total_conflicts: int = Field(default=0)
|
| 37 |
+
preferences_inferred: int = Field(default=0)
|
| 38 |
+
total_preferences: int = Field(default=0)
|
| 39 |
+
emails_drafted: int = Field(default=0)
|
| 40 |
+
total_emails: int = Field(default=0)
|
| 41 |
+
deadlines_met: int = Field(default=0)
|
| 42 |
+
deadlines_missed: int = Field(default=0)
|
| 43 |
+
unnecessary_actions: int = Field(default=0)
|
| 44 |
+
late_changes_handled: int = Field(default=0)
|
| 45 |
+
total_late_changes: int = Field(default=0)
|
| 46 |
+
cumulative_reward: float = Field(default=0.0)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: exec_assistant_arena
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-exec_assistant_arena"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Exec Assistant Arena environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.0",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m exec_assistant_arena.server.app
|
| 40 |
+
server = "exec_assistant_arena.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["exec_assistant_arena", "exec_assistant_arena.server"]
|
| 45 |
+
package-dir = { "exec_assistant_arena" = ".", "exec_assistant_arena.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Exec Assistant Arena environment server components."""
|
| 8 |
+
|
| 9 |
+
from .exec_assistant_arena_environment import ExecAssistantArenaEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["ExecAssistantArenaEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for the Executive Assistant Arena Environment."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from openenv.core.env_server.http_server import create_app
|
| 5 |
+
except Exception as e:
|
| 6 |
+
raise ImportError(
|
| 7 |
+
"openenv is required. Install with: pip install 'openenv-core[core]>=0.2.1'"
|
| 8 |
+
) from e
|
| 9 |
+
|
| 10 |
+
from models import AssistantAction, AssistantObservation
|
| 11 |
+
from .exec_assistant_arena_environment import ExecAssistantArenaEnvironment
|
| 12 |
+
|
| 13 |
+
app = create_app(
|
| 14 |
+
ExecAssistantArenaEnvironment,
|
| 15 |
+
AssistantAction,
|
| 16 |
+
AssistantObservation,
|
| 17 |
+
env_name="exec_assistant_arena",
|
| 18 |
+
max_concurrent_envs=5,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 23 |
+
import uvicorn
|
| 24 |
+
uvicorn.run(app, host=host, port=port)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
import argparse
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
main(port=args.port)
|
server/exec_assistant_arena_environment.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executive Assistant Arena Environment Implementation."""
|
| 2 |
+
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server.interfaces import Environment
|
| 6 |
+
|
| 7 |
+
from models import AssistantAction, AssistantObservation, AssistantState
|
| 8 |
+
from .scenario_generator import generate_scenario, Scenario, CalendarEvent, TIME_SLOTS
|
| 9 |
+
from .reward import score_reschedule, score_email_reply, score_terminal, RewardBreakdown
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ExecAssistantArenaEnvironment(Environment):
|
| 13 |
+
"""
|
| 14 |
+
An environment that simulates a personal assistant's morning inbox.
|
| 15 |
+
|
| 16 |
+
The agent must resolve calendar conflicts, draft email replies,
|
| 17 |
+
infer user preferences, and handle late-breaking changes.
|
| 18 |
+
|
| 19 |
+
Episodes are 10-20 steps. Rewards are rule-based and decomposed
|
| 20 |
+
into 6 components for training visibility.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self._state = AssistantState(episode_id=str(uuid4()), step_count=0)
|
| 27 |
+
self.scenario: Scenario | None = None
|
| 28 |
+
self.late_change_injected = False
|
| 29 |
+
self.late_change_step: int | None = None
|
| 30 |
+
self.replied_emails: set[str] = set()
|
| 31 |
+
self.reward_breakdown = RewardBreakdown()
|
| 32 |
+
|
| 33 |
+
def reset(self, seed=None, difficulty="medium", **kwargs) -> AssistantObservation:
|
| 34 |
+
"""Reset the environment with a new procedural scenario."""
|
| 35 |
+
if isinstance(seed, str):
|
| 36 |
+
seed = hash(seed) % (2**31)
|
| 37 |
+
|
| 38 |
+
self.scenario = generate_scenario(difficulty, seed)
|
| 39 |
+
self.late_change_injected = False
|
| 40 |
+
self.late_change_step = None
|
| 41 |
+
self.replied_emails = set()
|
| 42 |
+
self.reward_breakdown = RewardBreakdown()
|
| 43 |
+
|
| 44 |
+
self._state = AssistantState(
|
| 45 |
+
episode_id=str(uuid4()),
|
| 46 |
+
step_count=0,
|
| 47 |
+
total_conflicts=len(self.scenario.conflicts),
|
| 48 |
+
total_emails=len([e for e in self.scenario.emails if e.requires_reply]),
|
| 49 |
+
total_preferences=len(self.scenario.preferences),
|
| 50 |
+
total_late_changes=len(self.scenario.late_changes),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Build the welcome observation
|
| 54 |
+
pref_hints = "\n".join(f" - {desc}" for _, desc in self.scenario.preferences)
|
| 55 |
+
|
| 56 |
+
return AssistantObservation(
|
| 57 |
+
inbox_summary=self.scenario.inbox_text(),
|
| 58 |
+
calendar_view=self.scenario.calendar_text(),
|
| 59 |
+
pending_tasks=self.scenario.pending_tasks_text(),
|
| 60 |
+
tool_result=f"Good morning. You have {len(self.scenario.conflicts)} scheduling conflicts and {self._state.total_emails} emails needing replies.\n\nUser preferences:\n{pref_hints}",
|
| 61 |
+
conflicts=self.scenario.conflicts_text(),
|
| 62 |
+
done=False,
|
| 63 |
+
reward=0.0,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def step(self, action: AssistantAction, **kwargs) -> AssistantObservation:
|
| 67 |
+
"""Process one assistant action."""
|
| 68 |
+
if self.scenario is None:
|
| 69 |
+
self.reset()
|
| 70 |
+
|
| 71 |
+
self._state.step_count += 1
|
| 72 |
+
reward = 0.0
|
| 73 |
+
tool_result = ""
|
| 74 |
+
|
| 75 |
+
# Inject late change at step 7+
|
| 76 |
+
if self._state.step_count >= 7 and not self.late_change_injected:
|
| 77 |
+
change_desc = self.scenario.inject_late_change()
|
| 78 |
+
if change_desc:
|
| 79 |
+
self.late_change_injected = True
|
| 80 |
+
self.late_change_step = self._state.step_count
|
| 81 |
+
tool_result = f"*** LATE CHANGE: {change_desc} ***\n\n"
|
| 82 |
+
|
| 83 |
+
# Process tool call
|
| 84 |
+
tool = action.tool
|
| 85 |
+
args = action.arguments
|
| 86 |
+
|
| 87 |
+
if tool == "check_calendar":
|
| 88 |
+
tool_result += self.scenario.calendar_text()
|
| 89 |
+
# Free action - no reward
|
| 90 |
+
|
| 91 |
+
elif tool == "check_inbox":
|
| 92 |
+
tool_result += self.scenario.inbox_text()
|
| 93 |
+
# Free action
|
| 94 |
+
|
| 95 |
+
elif tool == "reschedule":
|
| 96 |
+
event_id = args.get("event_id", "")
|
| 97 |
+
new_time = args.get("new_time", "")
|
| 98 |
+
conflict_r, pref_r, msg = score_reschedule(
|
| 99 |
+
self.scenario, event_id, new_time, self.scenario.preferences
|
| 100 |
+
)
|
| 101 |
+
reward += conflict_r + pref_r
|
| 102 |
+
self.reward_breakdown.conflict_resolution += conflict_r
|
| 103 |
+
self.reward_breakdown.preference_inference += pref_r
|
| 104 |
+
if conflict_r > 0:
|
| 105 |
+
self._state.conflicts_resolved += 1
|
| 106 |
+
if pref_r > 0:
|
| 107 |
+
self._state.preferences_inferred += 1
|
| 108 |
+
tool_result += msg
|
| 109 |
+
|
| 110 |
+
elif tool == "draft_reply":
|
| 111 |
+
email_id = args.get("email_id", "")
|
| 112 |
+
body = args.get("body", "")
|
| 113 |
+
|
| 114 |
+
if email_id in self.replied_emails:
|
| 115 |
+
reward -= 0.2
|
| 116 |
+
self._state.unnecessary_actions += 1
|
| 117 |
+
self.reward_breakdown.efficiency_penalty -= 0.2
|
| 118 |
+
tool_result += f"Already replied to {email_id}."
|
| 119 |
+
else:
|
| 120 |
+
email_r, pref_r, msg = score_email_reply(
|
| 121 |
+
email_id, body, self.scenario, self.scenario.preferences
|
| 122 |
+
)
|
| 123 |
+
reward += email_r + pref_r
|
| 124 |
+
self.reward_breakdown.email_quality += email_r
|
| 125 |
+
self.reward_breakdown.preference_inference += pref_r
|
| 126 |
+
self._state.emails_drafted += 1
|
| 127 |
+
if pref_r > 0:
|
| 128 |
+
self._state.preferences_inferred += 1
|
| 129 |
+
self.replied_emails.add(email_id)
|
| 130 |
+
|
| 131 |
+
# Mark deadline as met
|
| 132 |
+
for e in self.scenario.emails:
|
| 133 |
+
if e.email_id == email_id and e.deadline:
|
| 134 |
+
self._state.deadlines_met += 1
|
| 135 |
+
self.reward_breakdown.deadline_adherence += 0.5
|
| 136 |
+
|
| 137 |
+
tool_result += msg
|
| 138 |
+
|
| 139 |
+
elif tool == "delegate_task":
|
| 140 |
+
task_desc = args.get("task", "")
|
| 141 |
+
to = args.get("to", "")
|
| 142 |
+
if task_desc and to:
|
| 143 |
+
tool_result += f"Delegated '{task_desc}' to {to}."
|
| 144 |
+
# Small positive if it's related to a late change
|
| 145 |
+
if self.late_change_injected and self.late_change_step:
|
| 146 |
+
reward += 0.5
|
| 147 |
+
self.reward_breakdown.late_change_recovery += 0.5
|
| 148 |
+
self._state.late_changes_handled += 1
|
| 149 |
+
else:
|
| 150 |
+
reward -= 0.2
|
| 151 |
+
self._state.unnecessary_actions += 1
|
| 152 |
+
self.reward_breakdown.efficiency_penalty -= 0.2
|
| 153 |
+
tool_result += "Delegate requires 'task' and 'to' arguments."
|
| 154 |
+
|
| 155 |
+
elif tool == "done":
|
| 156 |
+
# Compute terminal rewards
|
| 157 |
+
terminal = score_terminal(self.scenario)
|
| 158 |
+
|
| 159 |
+
# Credit back deadlines that were met
|
| 160 |
+
terminal.deadline_adherence += self._state.deadlines_met * 1.0
|
| 161 |
+
|
| 162 |
+
# Credit late changes handled
|
| 163 |
+
if self.late_change_injected:
|
| 164 |
+
# Check if agent took any action after the late change
|
| 165 |
+
handled = self._state.late_changes_handled > 0
|
| 166 |
+
if handled:
|
| 167 |
+
terminal.late_change_recovery += 2.0
|
| 168 |
+
self._state.late_changes_handled = max(1, self._state.late_changes_handled)
|
| 169 |
+
|
| 170 |
+
reward += terminal.total
|
| 171 |
+
self.reward_breakdown.deadline_adherence += terminal.deadline_adherence
|
| 172 |
+
self.reward_breakdown.late_change_recovery += terminal.late_change_recovery
|
| 173 |
+
self.reward_breakdown.conflict_resolution += terminal.conflict_resolution
|
| 174 |
+
|
| 175 |
+
tool_result += f"Episode complete. Final breakdown:\n"
|
| 176 |
+
tool_result += f" Conflicts resolved: {self._state.conflicts_resolved}/{self._state.total_conflicts}\n"
|
| 177 |
+
tool_result += f" Emails drafted: {self._state.emails_drafted}/{self._state.total_emails}\n"
|
| 178 |
+
tool_result += f" Preferences inferred: {self._state.preferences_inferred}/{self._state.total_preferences}\n"
|
| 179 |
+
tool_result += f" Deadlines met: {self._state.deadlines_met}\n"
|
| 180 |
+
tool_result += f" Late changes handled: {self._state.late_changes_handled}/{self._state.total_late_changes}\n"
|
| 181 |
+
|
| 182 |
+
else:
|
| 183 |
+
self._state.unnecessary_actions += 1
|
| 184 |
+
reward -= 0.2
|
| 185 |
+
self.reward_breakdown.efficiency_penalty -= 0.2
|
| 186 |
+
tool_result += f"Unknown tool: {tool}. Available: check_calendar, check_inbox, reschedule, draft_reply, delegate_task, done"
|
| 187 |
+
|
| 188 |
+
done = tool == "done" or self._state.step_count >= 20
|
| 189 |
+
self._state.cumulative_reward += reward
|
| 190 |
+
|
| 191 |
+
# If we hit max steps without "done", compute terminal penalties
|
| 192 |
+
if self._state.step_count >= 20 and tool != "done":
|
| 193 |
+
terminal = score_terminal(self.scenario)
|
| 194 |
+
terminal.deadline_adherence += self._state.deadlines_met * 1.0
|
| 195 |
+
reward += terminal.total
|
| 196 |
+
self._state.cumulative_reward += terminal.total
|
| 197 |
+
tool_result += "\n[Max steps reached - episode terminated]"
|
| 198 |
+
|
| 199 |
+
return AssistantObservation(
|
| 200 |
+
inbox_summary=self.scenario.inbox_text(),
|
| 201 |
+
calendar_view=self.scenario.calendar_text(),
|
| 202 |
+
pending_tasks=self.scenario.pending_tasks_text(),
|
| 203 |
+
tool_result=tool_result,
|
| 204 |
+
conflicts=self.scenario.conflicts_text(),
|
| 205 |
+
done=done,
|
| 206 |
+
reward=reward,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def state(self) -> AssistantState:
|
| 211 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
server/reward.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decomposed reward computation for the Executive Assistant Arena.
|
| 2 |
+
|
| 3 |
+
All rewards are rule-based and deterministic. No LLM judges.
|
| 4 |
+
Each component is logged separately for W&B tracking.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
from .scenario_generator import Scenario, TIME_SLOTS
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class RewardBreakdown:
|
| 14 |
+
conflict_resolution: float = 0.0
|
| 15 |
+
preference_inference: float = 0.0
|
| 16 |
+
email_quality: float = 0.0
|
| 17 |
+
deadline_adherence: float = 0.0
|
| 18 |
+
efficiency_penalty: float = 0.0
|
| 19 |
+
late_change_recovery: float = 0.0
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def total(self) -> float:
|
| 23 |
+
return (
|
| 24 |
+
self.conflict_resolution
|
| 25 |
+
+ self.preference_inference
|
| 26 |
+
+ self.email_quality
|
| 27 |
+
+ self.deadline_adherence
|
| 28 |
+
+ self.efficiency_penalty
|
| 29 |
+
+ self.late_change_recovery
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def score_reschedule(
|
| 34 |
+
scenario: Scenario,
|
| 35 |
+
event_id: str,
|
| 36 |
+
new_time: str,
|
| 37 |
+
preferences: list[tuple[str, str]],
|
| 38 |
+
) -> tuple[float, float, str]:
|
| 39 |
+
"""Score a reschedule action. Returns (conflict_reward, pref_reward, message)."""
|
| 40 |
+
event = None
|
| 41 |
+
for e in scenario.calendar:
|
| 42 |
+
if e.event_id == event_id:
|
| 43 |
+
event = e
|
| 44 |
+
break
|
| 45 |
+
|
| 46 |
+
if event is None:
|
| 47 |
+
return -0.2, 0.0, f"Event {event_id} not found."
|
| 48 |
+
|
| 49 |
+
if not event.can_reschedule:
|
| 50 |
+
return -0.5, 0.0, f"Event {event_id} cannot be rescheduled (high priority)."
|
| 51 |
+
|
| 52 |
+
if new_time not in TIME_SLOTS:
|
| 53 |
+
return -0.2, 0.0, f"Invalid time slot: {new_time}."
|
| 54 |
+
|
| 55 |
+
# Check if this resolves a conflict
|
| 56 |
+
old_time = event.time
|
| 57 |
+
was_in_conflict = any(
|
| 58 |
+
event_id in (a, b) for a, b in scenario.conflicts
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Temporarily move event and check new conflicts
|
| 62 |
+
event.time = new_time
|
| 63 |
+
time_index = {t: i for i, t in enumerate(TIME_SLOTS)}
|
| 64 |
+
|
| 65 |
+
creates_new_conflict = False
|
| 66 |
+
for other in scenario.calendar:
|
| 67 |
+
if other.event_id == event_id:
|
| 68 |
+
continue
|
| 69 |
+
if other.time in time_index and new_time in time_index:
|
| 70 |
+
o_start = time_index[other.time]
|
| 71 |
+
n_start = time_index[new_time]
|
| 72 |
+
o_slots = other.duration_min // 30
|
| 73 |
+
e_slots = event.duration_min // 30
|
| 74 |
+
if n_start < o_start + o_slots and o_start < n_start + e_slots:
|
| 75 |
+
creates_new_conflict = True
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
conflict_reward = 0.0
|
| 79 |
+
if was_in_conflict and not creates_new_conflict:
|
| 80 |
+
conflict_reward = 1.0
|
| 81 |
+
# Remove resolved conflicts
|
| 82 |
+
scenario.conflicts = [
|
| 83 |
+
(a, b) for a, b in scenario.conflicts
|
| 84 |
+
if event_id not in (a, b)
|
| 85 |
+
]
|
| 86 |
+
msg = f"Conflict resolved: {event_id} moved to {new_time}."
|
| 87 |
+
elif creates_new_conflict:
|
| 88 |
+
conflict_reward = -0.5
|
| 89 |
+
event.time = old_time # revert
|
| 90 |
+
msg = f"Cannot move {event_id} to {new_time} - creates new conflict."
|
| 91 |
+
else:
|
| 92 |
+
conflict_reward = 0.0
|
| 93 |
+
msg = f"Moved {event_id} to {new_time} (no conflict impact)."
|
| 94 |
+
|
| 95 |
+
# Check preference alignment
|
| 96 |
+
pref_reward = 0.0
|
| 97 |
+
pref_ids = [p[0] for p in preferences]
|
| 98 |
+
|
| 99 |
+
if "no_early_meetings" in pref_ids and new_time in ["9:00am", "9:30am"]:
|
| 100 |
+
pref_reward -= 0.3
|
| 101 |
+
msg += " Warning: user prefers no early meetings."
|
| 102 |
+
if "lunch_block" in pref_ids and new_time in ["12:00pm", "12:30pm"]:
|
| 103 |
+
pref_reward -= 0.3
|
| 104 |
+
msg += " Warning: moved into lunch block."
|
| 105 |
+
if "no_early_meetings" in pref_ids and old_time in ["9:00am", "9:30am"] and new_time not in ["9:00am", "9:30am"]:
|
| 106 |
+
pref_reward += 0.5
|
| 107 |
+
msg += " Good: moved away from early slot per preference."
|
| 108 |
+
if "buffer_time" in pref_ids or "no_back_to_back" in pref_ids:
|
| 109 |
+
# Check adjacent meetings
|
| 110 |
+
n_idx = time_index.get(new_time, -1)
|
| 111 |
+
for other in scenario.calendar:
|
| 112 |
+
if other.event_id == event_id:
|
| 113 |
+
continue
|
| 114 |
+
o_idx = time_index.get(other.time, -1)
|
| 115 |
+
if abs(n_idx - o_idx) == 1:
|
| 116 |
+
pref_reward -= 0.3
|
| 117 |
+
msg += " Warning: back-to-back meeting created."
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
return conflict_reward, pref_reward, msg
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def score_email_reply(
|
| 124 |
+
email_id: str,
|
| 125 |
+
reply_body: str,
|
| 126 |
+
scenario: Scenario,
|
| 127 |
+
preferences: list[tuple[str, str]],
|
| 128 |
+
) -> tuple[float, float, str]:
|
| 129 |
+
"""Score an email reply. Returns (email_reward, pref_reward, message)."""
|
| 130 |
+
email = None
|
| 131 |
+
for e in scenario.emails:
|
| 132 |
+
if e.email_id == email_id:
|
| 133 |
+
email = e
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
if email is None:
|
| 137 |
+
return -0.2, 0.0, f"Email {email_id} not found."
|
| 138 |
+
|
| 139 |
+
if not reply_body or len(reply_body.strip()) < 10:
|
| 140 |
+
return 0.0, 0.0, "Reply too short."
|
| 141 |
+
|
| 142 |
+
reply_lower = reply_body.lower()
|
| 143 |
+
|
| 144 |
+
# Score: addresses_issue (0.4)
|
| 145 |
+
addresses_score = 0.0
|
| 146 |
+
for kp in email.key_points:
|
| 147 |
+
# Simple keyword matching
|
| 148 |
+
keywords = kp.lower().split()
|
| 149 |
+
matches = sum(1 for kw in keywords if kw in reply_lower)
|
| 150 |
+
if matches >= len(keywords) * 0.3:
|
| 151 |
+
addresses_score += 0.4 / len(email.key_points)
|
| 152 |
+
|
| 153 |
+
# Score: tone (0.3)
|
| 154 |
+
formal_markers = ["dear", "regards", "sincerely", "please find", "i would like to"]
|
| 155 |
+
informal_markers = ["hey", "hi!", "thanks!", "sounds good", "sure thing", "no worries"]
|
| 156 |
+
|
| 157 |
+
formal_count = sum(1 for m in formal_markers if m in reply_lower)
|
| 158 |
+
informal_count = sum(1 for m in informal_markers if m in reply_lower)
|
| 159 |
+
|
| 160 |
+
tone_score = 0.0
|
| 161 |
+
if email.tone_expected == "formal" and formal_count > informal_count:
|
| 162 |
+
tone_score = 0.3
|
| 163 |
+
elif email.tone_expected == "informal" and informal_count >= formal_count:
|
| 164 |
+
tone_score = 0.3
|
| 165 |
+
elif formal_count == 0 and informal_count == 0:
|
| 166 |
+
tone_score = 0.15 # neutral is ok
|
| 167 |
+
|
| 168 |
+
# Score: preference alignment (0.3)
|
| 169 |
+
pref_score = 0.0
|
| 170 |
+
pref_ids = [p[0] for p in preferences]
|
| 171 |
+
if "informal_tone" in pref_ids and informal_count > 0:
|
| 172 |
+
pref_score += 0.3
|
| 173 |
+
elif "formal_tone" in pref_ids and formal_count > 0:
|
| 174 |
+
pref_score += 0.3
|
| 175 |
+
elif "informal_tone" not in pref_ids and "formal_tone" not in pref_ids:
|
| 176 |
+
pref_score += 0.15 # no tone preference
|
| 177 |
+
|
| 178 |
+
email_reward = addresses_score + tone_score + pref_score
|
| 179 |
+
pref_reward = 0.0
|
| 180 |
+
if pref_score > 0:
|
| 181 |
+
pref_reward = 0.5 # preference inferred
|
| 182 |
+
|
| 183 |
+
msg = f"Email reply scored: addresses={addresses_score:.2f}, tone={tone_score:.2f}, pref={pref_score:.2f}"
|
| 184 |
+
return email_reward, pref_reward, msg
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def score_terminal(scenario: Scenario) -> RewardBreakdown:
|
| 188 |
+
"""Compute terminal rewards at episode end."""
|
| 189 |
+
breakdown = RewardBreakdown()
|
| 190 |
+
|
| 191 |
+
# Deadline adherence
|
| 192 |
+
for email in scenario.emails:
|
| 193 |
+
if email.deadline and email.requires_reply:
|
| 194 |
+
breakdown.deadline_adherence -= 1.0 # missed deadline (unreplied)
|
| 195 |
+
elif email.deadline is None and email.requires_reply:
|
| 196 |
+
breakdown.deadline_adherence -= 0.5 # unreplied but no deadline
|
| 197 |
+
|
| 198 |
+
# Unresolved conflicts
|
| 199 |
+
remaining = len(scenario.conflicts)
|
| 200 |
+
breakdown.conflict_resolution -= remaining * 0.5
|
| 201 |
+
|
| 202 |
+
# Late changes not handled
|
| 203 |
+
for lc in scenario.late_changes:
|
| 204 |
+
if lc.injected:
|
| 205 |
+
breakdown.late_change_recovery += 0.0 # was injected but not handled
|
| 206 |
+
|
| 207 |
+
return breakdown
|
server/scenario_generator.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Procedural scenario generation for the Executive Assistant Arena."""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
NAMES = [
|
| 8 |
+
"Alice Chen", "Bob Martinez", "Carol Park", "David Kim", "Eve Johnson",
|
| 9 |
+
"Frank Lee", "Grace Wang", "Henry Brown", "Irene Davis", "Jack Wilson",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
EMAIL_SUBJECTS = [
|
| 13 |
+
"Q3 Budget Review", "Team Offsite Planning", "Client Demo Prep",
|
| 14 |
+
"Performance Review Follow-up", "Product Launch Timeline",
|
| 15 |
+
"Vendor Contract Renewal", "Board Presentation Draft", "Hiring Update",
|
| 16 |
+
"Customer Escalation", "Partnership Proposal",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
MEETING_TYPES = [
|
| 20 |
+
"1:1", "team standup", "client call", "design review", "sprint planning",
|
| 21 |
+
"all-hands", "interview", "lunch meeting", "board prep", "strategy session",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
TIME_SLOTS = [
|
| 25 |
+
"9:00am", "9:30am", "10:00am", "10:30am", "11:00am", "11:30am",
|
| 26 |
+
"12:00pm", "12:30pm", "1:00pm", "1:30pm", "2:00pm", "2:30pm",
|
| 27 |
+
"3:00pm", "3:30pm", "4:00pm", "4:30pm", "5:00pm",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
PREFERENCES = [
|
| 31 |
+
("no_early_meetings", "User prefers no meetings before 10am"),
|
| 32 |
+
("lunch_block", "User always blocks 12pm-1pm for lunch"),
|
| 33 |
+
("informal_tone", "User prefers informal/casual tone in emails"),
|
| 34 |
+
("formal_tone", "User prefers formal/professional tone in emails"),
|
| 35 |
+
("short_meetings", "User prefers 30-min meetings over 60-min"),
|
| 36 |
+
("no_friday_meetings", "User avoids meetings on Fridays"),
|
| 37 |
+
("boss_priority", "Meetings with the boss always take priority"),
|
| 38 |
+
("client_priority", "Client meetings cannot be rescheduled"),
|
| 39 |
+
("buffer_time", "User needs 15-min buffer between meetings"),
|
| 40 |
+
("no_back_to_back", "User dislikes back-to-back meetings"),
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
LATE_CHANGES = [
|
| 44 |
+
"boss_reschedule", # Boss moves a meeting to a conflicting time
|
| 45 |
+
"urgent_client", # Urgent client call appears
|
| 46 |
+
"meeting_cancelled", # A meeting gets cancelled, opening a slot
|
| 47 |
+
"deadline_moved", # A deadline moves earlier
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class CalendarEvent:
|
| 53 |
+
event_id: str
|
| 54 |
+
title: str
|
| 55 |
+
time: str
|
| 56 |
+
duration_min: int
|
| 57 |
+
attendees: list[str]
|
| 58 |
+
priority: str # "high", "medium", "low"
|
| 59 |
+
can_reschedule: bool = True
|
| 60 |
+
|
| 61 |
+
def to_text(self) -> str:
|
| 62 |
+
att = ", ".join(self.attendees)
|
| 63 |
+
return f"[{self.event_id}] {self.time} ({self.duration_min}min) - {self.title} with {att} [priority: {self.priority}]"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class Email:
|
| 68 |
+
email_id: str
|
| 69 |
+
sender: str
|
| 70 |
+
subject: str
|
| 71 |
+
body: str
|
| 72 |
+
requires_reply: bool
|
| 73 |
+
tone_expected: str # "formal" or "informal"
|
| 74 |
+
key_points: list[str] # what the reply must address
|
| 75 |
+
deadline: str | None = None
|
| 76 |
+
|
| 77 |
+
def to_text(self) -> str:
|
| 78 |
+
dl = f" [DEADLINE: {self.deadline}]" if self.deadline else ""
|
| 79 |
+
return f"[{self.email_id}] From: {self.sender} | Subject: {self.subject}{dl}\n {self.body}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class LateChange:
|
| 84 |
+
change_type: str
|
| 85 |
+
description: str
|
| 86 |
+
affected_event_id: str | None
|
| 87 |
+
new_time: str | None = None
|
| 88 |
+
injected: bool = False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class Scenario:
|
| 93 |
+
calendar: list[CalendarEvent]
|
| 94 |
+
emails: list[Email]
|
| 95 |
+
preferences: list[tuple[str, str]] # (pref_id, description)
|
| 96 |
+
late_changes: list[LateChange]
|
| 97 |
+
conflicts: list[tuple[str, str]] # pairs of conflicting event_ids
|
| 98 |
+
difficulty: str
|
| 99 |
+
|
| 100 |
+
def calendar_text(self) -> str:
|
| 101 |
+
return "\n".join(e.to_text() for e in self.calendar)
|
| 102 |
+
|
| 103 |
+
def inbox_text(self) -> str:
|
| 104 |
+
return "\n\n".join(e.to_text() for e in self.emails)
|
| 105 |
+
|
| 106 |
+
def conflicts_text(self) -> list[str]:
|
| 107 |
+
return [f"CONFLICT: {a} overlaps with {b}" for a, b in self.conflicts]
|
| 108 |
+
|
| 109 |
+
def pending_tasks_text(self) -> list[str]:
|
| 110 |
+
tasks = []
|
| 111 |
+
for a, b in self.conflicts:
|
| 112 |
+
tasks.append(f"Resolve conflict between {a} and {b}")
|
| 113 |
+
for e in self.emails:
|
| 114 |
+
if e.requires_reply:
|
| 115 |
+
tasks.append(f"Reply to email {e.email_id} from {e.sender}")
|
| 116 |
+
return tasks
|
| 117 |
+
|
| 118 |
+
def inject_late_change(self) -> str | None:
|
| 119 |
+
"""Inject the next un-injected late change. Returns description or None."""
|
| 120 |
+
for lc in self.late_changes:
|
| 121 |
+
if not lc.injected:
|
| 122 |
+
lc.injected = True
|
| 123 |
+
if lc.change_type == "boss_reschedule" and lc.affected_event_id:
|
| 124 |
+
for ev in self.calendar:
|
| 125 |
+
if ev.event_id == lc.affected_event_id and lc.new_time:
|
| 126 |
+
ev.time = lc.new_time
|
| 127 |
+
# This may create a new conflict
|
| 128 |
+
self._recompute_conflicts()
|
| 129 |
+
elif lc.change_type == "meeting_cancelled" and lc.affected_event_id:
|
| 130 |
+
self.calendar = [e for e in self.calendar if e.event_id != lc.affected_event_id]
|
| 131 |
+
self._recompute_conflicts()
|
| 132 |
+
return lc.description
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def _recompute_conflicts(self):
|
| 136 |
+
"""Recompute conflicts based on current calendar."""
|
| 137 |
+
time_index = {t: i for i, t in enumerate(TIME_SLOTS)}
|
| 138 |
+
self.conflicts = []
|
| 139 |
+
events = self.calendar
|
| 140 |
+
for i in range(len(events)):
|
| 141 |
+
for j in range(i + 1, len(events)):
|
| 142 |
+
a, b = events[i], events[j]
|
| 143 |
+
if a.time in time_index and b.time in time_index:
|
| 144 |
+
a_start = time_index[a.time]
|
| 145 |
+
b_start = time_index[b.time]
|
| 146 |
+
a_slots = a.duration_min // 30
|
| 147 |
+
b_slots = b.duration_min // 30
|
| 148 |
+
if a_start < b_start + b_slots and b_start < a_start + a_slots:
|
| 149 |
+
self.conflicts.append((a.event_id, b.event_id))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def generate_scenario(difficulty: str = "medium", seed: int | None = None) -> Scenario:
|
| 153 |
+
"""Generate a procedural scenario with the given difficulty."""
|
| 154 |
+
rng = random.Random(seed)
|
| 155 |
+
|
| 156 |
+
if difficulty == "easy":
|
| 157 |
+
n_events, n_conflicts, n_emails, n_prefs, n_late = 4, 2, 1, 2, 0
|
| 158 |
+
elif difficulty == "medium":
|
| 159 |
+
n_events, n_conflicts, n_emails, n_prefs, n_late = 6, 4, 3, 4, 1
|
| 160 |
+
else: # hard
|
| 161 |
+
n_events, n_conflicts, n_emails, n_prefs, n_late = 8, 6, 5, 6, 2
|
| 162 |
+
|
| 163 |
+
# Generate calendar events
|
| 164 |
+
people = rng.sample(NAMES, min(n_events + n_emails, len(NAMES)))
|
| 165 |
+
meeting_types = rng.sample(MEETING_TYPES, min(n_events, len(MEETING_TYPES)))
|
| 166 |
+
|
| 167 |
+
# Pick time slots - intentionally create conflicts
|
| 168 |
+
available_slots = list(TIME_SLOTS)
|
| 169 |
+
events = []
|
| 170 |
+
used_slots = []
|
| 171 |
+
|
| 172 |
+
for i in range(n_events):
|
| 173 |
+
eid = f"mtg_{i+1}"
|
| 174 |
+
title = meeting_types[i] if i < len(meeting_types) else f"Meeting {i+1}"
|
| 175 |
+
attendee = people[i] if i < len(people) else rng.choice(NAMES)
|
| 176 |
+
duration = rng.choice([30, 60])
|
| 177 |
+
priority = rng.choice(["high", "medium", "low"])
|
| 178 |
+
can_resched = priority != "high" or rng.random() > 0.5
|
| 179 |
+
|
| 180 |
+
if i < n_conflicts and used_slots:
|
| 181 |
+
# Intentionally pick a conflicting time
|
| 182 |
+
time = rng.choice(used_slots)
|
| 183 |
+
else:
|
| 184 |
+
time = rng.choice(available_slots)
|
| 185 |
+
|
| 186 |
+
used_slots.append(time)
|
| 187 |
+
events.append(CalendarEvent(
|
| 188 |
+
event_id=eid, title=title, time=time,
|
| 189 |
+
duration_min=duration, attendees=[attendee],
|
| 190 |
+
priority=priority, can_reschedule=can_resched,
|
| 191 |
+
))
|
| 192 |
+
|
| 193 |
+
# Compute actual conflicts
|
| 194 |
+
time_index = {t: i for i, t in enumerate(TIME_SLOTS)}
|
| 195 |
+
conflicts = []
|
| 196 |
+
for i in range(len(events)):
|
| 197 |
+
for j in range(i + 1, len(events)):
|
| 198 |
+
a, b = events[i], events[j]
|
| 199 |
+
if a.time in time_index and b.time in time_index:
|
| 200 |
+
a_start = time_index[a.time]
|
| 201 |
+
b_start = time_index[b.time]
|
| 202 |
+
a_slots = a.duration_min // 30
|
| 203 |
+
b_slots = b.duration_min // 30
|
| 204 |
+
if a_start < b_start + b_slots and b_start < a_start + a_slots:
|
| 205 |
+
conflicts.append((a.event_id, b.event_id))
|
| 206 |
+
|
| 207 |
+
# Generate emails
|
| 208 |
+
emails = []
|
| 209 |
+
for i in range(n_emails):
|
| 210 |
+
sender = people[n_events + i] if n_events + i < len(people) else rng.choice(NAMES)
|
| 211 |
+
subject = rng.choice(EMAIL_SUBJECTS)
|
| 212 |
+
tone = rng.choice(["formal", "informal"])
|
| 213 |
+
key_points = [f"Address the {subject.lower()} timeline"]
|
| 214 |
+
if rng.random() > 0.5:
|
| 215 |
+
key_points.append("Confirm next steps")
|
| 216 |
+
deadline = rng.choice(["today", "tomorrow", None])
|
| 217 |
+
body = f"Hi, I wanted to follow up on {subject.lower()}. Could you get back to me{' by ' + deadline if deadline else ''}? Thanks, {sender}"
|
| 218 |
+
|
| 219 |
+
emails.append(Email(
|
| 220 |
+
email_id=f"email_{i+1}", sender=sender, subject=subject,
|
| 221 |
+
body=body, requires_reply=True, tone_expected=tone,
|
| 222 |
+
key_points=key_points, deadline=deadline,
|
| 223 |
+
))
|
| 224 |
+
|
| 225 |
+
# Pick preferences
|
| 226 |
+
prefs = rng.sample(PREFERENCES, min(n_prefs, len(PREFERENCES)))
|
| 227 |
+
|
| 228 |
+
# Generate late changes
|
| 229 |
+
late_changes = []
|
| 230 |
+
for i in range(n_late):
|
| 231 |
+
change_type = rng.choice(LATE_CHANGES)
|
| 232 |
+
if change_type == "boss_reschedule" and events:
|
| 233 |
+
target = rng.choice(events)
|
| 234 |
+
new_time = rng.choice([t for t in TIME_SLOTS if t != target.time])
|
| 235 |
+
late_changes.append(LateChange(
|
| 236 |
+
change_type=change_type,
|
| 237 |
+
description=f"URGENT: Boss has rescheduled {target.title} ({target.event_id}) to {new_time}",
|
| 238 |
+
affected_event_id=target.event_id,
|
| 239 |
+
new_time=new_time,
|
| 240 |
+
))
|
| 241 |
+
elif change_type == "meeting_cancelled" and events:
|
| 242 |
+
target = rng.choice(events)
|
| 243 |
+
late_changes.append(LateChange(
|
| 244 |
+
change_type=change_type,
|
| 245 |
+
description=f"CANCELLED: {target.title} ({target.event_id}) has been cancelled",
|
| 246 |
+
affected_event_id=target.event_id,
|
| 247 |
+
))
|
| 248 |
+
elif change_type == "urgent_client":
|
| 249 |
+
time = rng.choice(TIME_SLOTS)
|
| 250 |
+
late_changes.append(LateChange(
|
| 251 |
+
change_type=change_type,
|
| 252 |
+
description=f"URGENT: New client call scheduled at {time} - must attend",
|
| 253 |
+
affected_event_id=None,
|
| 254 |
+
new_time=time,
|
| 255 |
+
))
|
| 256 |
+
else:
|
| 257 |
+
late_changes.append(LateChange(
|
| 258 |
+
change_type="deadline_moved",
|
| 259 |
+
description="URGENT: Q3 report deadline moved to today",
|
| 260 |
+
affected_event_id=None,
|
| 261 |
+
))
|
| 262 |
+
|
| 263 |
+
return Scenario(
|
| 264 |
+
calendar=events, emails=emails, preferences=prefs,
|
| 265 |
+
late_changes=late_changes, conflicts=conflicts, difficulty=difficulty,
|
| 266 |
+
)
|