Upload folder using huggingface_hub
Browse files- .env.example +7 -0
- .github/workflows/ci.yml +20 -0
- .gitignore +24 -0
- Dockerfile +19 -0
- README.md +200 -0
- app.py +37 -0
- baseline_expected_scores.json +29 -0
- inference.py +300 -0
- openenv.yaml +31 -0
- pre_submission_validate.py +352 -0
- pyproject.toml +29 -0
- requirements.txt +2 -0
- scripts/bootstrap_remotes.sh +83 -0
- scripts/pre_validation_script.sh +185 -0
- scripts/run_baseline.py +197 -0
- scripts/sample_inference_script.sh +188 -0
- scripts/validate_env.py +30 -0
- src/support_triage_openenv/__init__.py +6 -0
- src/support_triage_openenv/env.py +229 -0
- src/support_triage_openenv/graders.py +64 -0
- src/support_triage_openenv/models.py +53 -0
- src/support_triage_openenv/tasks.py +153 -0
- tasks/TASKS.md +25 -0
- tests/test_api.py +27 -0
- tests/test_env.py +47 -0
.env.example
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mandatory hackathon environment variables
|
| 2 |
+
API_BASE_URL=https://your-openai-compatible-endpoint/v1
|
| 3 |
+
MODEL_NAME=your-model-id
|
| 4 |
+
HF_TOKEN=your-api-key
|
| 5 |
+
|
| 6 |
+
# Optional for validator remote ping
|
| 7 |
+
SPACE_URL=https://your-space-name.hf.space
|
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
pull_request:
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
test:
|
| 9 |
+
runs-on: ubuntu-latest
|
| 10 |
+
steps:
|
| 11 |
+
- uses: actions/checkout@v4
|
| 12 |
+
- uses: actions/setup-python@v5
|
| 13 |
+
with:
|
| 14 |
+
python-version: '3.11'
|
| 15 |
+
- name: Install deps
|
| 16 |
+
run: |
|
| 17 |
+
python -m pip install --upgrade pip
|
| 18 |
+
pip install -r requirements.txt
|
| 19 |
+
- name: Run tests
|
| 20 |
+
run: python -m pytest -q
|
.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.so
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
.mypy_cache/
|
| 7 |
+
|
| 8 |
+
# Virtual environments
|
| 9 |
+
.venv/
|
| 10 |
+
venv/
|
| 11 |
+
|
| 12 |
+
# Build artifacts
|
| 13 |
+
build/
|
| 14 |
+
dist/
|
| 15 |
+
*.egg-info/
|
| 16 |
+
|
| 17 |
+
# OS/editor
|
| 18 |
+
.DS_Store
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
|
| 22 |
+
# Runtime artifacts
|
| 23 |
+
*.log
|
| 24 |
+
scores/
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
|
| 8 |
+
COPY pyproject.toml requirements.txt README.md /app/
|
| 9 |
+
COPY src /app/src
|
| 10 |
+
COPY scripts /app/scripts
|
| 11 |
+
COPY openenv.yaml /app/openenv.yaml
|
| 12 |
+
COPY app.py /app/app.py
|
| 13 |
+
|
| 14 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 15 |
+
pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
EXPOSE 7860
|
| 18 |
+
|
| 19 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Support Triage OpenEnv
|
| 3 |
+
emoji: "📨"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: teal
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- customer-support
|
| 12 |
+
license: mit
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Support Triage OpenEnv
|
| 16 |
+
|
| 17 |
+
A complete, real-world OpenEnv environment for training/evaluating agents on **customer support ticket triage**. The environment simulates what support teams actually do: read inbox tickets, classify urgency/category, draft safe responses, and resolve the right ticket.
|
| 18 |
+
|
| 19 |
+
## Why this environment
|
| 20 |
+
|
| 21 |
+
Most agent benchmarks under-model production support workflows. This environment focuses on practical support operations with:
|
| 22 |
+
- Multi-ticket inbox context selection
|
| 23 |
+
- Policy-compliant communication
|
| 24 |
+
- Priority + escalation decisions
|
| 25 |
+
- Deterministic graders and dense reward shaping
|
| 26 |
+
|
| 27 |
+
## OpenEnv API compliance
|
| 28 |
+
|
| 29 |
+
The environment exposes:
|
| 30 |
+
- `reset(task_id?: str) -> Observation`
|
| 31 |
+
- `step(action: Action) -> (Observation, Reward, done, info)`
|
| 32 |
+
- `state() -> dict`
|
| 33 |
+
|
| 34 |
+
Typed Pydantic models:
|
| 35 |
+
- `Observation`: [`src/support_triage_openenv/models.py`](src/support_triage_openenv/models.py)
|
| 36 |
+
- `Action`: [`src/support_triage_openenv/models.py`](src/support_triage_openenv/models.py)
|
| 37 |
+
- `Reward`: [`src/support_triage_openenv/models.py`](src/support_triage_openenv/models.py)
|
| 38 |
+
|
| 39 |
+
Metadata:
|
| 40 |
+
- `openenv.yaml`
|
| 41 |
+
|
| 42 |
+
## Action space
|
| 43 |
+
|
| 44 |
+
`Action` model fields:
|
| 45 |
+
- `action_type`: one of `read_ticket | classify_ticket | draft_reply | resolve_ticket`
|
| 46 |
+
- `ticket_id`: required for `read_ticket`, `classify_ticket`, `resolve_ticket`
|
| 47 |
+
- `priority`: optional enum `low | medium | high | urgent`
|
| 48 |
+
- `category`: optional enum `account | billing | technical | abuse | general`
|
| 49 |
+
- `needs_escalation`: optional bool
|
| 50 |
+
- `message`: text for `draft_reply`
|
| 51 |
+
|
| 52 |
+
## Observation space
|
| 53 |
+
|
| 54 |
+
`Observation` includes:
|
| 55 |
+
- `task_id`, `objective`, `step_count`, `max_steps`
|
| 56 |
+
- `inbox`: ticket metadata list (`ticket_id`, subject, tier, age, read flag)
|
| 57 |
+
- `current_ticket_content`: only visible after reading selected ticket
|
| 58 |
+
- `latest_system_note`: feedback from last step
|
| 59 |
+
- `score_hint`: partial grader components (`read`, `classify`, `reply`, `resolve`)
|
| 60 |
+
|
| 61 |
+
## Tasks and difficulty
|
| 62 |
+
|
| 63 |
+
1. `easy_password_reset` (Easy)
|
| 64 |
+
- Correctly process account lockout and send secure reset guidance.
|
| 65 |
+
|
| 66 |
+
2. `medium_billing_dispute` (Medium)
|
| 67 |
+
- Investigate duplicate billing with context ticket and provide policy-compliant refund timeline.
|
| 68 |
+
|
| 69 |
+
3. `hard_outage_incident` (Hard)
|
| 70 |
+
- Handle a high-stakes outage report requiring multi-ticket context, urgent escalation, and careful incident messaging.
|
| 71 |
+
|
| 72 |
+
Each task has deterministic grading in `support_triage_openenv.graders.grade_task`, returning a score `0.0-1.0`.
|
| 73 |
+
|
| 74 |
+
## Reward design
|
| 75 |
+
|
| 76 |
+
Reward is shaped and meaningful across the trajectory:
|
| 77 |
+
- Positive dense signal from partial grader progress (read/context, classification fields, reply quality, resolve correctness)
|
| 78 |
+
- Penalties for invalid actions, repeated loops, and malformed steps
|
| 79 |
+
- Final step guarantees score alignment with deterministic grader output
|
| 80 |
+
|
| 81 |
+
## Project structure
|
| 82 |
+
|
| 83 |
+
- `src/support_triage_openenv/env.py` - environment implementation
|
| 84 |
+
- `src/support_triage_openenv/models.py` - typed OpenEnv models
|
| 85 |
+
- `src/support_triage_openenv/tasks.py` - task specs (easy/medium/hard)
|
| 86 |
+
- `src/support_triage_openenv/graders.py` - deterministic grader logic
|
| 87 |
+
- `scripts/run_baseline.py` - OpenAI baseline inference runner
|
| 88 |
+
- `scripts/validate_env.py` - tests + optional `openenv validate`
|
| 89 |
+
- `app.py` - FastAPI app for HF Space runtime
|
| 90 |
+
- `Dockerfile` - containerized deployment
|
| 91 |
+
|
| 92 |
+
## Setup
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
cd /home/ai24mtech14005/meta_hackathon
|
| 96 |
+
python3 -m venv .venv
|
| 97 |
+
source .venv/bin/activate
|
| 98 |
+
pip install -r requirements.txt
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Run tests
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
python -m pytest -q
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Run baseline
|
| 108 |
+
|
| 109 |
+
OpenAI model baseline:
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
export API_BASE_URL=https://your-openai-compatible-endpoint/v1
|
| 113 |
+
export MODEL_NAME=your-model-id
|
| 114 |
+
export HF_TOKEN=your-api-key
|
| 115 |
+
python inference.py --mode openai --output scores/inference_scores.json
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Deterministic heuristic baseline:
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
python inference.py --mode heuristic --output scores/inference_scores.json
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
Outputs JSON report to `scores/inference_scores.json` and structured stdout logs with `[START]`, `[STEP]`, `[END]`.
|
| 125 |
+
|
| 126 |
+
## Run API locally
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
uvicorn app:app --host 0.0.0.0 --port 7860
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
Endpoints:
|
| 133 |
+
- `GET /health`
|
| 134 |
+
- `POST /reset`
|
| 135 |
+
- `POST /step`
|
| 136 |
+
- `GET /state`
|
| 137 |
+
|
| 138 |
+
## Docker
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
docker build -t support-triage-openenv .
|
| 142 |
+
docker run --rm -p 7860:7860 support-triage-openenv
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Hugging Face Space deployment
|
| 146 |
+
|
| 147 |
+
- Create a **Docker Space**.
|
| 148 |
+
- Push this repository to the Space.
|
| 149 |
+
- Keep `README.md` frontmatter tags including `openenv`.
|
| 150 |
+
- Space serves the API on port `7860`.
|
| 151 |
+
|
| 152 |
+
## One-command remote bootstrap
|
| 153 |
+
|
| 154 |
+
If you want this local repo to automatically create and push to both GitHub + HF:
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
export GITHUB_USERNAME=your_github_user
|
| 158 |
+
export GITHUB_TOKEN=your_github_pat
|
| 159 |
+
export HF_USERNAME=your_hf_user
|
| 160 |
+
export HF_TOKEN=your_hf_token
|
| 161 |
+
bash scripts/bootstrap_remotes.sh support-triage-openenv
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
## Baseline scores (heuristic reproducible)
|
| 165 |
+
|
| 166 |
+
Generated with:
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
python inference.py --mode heuristic --output scores/inference_scores.json
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
- `easy_password_reset`: grader `1.0`, reward `1.0`
|
| 173 |
+
- `medium_billing_dispute`: grader `1.0`, reward `1.0`
|
| 174 |
+
- `hard_outage_incident`: grader `1.0`, reward `1.0`
|
| 175 |
+
- Overall average grader score: `1.0`
|
| 176 |
+
- Tracked reference artifact: `baseline_expected_scores.json`
|
| 177 |
+
|
| 178 |
+
## Pre-submission validator
|
| 179 |
+
|
| 180 |
+
Run full strict validation (all disqualification gates):
|
| 181 |
+
|
| 182 |
+
```bash
|
| 183 |
+
python pre_submission_validate.py --space-url https://your-space-name.hf.space
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
Local-only run while iterating (skips Docker daemon + remote space ping):
|
| 187 |
+
|
| 188 |
+
```bash
|
| 189 |
+
python pre_submission_validate.py --skip-docker --skip-space
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
Run organizer-provided script directly (integrated path):
|
| 193 |
+
|
| 194 |
+
```bash
|
| 195 |
+
bash scripts/pre_validation_script.sh https://your-space-name.hf.space .
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
Notes:
|
| 199 |
+
- `scripts/sample_inference_script.sh` is kept as organizer reference.
|
| 200 |
+
- Root `inference.py` is aligned to the required `[START]`, `[STEP]`, `[END]` line format.
|
app.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
|
| 5 |
+
from support_triage_openenv.env import SupportTriageEnv
|
| 6 |
+
from support_triage_openenv.models import Action
|
| 7 |
+
|
| 8 |
+
app = FastAPI(title="Support Triage OpenEnv", version="0.1.0")
|
| 9 |
+
env = SupportTriageEnv()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@app.get("/health")
|
| 13 |
+
def health() -> dict[str, str]:
|
| 14 |
+
return {"status": "ok"}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@app.post("/reset")
|
| 18 |
+
def reset(payload: dict | None = None) -> dict:
|
| 19 |
+
task_id = (payload or {}).get("task_id")
|
| 20 |
+
obs = env.reset(task_id=task_id)
|
| 21 |
+
return obs.model_dump()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@app.post("/step")
|
| 25 |
+
def step(action: Action) -> dict:
|
| 26 |
+
obs, reward, done, info = env.step(action)
|
| 27 |
+
return {
|
| 28 |
+
"observation": obs.model_dump(),
|
| 29 |
+
"reward": reward.model_dump(),
|
| 30 |
+
"done": done,
|
| 31 |
+
"info": info,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.get("/state")
|
| 36 |
+
def state() -> dict:
|
| 37 |
+
return env.state()
|
baseline_expected_scores.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mode": "heuristic",
|
| 3 |
+
"model": "gpt-4.1-mini",
|
| 4 |
+
"avg_grader_score": 1.0,
|
| 5 |
+
"avg_final_reward": 1.0,
|
| 6 |
+
"episodes": [
|
| 7 |
+
{
|
| 8 |
+
"task_id": "easy_password_reset",
|
| 9 |
+
"steps": 4,
|
| 10 |
+
"grader_score": 1.0,
|
| 11 |
+
"reward": 1.0,
|
| 12 |
+
"done_reason": "resolved"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"task_id": "medium_billing_dispute",
|
| 16 |
+
"steps": 5,
|
| 17 |
+
"grader_score": 1.0,
|
| 18 |
+
"reward": 1.0,
|
| 19 |
+
"done_reason": "resolved"
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"task_id": "hard_outage_incident",
|
| 23 |
+
"steps": 6,
|
| 24 |
+
"grader_score": 1.0,
|
| 25 |
+
"reward": 1.0,
|
| 26 |
+
"done_reason": "resolved"
|
| 27 |
+
}
|
| 28 |
+
]
|
| 29 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from openai import OpenAI
|
| 13 |
+
|
| 14 |
+
from support_triage_openenv import Action, SupportTriageEnv
|
| 15 |
+
|
| 16 |
+
# Mandatory variables requested by organizers.
|
| 17 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 18 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 19 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 20 |
+
|
| 21 |
+
BENCHMARK = os.getenv("SUPPORT_TRIAGE_BENCHMARK", "support-triage-openenv")
|
| 22 |
+
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.9"))
|
| 23 |
+
|
| 24 |
+
SYSTEM_PROMPT = (
|
| 25 |
+
"You are solving customer support ticket triage. "
|
| 26 |
+
"Return exactly one JSON object with keys: "
|
| 27 |
+
"action_type, ticket_id, priority, category, needs_escalation, message."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
RULE_POLICY: dict[str, list[dict[str, Any]]] = {
|
| 31 |
+
"easy_password_reset": [
|
| 32 |
+
{"action_type": "read_ticket", "ticket_id": "T-1001"},
|
| 33 |
+
{
|
| 34 |
+
"action_type": "classify_ticket",
|
| 35 |
+
"ticket_id": "T-1001",
|
| 36 |
+
"priority": "medium",
|
| 37 |
+
"category": "account",
|
| 38 |
+
"needs_escalation": False,
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"action_type": "draft_reply",
|
| 42 |
+
"message": (
|
| 43 |
+
"We will send a reset link to your email. For security, confirm the request "
|
| 44 |
+
"from your registered email before using the reset link."
|
| 45 |
+
),
|
| 46 |
+
},
|
| 47 |
+
{"action_type": "resolve_ticket", "ticket_id": "T-1001"},
|
| 48 |
+
],
|
| 49 |
+
"medium_billing_dispute": [
|
| 50 |
+
{"action_type": "read_ticket", "ticket_id": "T-2001"},
|
| 51 |
+
{"action_type": "read_ticket", "ticket_id": "T-2002"},
|
| 52 |
+
{
|
| 53 |
+
"action_type": "classify_ticket",
|
| 54 |
+
"ticket_id": "T-2001",
|
| 55 |
+
"priority": "high",
|
| 56 |
+
"category": "billing",
|
| 57 |
+
"needs_escalation": False,
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"action_type": "draft_reply",
|
| 61 |
+
"message": (
|
| 62 |
+
"We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. "
|
| 63 |
+
"Refund processing typically takes 3-5 business days."
|
| 64 |
+
),
|
| 65 |
+
},
|
| 66 |
+
{"action_type": "resolve_ticket", "ticket_id": "T-2001"},
|
| 67 |
+
],
|
| 68 |
+
"hard_outage_incident": [
|
| 69 |
+
{"action_type": "read_ticket", "ticket_id": "T-3001"},
|
| 70 |
+
{"action_type": "read_ticket", "ticket_id": "T-3002"},
|
| 71 |
+
{"action_type": "read_ticket", "ticket_id": "T-3003"},
|
| 72 |
+
{
|
| 73 |
+
"action_type": "classify_ticket",
|
| 74 |
+
"ticket_id": "T-3001",
|
| 75 |
+
"priority": "urgent",
|
| 76 |
+
"category": "technical",
|
| 77 |
+
"needs_escalation": True,
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"action_type": "draft_reply",
|
| 81 |
+
"message": (
|
| 82 |
+
"We have escalated this incident and are investigating now. "
|
| 83 |
+
"The status page will carry updates while we continue incident response."
|
| 84 |
+
),
|
| 85 |
+
},
|
| 86 |
+
{"action_type": "resolve_ticket", "ticket_id": "T-3001"},
|
| 87 |
+
],
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class EpisodeResult:
|
| 93 |
+
task_id: str
|
| 94 |
+
steps: int
|
| 95 |
+
score: float
|
| 96 |
+
success: bool
|
| 97 |
+
final_reward: float
|
| 98 |
+
rewards: list[float]
|
| 99 |
+
fallback_count: int
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 103 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
|
| 107 |
+
error_val = error if error else "null"
|
| 108 |
+
done_val = str(done).lower()
|
| 109 |
+
print(
|
| 110 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 111 |
+
flush=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 116 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 117 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _extract_json(text: str) -> str:
|
| 121 |
+
text = text.strip()
|
| 122 |
+
start = text.find("{")
|
| 123 |
+
end = text.rfind("}")
|
| 124 |
+
if start == -1 or end == -1 or end <= start:
|
| 125 |
+
raise ValueError("No JSON object found in model response")
|
| 126 |
+
return text[start : end + 1]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def heuristic_action(task_id: str, step_idx: int) -> Action:
|
| 130 |
+
plan = RULE_POLICY[task_id]
|
| 131 |
+
idx = min(step_idx, len(plan) - 1)
|
| 132 |
+
return Action.model_validate(plan[idx])
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def llm_action(client: OpenAI, observation: dict[str, Any], state: dict[str, Any]) -> Action:
|
| 136 |
+
prompt = json.dumps(
|
| 137 |
+
{
|
| 138 |
+
"instruction": "Pick the best next single action to maximize final task score.",
|
| 139 |
+
"observation": observation,
|
| 140 |
+
"state": state,
|
| 141 |
+
},
|
| 142 |
+
ensure_ascii=True,
|
| 143 |
+
)
|
| 144 |
+
completion = client.chat.completions.create(
|
| 145 |
+
model=MODEL_NAME,
|
| 146 |
+
messages=[
|
| 147 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 148 |
+
{"role": "user", "content": prompt},
|
| 149 |
+
],
|
| 150 |
+
temperature=0,
|
| 151 |
+
max_tokens=220,
|
| 152 |
+
stream=False,
|
| 153 |
+
)
|
| 154 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 155 |
+
payload = json.loads(_extract_json(text))
|
| 156 |
+
return Action.model_validate(payload)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def action_to_str(action: Action) -> str:
|
| 160 |
+
if action.action_type == "read_ticket":
|
| 161 |
+
return f"read_ticket({action.ticket_id})"
|
| 162 |
+
if action.action_type == "classify_ticket":
|
| 163 |
+
return (
|
| 164 |
+
f"classify_ticket({action.ticket_id},{action.priority},{action.category},"
|
| 165 |
+
f"{str(bool(action.needs_escalation)).lower()})"
|
| 166 |
+
)
|
| 167 |
+
if action.action_type == "draft_reply":
|
| 168 |
+
length = len((action.message or "").strip())
|
| 169 |
+
return f"draft_reply(len={length})"
|
| 170 |
+
if action.action_type == "resolve_ticket":
|
| 171 |
+
return f"resolve_ticket({action.ticket_id})"
|
| 172 |
+
return action.action_type
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def run_episode(
|
| 176 |
+
env: SupportTriageEnv,
|
| 177 |
+
task_id: str,
|
| 178 |
+
mode: str,
|
| 179 |
+
client: OpenAI | None,
|
| 180 |
+
started_at: float,
|
| 181 |
+
runtime_limit_seconds: int,
|
| 182 |
+
) -> EpisodeResult:
|
| 183 |
+
obs = env.reset(task_id)
|
| 184 |
+
done = False
|
| 185 |
+
info: dict[str, Any] = {}
|
| 186 |
+
rewards: list[float] = []
|
| 187 |
+
steps_taken = 0
|
| 188 |
+
fallback_count = 0
|
| 189 |
+
success = False
|
| 190 |
+
score = 0.0
|
| 191 |
+
final_reward = 0.0
|
| 192 |
+
|
| 193 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 194 |
+
|
| 195 |
+
while not done:
|
| 196 |
+
if time.monotonic() - started_at > runtime_limit_seconds:
|
| 197 |
+
raise TimeoutError(f"Runtime exceeded {runtime_limit_seconds}s")
|
| 198 |
+
|
| 199 |
+
step_idx = env.state()["step_count"]
|
| 200 |
+
|
| 201 |
+
if mode == "heuristic":
|
| 202 |
+
action = heuristic_action(task_id, step_idx)
|
| 203 |
+
else:
|
| 204 |
+
assert client is not None
|
| 205 |
+
try:
|
| 206 |
+
action = llm_action(client, obs.model_dump(), env.state())
|
| 207 |
+
except Exception:
|
| 208 |
+
fallback_count += 1
|
| 209 |
+
action = heuristic_action(task_id, step_idx)
|
| 210 |
+
|
| 211 |
+
step_error: str | None = None
|
| 212 |
+
try:
|
| 213 |
+
obs, reward, done, info = env.step(action)
|
| 214 |
+
reward_value = float(reward.value)
|
| 215 |
+
except Exception as exc:
|
| 216 |
+
step_error = str(exc)
|
| 217 |
+
reward_value = 0.0
|
| 218 |
+
done = True
|
| 219 |
+
|
| 220 |
+
steps_taken = step_idx + 1
|
| 221 |
+
rewards.append(reward_value)
|
| 222 |
+
final_reward = reward_value
|
| 223 |
+
|
| 224 |
+
log_step(
|
| 225 |
+
step=steps_taken,
|
| 226 |
+
action=action_to_str(action),
|
| 227 |
+
reward=reward_value,
|
| 228 |
+
done=done,
|
| 229 |
+
error=step_error,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if done:
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
+
score = max(0.0, min(1.0, float(info.get("grader_score", 0.0))))
|
| 236 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 237 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 238 |
+
|
| 239 |
+
return EpisodeResult(
|
| 240 |
+
task_id=task_id,
|
| 241 |
+
steps=steps_taken,
|
| 242 |
+
score=round(score, 4),
|
| 243 |
+
success=success,
|
| 244 |
+
final_reward=round(final_reward, 4),
|
| 245 |
+
rewards=[round(r, 4) for r in rewards],
|
| 246 |
+
fallback_count=fallback_count,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def main() -> None:
|
| 251 |
+
parser = argparse.ArgumentParser(description="Submission inference script.")
|
| 252 |
+
parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai")
|
| 253 |
+
parser.add_argument("--output", default="scores/inference_scores.json")
|
| 254 |
+
parser.add_argument("--runtime-limit-seconds", type=int, default=1200)
|
| 255 |
+
parser.add_argument("--task-id", default="", help="Optional single task id; default runs all tasks")
|
| 256 |
+
args = parser.parse_args()
|
| 257 |
+
|
| 258 |
+
if args.mode == "openai" and not HF_TOKEN:
|
| 259 |
+
raise RuntimeError("HF_TOKEN is required for --mode openai")
|
| 260 |
+
|
| 261 |
+
env = SupportTriageEnv()
|
| 262 |
+
task_ids = [args.task_id] if args.task_id else env.task_ids
|
| 263 |
+
|
| 264 |
+
client = None
|
| 265 |
+
if args.mode == "openai":
|
| 266 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 267 |
+
|
| 268 |
+
started_at = time.monotonic()
|
| 269 |
+
episodes: list[EpisodeResult] = []
|
| 270 |
+
for task_id in task_ids:
|
| 271 |
+
if task_id not in env.task_ids:
|
| 272 |
+
raise ValueError(f"Unknown task_id '{task_id}'")
|
| 273 |
+
episodes.append(
|
| 274 |
+
run_episode(
|
| 275 |
+
env=env,
|
| 276 |
+
task_id=task_id,
|
| 277 |
+
mode=args.mode,
|
| 278 |
+
client=client,
|
| 279 |
+
started_at=started_at,
|
| 280 |
+
runtime_limit_seconds=args.runtime_limit_seconds,
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
summary = {
|
| 285 |
+
"mode": args.mode,
|
| 286 |
+
"api_base_url": API_BASE_URL,
|
| 287 |
+
"model_name": MODEL_NAME,
|
| 288 |
+
"avg_score": round(sum(e.score for e in episodes) / len(episodes), 4),
|
| 289 |
+
"avg_final_reward": round(sum(e.final_reward for e in episodes) / len(episodes), 4),
|
| 290 |
+
"total_steps": int(sum(e.steps for e in episodes)),
|
| 291 |
+
"episodes": [asdict(e) for e in episodes],
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
output_path = Path(args.output)
|
| 295 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 296 |
+
output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
main()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: support-triage-openenv
|
| 2 |
+
version: 0.1.0
|
| 3 |
+
entrypoint: src/support_triage_openenv/env.py:SupportTriageEnv
|
| 4 |
+
description: |
|
| 5 |
+
Real-world customer support ticket triage environment with deterministic tasks,
|
| 6 |
+
typed action/observation/reward models, and dense reward shaping.
|
| 7 |
+
license: mit
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- customer-support
|
| 12 |
+
maintainers:
|
| 13 |
+
- name: meta-hackathon-team
|
| 14 |
+
api:
|
| 15 |
+
reset: "reset(task_id?: str) -> Observation"
|
| 16 |
+
step: "step(action: Action) -> tuple[Observation, Reward, bool, info]"
|
| 17 |
+
state: "state() -> dict"
|
| 18 |
+
models:
|
| 19 |
+
observation: support_triage_openenv.models.Observation
|
| 20 |
+
action: support_triage_openenv.models.Action
|
| 21 |
+
reward: support_triage_openenv.models.Reward
|
| 22 |
+
tasks:
|
| 23 |
+
- id: easy_password_reset
|
| 24 |
+
difficulty: easy
|
| 25 |
+
grader: support_triage_openenv.graders.grade_task
|
| 26 |
+
- id: medium_billing_dispute
|
| 27 |
+
difficulty: medium
|
| 28 |
+
grader: support_triage_openenv.graders.grade_task
|
| 29 |
+
- id: hard_outage_incident
|
| 30 |
+
difficulty: hard
|
| 31 |
+
grader: support_triage_openenv.graders.grade_task
|
pre_submission_validate.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import importlib
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import httpx
|
| 16 |
+
import yaml
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
|
| 19 |
+
ROOT = Path(__file__).resolve().parent
|
| 20 |
+
if str(ROOT / "src") not in sys.path:
|
| 21 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 22 |
+
if str(ROOT) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(ROOT))
|
| 24 |
+
|
| 25 |
+
START_RE = re.compile(r"^\[START\] task=([^ ]+) env=([^ ]+) model=(.+)$")
|
| 26 |
+
STEP_RE = re.compile(r"^\[STEP\] step=(\d+) action=(.+) reward=([0-9]+\.[0-9]{2}) done=(true|false) error=(.+)$")
|
| 27 |
+
END_RE = re.compile(r"^\[END\] success=(true|false) steps=(\d+) score=([0-9]+\.[0-9]{3}) rewards=([0-9\.,-]*)$")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class CheckResult:
|
| 32 |
+
name: str
|
| 33 |
+
passed: bool
|
| 34 |
+
detail: str
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run_command(cmd: list[str], timeout: int = 300) -> tuple[int, str, str]:
|
| 38 |
+
proc = subprocess.run(cmd, cwd=ROOT, capture_output=True, text=True, timeout=timeout)
|
| 39 |
+
return proc.returncode, proc.stdout, proc.stderr
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def check_env_config() -> CheckResult:
|
| 43 |
+
required = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"]
|
| 44 |
+
missing = [k for k in required if not os.getenv(k)]
|
| 45 |
+
if missing:
|
| 46 |
+
return CheckResult("Env vars configured", False, f"Missing: {', '.join(missing)}")
|
| 47 |
+
return CheckResult("Env vars configured", True, "API_BASE_URL, MODEL_NAME, HF_TOKEN are set")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def check_inference_file() -> CheckResult:
|
| 51 |
+
path = ROOT / "inference.py"
|
| 52 |
+
if not path.exists():
|
| 53 |
+
return CheckResult("Root inference.py", False, "inference.py missing at repo root")
|
| 54 |
+
|
| 55 |
+
text = path.read_text(encoding="utf-8")
|
| 56 |
+
required_snippets = [
|
| 57 |
+
"from openai import OpenAI",
|
| 58 |
+
"API_BASE_URL",
|
| 59 |
+
"MODEL_NAME",
|
| 60 |
+
"HF_TOKEN",
|
| 61 |
+
"[START] task=",
|
| 62 |
+
"[STEP] step=",
|
| 63 |
+
"[END] success=",
|
| 64 |
+
]
|
| 65 |
+
missing = [s for s in required_snippets if s not in text]
|
| 66 |
+
if missing:
|
| 67 |
+
return CheckResult("Root inference.py", False, f"Missing required content: {missing}")
|
| 68 |
+
|
| 69 |
+
return CheckResult("Root inference.py", True, "Found required script name, env vars, OpenAI client, and organizer log format")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def check_openenv_compliance() -> CheckResult:
|
| 73 |
+
cfg_path = ROOT / "openenv.yaml"
|
| 74 |
+
if not cfg_path.exists():
|
| 75 |
+
return CheckResult("OpenEnv compliance", False, "openenv.yaml not found")
|
| 76 |
+
|
| 77 |
+
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
| 78 |
+
for key in ["entrypoint", "models", "tasks", "api"]:
|
| 79 |
+
if key not in cfg:
|
| 80 |
+
return CheckResult("OpenEnv compliance", False, f"Missing key in openenv.yaml: {key}")
|
| 81 |
+
|
| 82 |
+
entrypoint = cfg["entrypoint"]
|
| 83 |
+
if ":" not in entrypoint:
|
| 84 |
+
return CheckResult("OpenEnv compliance", False, "Entrypoint must be <path>:<ClassName>")
|
| 85 |
+
|
| 86 |
+
fs_path, class_name = entrypoint.split(":", 1)
|
| 87 |
+
module_name = fs_path.replace("/", ".").replace(".py", "")
|
| 88 |
+
module = importlib.import_module(module_name)
|
| 89 |
+
env_cls = getattr(module, class_name, None)
|
| 90 |
+
if env_cls is None:
|
| 91 |
+
return CheckResult("OpenEnv compliance", False, f"Entrypoint class not found: {class_name}")
|
| 92 |
+
|
| 93 |
+
env = env_cls()
|
| 94 |
+
for method_name in ["reset", "step", "state"]:
|
| 95 |
+
if not callable(getattr(env, method_name, None)):
|
| 96 |
+
return CheckResult("OpenEnv compliance", False, f"Missing callable method: {method_name}")
|
| 97 |
+
|
| 98 |
+
model_refs = cfg.get("models", {})
|
| 99 |
+
for model_name in ["observation", "action", "reward"]:
|
| 100 |
+
dotted = model_refs.get(model_name)
|
| 101 |
+
if not dotted or "." not in dotted:
|
| 102 |
+
return CheckResult("OpenEnv compliance", False, f"Invalid model ref for {model_name}: {dotted}")
|
| 103 |
+
mod_name, cls_name = dotted.rsplit(".", 1)
|
| 104 |
+
cls = getattr(importlib.import_module(mod_name), cls_name, None)
|
| 105 |
+
if cls is None or not issubclass(cls, BaseModel):
|
| 106 |
+
return CheckResult("OpenEnv compliance", False, f"{dotted} must resolve to Pydantic BaseModel")
|
| 107 |
+
|
| 108 |
+
obs = env.reset(cfg["tasks"][0]["id"])
|
| 109 |
+
if not isinstance(obs, BaseModel):
|
| 110 |
+
return CheckResult("OpenEnv compliance", False, "reset() must return typed model")
|
| 111 |
+
|
| 112 |
+
action_mod_name, action_cls_name = model_refs["action"].rsplit(".", 1)
|
| 113 |
+
action_cls = getattr(importlib.import_module(action_mod_name), action_cls_name)
|
| 114 |
+
action = action_cls(action_type="read_ticket", ticket_id="T-1001")
|
| 115 |
+
obs2, reward, done, info = env.step(action)
|
| 116 |
+
|
| 117 |
+
if not isinstance(obs2, BaseModel):
|
| 118 |
+
return CheckResult("OpenEnv compliance", False, "step() observation must be typed model")
|
| 119 |
+
if not isinstance(reward, BaseModel):
|
| 120 |
+
return CheckResult("OpenEnv compliance", False, "step() reward must be typed model")
|
| 121 |
+
if not isinstance(done, bool):
|
| 122 |
+
return CheckResult("OpenEnv compliance", False, "step() done must be bool")
|
| 123 |
+
if not isinstance(info, dict):
|
| 124 |
+
return CheckResult("OpenEnv compliance", False, "step() info must be dict")
|
| 125 |
+
if not isinstance(env.state(), dict):
|
| 126 |
+
return CheckResult("OpenEnv compliance", False, "state() must return dict")
|
| 127 |
+
|
| 128 |
+
return CheckResult("OpenEnv compliance", True, "openenv.yaml + typed models + reset/step/state validated")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def check_task_graders() -> CheckResult:
|
| 132 |
+
inference = importlib.import_module("inference")
|
| 133 |
+
env_mod = importlib.import_module("support_triage_openenv.env")
|
| 134 |
+
action_mod = importlib.import_module("support_triage_openenv.models")
|
| 135 |
+
|
| 136 |
+
env = env_mod.SupportTriageEnv()
|
| 137 |
+
task_ids = env.task_ids
|
| 138 |
+
if len(task_ids) < 3:
|
| 139 |
+
return CheckResult("3+ tasks with graders", False, f"Expected >=3 tasks, got {len(task_ids)}")
|
| 140 |
+
|
| 141 |
+
details: list[str] = []
|
| 142 |
+
for task_id in task_ids:
|
| 143 |
+
env.reset(task_id)
|
| 144 |
+
done = False
|
| 145 |
+
info: dict[str, Any] = {}
|
| 146 |
+
|
| 147 |
+
while not done:
|
| 148 |
+
step_idx = env.state()["step_count"]
|
| 149 |
+
raw_action = inference.RULE_POLICY[task_id][min(step_idx, len(inference.RULE_POLICY[task_id]) - 1)]
|
| 150 |
+
action = action_mod.Action.model_validate(raw_action)
|
| 151 |
+
_, reward, done, info = env.step(action)
|
| 152 |
+
reward_value = float(reward.value)
|
| 153 |
+
if not (0.0 <= reward_value <= 1.0):
|
| 154 |
+
return CheckResult("3+ tasks with graders", False, f"Reward out of range in {task_id}: {reward_value}")
|
| 155 |
+
|
| 156 |
+
grader_score = float(info.get("grader_score", -1.0))
|
| 157 |
+
if not (0.0 <= grader_score <= 1.0):
|
| 158 |
+
return CheckResult("3+ tasks with graders", False, f"Grader out of range in {task_id}: {grader_score}")
|
| 159 |
+
|
| 160 |
+
details.append(f"{task_id}:{grader_score:.4f}")
|
| 161 |
+
|
| 162 |
+
return CheckResult("3+ tasks with graders", True, " | ".join(details))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _validate_log_sequence(lines: list[str]) -> tuple[bool, str]:
|
| 166 |
+
if not lines:
|
| 167 |
+
return False, "No stdout lines from inference.py"
|
| 168 |
+
|
| 169 |
+
phase = "need_start"
|
| 170 |
+
steps_seen = 0
|
| 171 |
+
episodes = 0
|
| 172 |
+
|
| 173 |
+
for line in lines:
|
| 174 |
+
if line.startswith("[START]"):
|
| 175 |
+
if phase != "need_start":
|
| 176 |
+
return False, "[START] appeared before previous episode ended"
|
| 177 |
+
if not START_RE.match(line):
|
| 178 |
+
return False, f"Invalid [START] format: {line}"
|
| 179 |
+
phase = "need_step_or_end"
|
| 180 |
+
steps_seen = 0
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
if line.startswith("[STEP]"):
|
| 184 |
+
if phase != "need_step_or_end":
|
| 185 |
+
return False, "[STEP] appeared before [START]"
|
| 186 |
+
m = STEP_RE.match(line)
|
| 187 |
+
if not m:
|
| 188 |
+
return False, f"Invalid [STEP] format: {line}"
|
| 189 |
+
reward = float(m.group(3))
|
| 190 |
+
if reward < 0.0 or reward > 1.0:
|
| 191 |
+
return False, f"[STEP] reward out of range: {reward}"
|
| 192 |
+
steps_seen += 1
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
if line.startswith("[END]"):
|
| 196 |
+
if phase != "need_step_or_end":
|
| 197 |
+
return False, "[END] appeared before [START]"
|
| 198 |
+
m = END_RE.match(line)
|
| 199 |
+
if not m:
|
| 200 |
+
return False, f"Invalid [END] format: {line}"
|
| 201 |
+
end_steps = int(m.group(2))
|
| 202 |
+
score = float(m.group(3))
|
| 203 |
+
rewards_blob = m.group(4)
|
| 204 |
+
|
| 205 |
+
if end_steps != steps_seen:
|
| 206 |
+
return False, f"[END] steps mismatch: expected {steps_seen}, got {end_steps}"
|
| 207 |
+
if score < 0.0 or score > 1.0:
|
| 208 |
+
return False, f"[END] score out of range: {score}"
|
| 209 |
+
|
| 210 |
+
rewards = [r for r in rewards_blob.split(",") if r != ""]
|
| 211 |
+
if len(rewards) != steps_seen:
|
| 212 |
+
return False, f"[END] rewards count mismatch: expected {steps_seen}, got {len(rewards)}"
|
| 213 |
+
for r in rewards:
|
| 214 |
+
rv = float(r)
|
| 215 |
+
if rv < 0.0 or rv > 1.0:
|
| 216 |
+
return False, f"[END] reward out of range: {rv}"
|
| 217 |
+
|
| 218 |
+
episodes += 1
|
| 219 |
+
phase = "need_start"
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
return False, f"Unexpected stdout line (must be START/STEP/END only): {line}"
|
| 223 |
+
|
| 224 |
+
if phase != "need_start":
|
| 225 |
+
return False, "Missing [END] for final episode"
|
| 226 |
+
if episodes == 0:
|
| 227 |
+
return False, "No complete episodes found"
|
| 228 |
+
|
| 229 |
+
return True, f"Validated {episodes} episode log sequences"
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def check_inference_repro() -> CheckResult:
|
| 233 |
+
output_path = ROOT / "scores" / "inference_scores.json"
|
| 234 |
+
cmd = [sys.executable, "inference.py", "--mode", "heuristic", "--output", str(output_path)]
|
| 235 |
+
code, out, err = run_command(cmd, timeout=120)
|
| 236 |
+
if code != 0:
|
| 237 |
+
return CheckResult("Baseline reproduces", False, f"inference.py failed: {err.strip() or out.strip()}")
|
| 238 |
+
|
| 239 |
+
if not output_path.exists():
|
| 240 |
+
return CheckResult("Baseline reproduces", False, "scores/inference_scores.json was not created")
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
payload = json.loads(output_path.read_text(encoding="utf-8"))
|
| 244 |
+
except Exception as exc:
|
| 245 |
+
return CheckResult("Baseline reproduces", False, f"Invalid JSON output: {exc}")
|
| 246 |
+
|
| 247 |
+
for key in ["avg_score", "avg_final_reward", "episodes"]:
|
| 248 |
+
if key not in payload:
|
| 249 |
+
return CheckResult("Baseline reproduces", False, f"Missing key in output JSON: {key}")
|
| 250 |
+
|
| 251 |
+
lines = [ln.strip() for ln in out.splitlines() if ln.strip()]
|
| 252 |
+
ok, detail = _validate_log_sequence(lines)
|
| 253 |
+
if not ok:
|
| 254 |
+
return CheckResult("Baseline reproduces", False, detail)
|
| 255 |
+
|
| 256 |
+
return CheckResult("Baseline reproduces", True, f"inference.py completed and wrote {output_path.relative_to(ROOT)}; {detail}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def check_docker_build(skip: bool) -> CheckResult:
|
| 260 |
+
if skip:
|
| 261 |
+
return CheckResult("Dockerfile builds", True, "Skipped by --skip-docker")
|
| 262 |
+
|
| 263 |
+
code, out, err = run_command(["docker", "build", "-t", "support-triage-openenv:presubmit", "."], timeout=900)
|
| 264 |
+
if code != 0:
|
| 265 |
+
msg = (err or out).strip().splitlines()
|
| 266 |
+
short = msg[-1] if msg else "docker build failed"
|
| 267 |
+
return CheckResult("Dockerfile builds", False, short)
|
| 268 |
+
return CheckResult("Dockerfile builds", True, "docker build succeeded")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def check_space_ping(space_url: str | None, skip: bool) -> CheckResult:
|
| 272 |
+
if skip:
|
| 273 |
+
return CheckResult("HF Space deploys + ping", True, "Skipped by --skip-space")
|
| 274 |
+
|
| 275 |
+
if not space_url:
|
| 276 |
+
return CheckResult("HF Space deploys + ping", False, "Provide --space-url (or use --skip-space for local-only checks)")
|
| 277 |
+
|
| 278 |
+
base = space_url.rstrip("/")
|
| 279 |
+
try:
|
| 280 |
+
with httpx.Client(timeout=20.0) as client:
|
| 281 |
+
reset = client.post(f"{base}/reset", json={"task_id": "easy_password_reset"})
|
| 282 |
+
if reset.status_code != 200:
|
| 283 |
+
return CheckResult("HF Space deploys + ping", False, f"POST /reset returned {reset.status_code}")
|
| 284 |
+
|
| 285 |
+
payload = reset.json()
|
| 286 |
+
if payload.get("task_id") != "easy_password_reset":
|
| 287 |
+
return CheckResult("HF Space deploys + ping", False, "reset() payload missing expected task_id")
|
| 288 |
+
except Exception as exc:
|
| 289 |
+
return CheckResult("HF Space deploys + ping", False, f"Ping failed: {exc}")
|
| 290 |
+
|
| 291 |
+
return CheckResult("HF Space deploys + ping", True, f"{base} returned 200 and reset() works")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def check_organizer_script(space_url: str | None, skip: bool) -> CheckResult:
|
| 295 |
+
if skip:
|
| 296 |
+
return CheckResult("Organizer pre-validation script", True, "Skipped")
|
| 297 |
+
|
| 298 |
+
script_path = ROOT / "scripts" / "pre_validation_script.sh"
|
| 299 |
+
if not script_path.exists():
|
| 300 |
+
return CheckResult("Organizer pre-validation script", False, "scripts/pre_validation_script.sh not found")
|
| 301 |
+
if not space_url:
|
| 302 |
+
return CheckResult("Organizer pre-validation script", False, "Requires --space-url")
|
| 303 |
+
|
| 304 |
+
code, out, err = run_command(["bash", str(script_path), space_url, str(ROOT)], timeout=1800)
|
| 305 |
+
if code != 0:
|
| 306 |
+
tail = (out + "\n" + err).strip().splitlines()[-5:]
|
| 307 |
+
return CheckResult("Organizer pre-validation script", False, " | ".join(tail) if tail else "script failed")
|
| 308 |
+
|
| 309 |
+
return CheckResult("Organizer pre-validation script", True, "Organizer script passed")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def run_all(args: argparse.Namespace) -> list[CheckResult]:
|
| 313 |
+
organizer_skip = args.skip_organizer_script or args.skip_space or args.skip_docker
|
| 314 |
+
return [
|
| 315 |
+
check_env_config(),
|
| 316 |
+
check_inference_file(),
|
| 317 |
+
check_openenv_compliance(),
|
| 318 |
+
check_task_graders(),
|
| 319 |
+
check_inference_repro(),
|
| 320 |
+
check_docker_build(skip=args.skip_docker),
|
| 321 |
+
check_space_ping(space_url=args.space_url, skip=args.skip_space),
|
| 322 |
+
check_organizer_script(space_url=args.space_url, skip=organizer_skip),
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def main() -> None:
|
| 327 |
+
parser = argparse.ArgumentParser(description="Pre-submission validator for Meta HF hackathon OpenEnv env.")
|
| 328 |
+
parser.add_argument("--space-url", default=os.getenv("SPACE_URL"), help="Deployed HF Space URL for ping checks")
|
| 329 |
+
parser.add_argument("--skip-docker", action="store_true", help="Skip docker build check")
|
| 330 |
+
parser.add_argument("--skip-space", action="store_true", help="Skip remote Space ping check")
|
| 331 |
+
parser.add_argument("--skip-organizer-script", action="store_true", help="Skip organizer-provided pre-validation script")
|
| 332 |
+
args = parser.parse_args()
|
| 333 |
+
|
| 334 |
+
results = run_all(args)
|
| 335 |
+
|
| 336 |
+
print("\n=== Pre-Submission Checklist Report ===")
|
| 337 |
+
for r in results:
|
| 338 |
+
status = "PASS" if r.passed else "FAIL"
|
| 339 |
+
print(f"[{status}] {r.name}: {r.detail}")
|
| 340 |
+
|
| 341 |
+
failed = [r for r in results if not r.passed]
|
| 342 |
+
print("\nSummary:")
|
| 343 |
+
print(f"- Total checks: {len(results)}")
|
| 344 |
+
print(f"- Passed: {len(results) - len(failed)}")
|
| 345 |
+
print(f"- Failed: {len(failed)}")
|
| 346 |
+
|
| 347 |
+
if failed:
|
| 348 |
+
sys.exit(1)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if __name__ == "__main__":
|
| 352 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "support-triage-openenv"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "OpenEnv-compliant customer support triage RL environment"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
authors = [{ name = "meta-hackathon-team" }]
|
| 12 |
+
dependencies = [
|
| 13 |
+
"pydantic>=2.7.0",
|
| 14 |
+
"PyYAML>=6.0.1",
|
| 15 |
+
"openai>=1.40.0",
|
| 16 |
+
"fastapi>=0.115.0",
|
| 17 |
+
"uvicorn>=0.30.0",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[project.optional-dependencies]
|
| 21 |
+
dev = [
|
| 22 |
+
"pytest>=8.2.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[tool.setuptools]
|
| 26 |
+
package-dir = {"" = "src"}
|
| 27 |
+
|
| 28 |
+
[tool.setuptools.packages.find]
|
| 29 |
+
where = ["src"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-e .
|
| 2 |
+
pytest>=8.2.0
|
scripts/bootstrap_remotes.sh
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
REPO_NAME="${1:-support-triage-openenv}"
|
| 5 |
+
GITHUB_VISIBILITY="${GITHUB_VISIBILITY:-public}"
|
| 6 |
+
|
| 7 |
+
if [[ -z "${GITHUB_TOKEN:-}" || -z "${GITHUB_USERNAME:-}" ]]; then
|
| 8 |
+
echo "Missing GITHUB_TOKEN or GITHUB_USERNAME" >&2
|
| 9 |
+
exit 1
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
if [[ -z "${HF_TOKEN:-}" || -z "${HF_USERNAME:-}" ]]; then
|
| 13 |
+
echo "Missing HF_TOKEN or HF_USERNAME" >&2
|
| 14 |
+
exit 1
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
if ! command -v huggingface-cli >/dev/null 2>&1; then
|
| 18 |
+
echo "huggingface-cli is required but not installed." >&2
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
# 1) Create GitHub repo via REST API (works without gh CLI)
|
| 23 |
+
create_payload=$(cat <<JSON
|
| 24 |
+
{
|
| 25 |
+
"name": "${REPO_NAME}",
|
| 26 |
+
"private": $( [[ "${GITHUB_VISIBILITY}" == "private" ]] && echo true || echo false ),
|
| 27 |
+
"description": "OpenEnv customer support triage environment",
|
| 28 |
+
"auto_init": false
|
| 29 |
+
}
|
| 30 |
+
JSON
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
curl -sS -o /tmp/github_repo_create.json -w "%{http_code}" \
|
| 34 |
+
-H "Accept: application/vnd.github+json" \
|
| 35 |
+
-H "Authorization: Bearer ${GITHUB_TOKEN}" \
|
| 36 |
+
-H "X-GitHub-Api-Version: 2022-11-28" \
|
| 37 |
+
https://api.github.com/user/repos \
|
| 38 |
+
-d "${create_payload}" >/tmp/github_repo_create_status.txt
|
| 39 |
+
|
| 40 |
+
status_code="$(cat /tmp/github_repo_create_status.txt)"
|
| 41 |
+
if [[ "${status_code}" != "201" && "${status_code}" != "422" ]]; then
|
| 42 |
+
echo "GitHub repo creation failed with HTTP ${status_code}" >&2
|
| 43 |
+
cat /tmp/github_repo_create.json >&2
|
| 44 |
+
exit 1
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
if git remote get-url origin >/dev/null 2>&1; then
|
| 48 |
+
git remote set-url origin "https://github.com/${GITHUB_USERNAME}/${REPO_NAME}.git"
|
| 49 |
+
else
|
| 50 |
+
git remote add origin "https://github.com/${GITHUB_USERNAME}/${REPO_NAME}.git"
|
| 51 |
+
fi
|
| 52 |
+
|
| 53 |
+
# Authenticated push URL
|
| 54 |
+
git remote set-url --push origin "https://${GITHUB_USERNAME}:${GITHUB_TOKEN}@github.com/${GITHUB_USERNAME}/${REPO_NAME}.git"
|
| 55 |
+
|
| 56 |
+
git push -u origin main
|
| 57 |
+
|
| 58 |
+
# Reset push URL to tokenless remote after push
|
| 59 |
+
git remote set-url --push origin "https://github.com/${GITHUB_USERNAME}/${REPO_NAME}.git"
|
| 60 |
+
|
| 61 |
+
# 2) Create HF Docker Space repo and push
|
| 62 |
+
huggingface-cli repo create "${HF_USERNAME}/${REPO_NAME}" \
|
| 63 |
+
--repo-type space \
|
| 64 |
+
--space_sdk docker \
|
| 65 |
+
--token "${HF_TOKEN}" \
|
| 66 |
+
--exist-ok >/tmp/hf_repo_create.log
|
| 67 |
+
|
| 68 |
+
if git remote get-url huggingface >/dev/null 2>&1; then
|
| 69 |
+
git remote set-url huggingface "https://huggingface.co/spaces/${HF_USERNAME}/${REPO_NAME}"
|
| 70 |
+
else
|
| 71 |
+
git remote add huggingface "https://huggingface.co/spaces/${HF_USERNAME}/${REPO_NAME}"
|
| 72 |
+
fi
|
| 73 |
+
|
| 74 |
+
git remote set-url --push huggingface "https://user:${HF_TOKEN}@huggingface.co/spaces/${HF_USERNAME}/${REPO_NAME}"
|
| 75 |
+
|
| 76 |
+
git push -u huggingface main
|
| 77 |
+
|
| 78 |
+
# Reset push URL to tokenless remote after push
|
| 79 |
+
git remote set-url --push huggingface "https://huggingface.co/spaces/${HF_USERNAME}/${REPO_NAME}"
|
| 80 |
+
|
| 81 |
+
echo "Completed."
|
| 82 |
+
echo "GitHub: https://github.com/${GITHUB_USERNAME}/${REPO_NAME}"
|
| 83 |
+
echo "HF Space: https://huggingface.co/spaces/${HF_USERNAME}/${REPO_NAME}"
|
scripts/pre_validation_script.sh
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 4 |
+
#
|
| 5 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 6 |
+
#
|
| 7 |
+
# Prerequisites:
|
| 8 |
+
# - Docker: https://docs.docker.com/get-docker/
|
| 9 |
+
# - openenv-core: pip install openenv-core
|
| 10 |
+
# - curl (usually pre-installed)
|
| 11 |
+
#
|
| 12 |
+
# Run:
|
| 13 |
+
# curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
|
| 14 |
+
#
|
| 15 |
+
# Or download and run locally:
|
| 16 |
+
# chmod +x validate-submission.sh
|
| 17 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 18 |
+
#
|
| 19 |
+
# Arguments:
|
| 20 |
+
# ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
|
| 21 |
+
# repo_dir Path to your repo (default: current directory)
|
| 22 |
+
#
|
| 23 |
+
# Examples:
|
| 24 |
+
# ./validate-submission.sh https://my-team.hf.space
|
| 25 |
+
# ./validate-submission.sh https://my-team.hf.space ./my-repo
|
| 26 |
+
#
|
| 27 |
+
|
| 28 |
+
set -uo pipefail
|
| 29 |
+
|
| 30 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 31 |
+
if [ -t 1 ]; then
|
| 32 |
+
RED='\033[0;31m'
|
| 33 |
+
GREEN='\033[0;32m'
|
| 34 |
+
YELLOW='\033[1;33m'
|
| 35 |
+
BOLD='\033[1m'
|
| 36 |
+
NC='\033[0m'
|
| 37 |
+
else
|
| 38 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
run_with_timeout() {
|
| 42 |
+
local secs="$1"; shift
|
| 43 |
+
if command -v timeout &>/dev/null; then
|
| 44 |
+
timeout "$secs" "$@"
|
| 45 |
+
elif command -v gtimeout &>/dev/null; then
|
| 46 |
+
gtimeout "$secs" "$@"
|
| 47 |
+
else
|
| 48 |
+
"$@" &
|
| 49 |
+
local pid=$!
|
| 50 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 51 |
+
local watcher=$!
|
| 52 |
+
wait "$pid" 2>/dev/null
|
| 53 |
+
local rc=$?
|
| 54 |
+
kill "$watcher" 2>/dev/null
|
| 55 |
+
wait "$watcher" 2>/dev/null
|
| 56 |
+
return $rc
|
| 57 |
+
fi
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
portable_mktemp() {
|
| 61 |
+
local prefix="${1:-validate}"
|
| 62 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CLEANUP_FILES=()
|
| 66 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 67 |
+
trap cleanup EXIT
|
| 68 |
+
|
| 69 |
+
PING_URL="${1:-}"
|
| 70 |
+
REPO_DIR="${2:-.}"
|
| 71 |
+
|
| 72 |
+
if [ -z "$PING_URL" ]; then
|
| 73 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 74 |
+
printf "\n"
|
| 75 |
+
printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
|
| 76 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 81 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
PING_URL="${PING_URL%/}"
|
| 85 |
+
export PING_URL
|
| 86 |
+
PASS=0
|
| 87 |
+
|
| 88 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 89 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 90 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 91 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 92 |
+
stop_at() {
|
| 93 |
+
printf "\n"
|
| 94 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 95 |
+
exit 1
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
printf "\n"
|
| 99 |
+
printf "${BOLD}========================================${NC}\n"
|
| 100 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 101 |
+
printf "${BOLD}========================================${NC}\n"
|
| 102 |
+
log "Repo: $REPO_DIR"
|
| 103 |
+
log "Ping URL: $PING_URL"
|
| 104 |
+
printf "\n"
|
| 105 |
+
|
| 106 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 107 |
+
|
| 108 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 109 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 110 |
+
HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
|
| 111 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 112 |
+
"$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
|
| 113 |
+
|
| 114 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 115 |
+
pass "HF Space is live and responds to /reset"
|
| 116 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 117 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 118 |
+
hint "Check your network connection and that the Space is running."
|
| 119 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 120 |
+
stop_at "Step 1"
|
| 121 |
+
else
|
| 122 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 123 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 124 |
+
hint "Try opening $PING_URL in your browser first."
|
| 125 |
+
stop_at "Step 1"
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 129 |
+
|
| 130 |
+
if ! command -v docker &>/dev/null; then
|
| 131 |
+
fail "docker command not found"
|
| 132 |
+
hint "Install Docker: https://docs.docker.com/get-docker/"
|
| 133 |
+
stop_at "Step 2"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 137 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 138 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 139 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 140 |
+
else
|
| 141 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 142 |
+
stop_at "Step 2"
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 146 |
+
|
| 147 |
+
BUILD_OK=false
|
| 148 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 149 |
+
|
| 150 |
+
if [ "$BUILD_OK" = true ]; then
|
| 151 |
+
pass "Docker build succeeded"
|
| 152 |
+
else
|
| 153 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 154 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 155 |
+
stop_at "Step 2"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 159 |
+
|
| 160 |
+
if ! command -v openenv &>/dev/null; then
|
| 161 |
+
fail "openenv command not found"
|
| 162 |
+
hint "Install it: pip install openenv-core"
|
| 163 |
+
stop_at "Step 3"
|
| 164 |
+
fi
|
| 165 |
+
|
| 166 |
+
VALIDATE_OK=false
|
| 167 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 168 |
+
|
| 169 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 170 |
+
pass "openenv validate passed"
|
| 171 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 172 |
+
else
|
| 173 |
+
fail "openenv validate failed"
|
| 174 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 175 |
+
stop_at "Step 3"
|
| 176 |
+
fi
|
| 177 |
+
|
| 178 |
+
printf "\n"
|
| 179 |
+
printf "${BOLD}========================================${NC}\n"
|
| 180 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 181 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 182 |
+
printf "${BOLD}========================================${NC}\n"
|
| 183 |
+
printf "\n"
|
| 184 |
+
|
| 185 |
+
exit 0
|
scripts/run_baseline.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import asdict, dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from openai import OpenAI
|
| 12 |
+
|
| 13 |
+
from support_triage_openenv import Action, SupportTriageEnv
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
SYSTEM_PROMPT = """You are an agent solving a customer-support triage environment.
|
| 17 |
+
Return exactly one JSON object for the next action with keys:
|
| 18 |
+
- action_type: read_ticket | classify_ticket | draft_reply | resolve_ticket
|
| 19 |
+
- ticket_id (required for read/classify/resolve)
|
| 20 |
+
- priority, category, needs_escalation (for classify)
|
| 21 |
+
- message (for draft_reply)
|
| 22 |
+
No markdown, no extra text."""
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class EpisodeResult:
|
| 27 |
+
task_id: str
|
| 28 |
+
steps: int
|
| 29 |
+
grader_score: float
|
| 30 |
+
reward: float
|
| 31 |
+
done_reason: str
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
RULE_POLICY: dict[str, list[dict[str, Any]]] = {
|
| 35 |
+
"easy_password_reset": [
|
| 36 |
+
{"action_type": "read_ticket", "ticket_id": "T-1001"},
|
| 37 |
+
{
|
| 38 |
+
"action_type": "classify_ticket",
|
| 39 |
+
"ticket_id": "T-1001",
|
| 40 |
+
"priority": "medium",
|
| 41 |
+
"category": "account",
|
| 42 |
+
"needs_escalation": False,
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"action_type": "draft_reply",
|
| 46 |
+
"message": (
|
| 47 |
+
"We will send a reset link to your email. For security, confirm the request "
|
| 48 |
+
"from your registered email before using the reset link."
|
| 49 |
+
),
|
| 50 |
+
},
|
| 51 |
+
{"action_type": "resolve_ticket", "ticket_id": "T-1001"},
|
| 52 |
+
],
|
| 53 |
+
"medium_billing_dispute": [
|
| 54 |
+
{"action_type": "read_ticket", "ticket_id": "T-2001"},
|
| 55 |
+
{"action_type": "read_ticket", "ticket_id": "T-2002"},
|
| 56 |
+
{
|
| 57 |
+
"action_type": "classify_ticket",
|
| 58 |
+
"ticket_id": "T-2001",
|
| 59 |
+
"priority": "high",
|
| 60 |
+
"category": "billing",
|
| 61 |
+
"needs_escalation": False,
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"action_type": "draft_reply",
|
| 65 |
+
"message": (
|
| 66 |
+
"We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. "
|
| 67 |
+
"Refund processing typically takes 3-5 business days."
|
| 68 |
+
),
|
| 69 |
+
},
|
| 70 |
+
{"action_type": "resolve_ticket", "ticket_id": "T-2001"},
|
| 71 |
+
],
|
| 72 |
+
"hard_outage_incident": [
|
| 73 |
+
{"action_type": "read_ticket", "ticket_id": "T-3001"},
|
| 74 |
+
{"action_type": "read_ticket", "ticket_id": "T-3002"},
|
| 75 |
+
{"action_type": "read_ticket", "ticket_id": "T-3003"},
|
| 76 |
+
{
|
| 77 |
+
"action_type": "classify_ticket",
|
| 78 |
+
"ticket_id": "T-3001",
|
| 79 |
+
"priority": "urgent",
|
| 80 |
+
"category": "technical",
|
| 81 |
+
"needs_escalation": True,
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"action_type": "draft_reply",
|
| 85 |
+
"message": (
|
| 86 |
+
"We have escalated this incident and are investigating now. "
|
| 87 |
+
"The status page will carry updates while we continue incident response."
|
| 88 |
+
),
|
| 89 |
+
},
|
| 90 |
+
{"action_type": "resolve_ticket", "ticket_id": "T-3001"},
|
| 91 |
+
],
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _extract_json(text: str) -> str:
|
| 96 |
+
text = text.strip()
|
| 97 |
+
start = text.find("{")
|
| 98 |
+
end = text.rfind("}")
|
| 99 |
+
if start == -1 or end == -1 or end <= start:
|
| 100 |
+
raise ValueError("No JSON object found in model response")
|
| 101 |
+
return text[start : end + 1]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def llm_action(client: OpenAI, model: str, observation: dict[str, Any], state: dict[str, Any]) -> Action:
|
| 105 |
+
user_prompt = json.dumps(
|
| 106 |
+
{
|
| 107 |
+
"observation": observation,
|
| 108 |
+
"state": state,
|
| 109 |
+
"instruction": "Pick the best next single action to maximize final score.",
|
| 110 |
+
},
|
| 111 |
+
ensure_ascii=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
response = client.responses.create(
|
| 115 |
+
model=model,
|
| 116 |
+
temperature=0,
|
| 117 |
+
top_p=1,
|
| 118 |
+
input=[
|
| 119 |
+
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 120 |
+
{"role": "user", "content": [{"type": "text", "text": user_prompt}]},
|
| 121 |
+
],
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
raw = response.output_text or ""
|
| 125 |
+
payload = json.loads(_extract_json(raw))
|
| 126 |
+
return Action.model_validate(payload)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def heuristic_action(task_id: str, step_idx: int) -> Action:
|
| 130 |
+
plan = RULE_POLICY[task_id]
|
| 131 |
+
idx = min(step_idx, len(plan) - 1)
|
| 132 |
+
return Action.model_validate(plan[idx])
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def run_episode(env: SupportTriageEnv, task_id: str, mode: str, model: str, client: OpenAI | None) -> EpisodeResult:
|
| 136 |
+
obs = env.reset(task_id)
|
| 137 |
+
done = False
|
| 138 |
+
info: dict[str, Any] = {}
|
| 139 |
+
reward_value = 0.0
|
| 140 |
+
|
| 141 |
+
while not done:
|
| 142 |
+
step_idx = env.state()["step_count"]
|
| 143 |
+
if mode == "heuristic":
|
| 144 |
+
action = heuristic_action(task_id, step_idx)
|
| 145 |
+
else:
|
| 146 |
+
assert client is not None
|
| 147 |
+
try:
|
| 148 |
+
action = llm_action(client, model, obs.model_dump(), env.state())
|
| 149 |
+
except Exception:
|
| 150 |
+
# Deterministic fallback keeps run alive for reproducible scoring.
|
| 151 |
+
action = heuristic_action(task_id, step_idx)
|
| 152 |
+
|
| 153 |
+
obs, reward, done, info = env.step(action)
|
| 154 |
+
reward_value = reward.value
|
| 155 |
+
|
| 156 |
+
return EpisodeResult(
|
| 157 |
+
task_id=task_id,
|
| 158 |
+
steps=env.state()["step_count"],
|
| 159 |
+
grader_score=float(info["grader_score"]),
|
| 160 |
+
reward=reward_value,
|
| 161 |
+
done_reason=str(info["done_reason"]),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def main() -> None:
|
| 166 |
+
parser = argparse.ArgumentParser(description="Run baseline on support-triage-openenv tasks.")
|
| 167 |
+
parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai")
|
| 168 |
+
parser.add_argument("--model", default="gpt-4.1-mini")
|
| 169 |
+
parser.add_argument("--output", default="scores/baseline_scores.json")
|
| 170 |
+
args = parser.parse_args()
|
| 171 |
+
|
| 172 |
+
client = None
|
| 173 |
+
if args.mode == "openai":
|
| 174 |
+
if not os.getenv("OPENAI_API_KEY"):
|
| 175 |
+
raise RuntimeError("OPENAI_API_KEY is required for --mode openai")
|
| 176 |
+
client = OpenAI()
|
| 177 |
+
|
| 178 |
+
env = SupportTriageEnv()
|
| 179 |
+
results = [run_episode(env, t, args.mode, args.model, client) for t in env.task_ids]
|
| 180 |
+
|
| 181 |
+
summary = {
|
| 182 |
+
"mode": args.mode,
|
| 183 |
+
"model": args.model,
|
| 184 |
+
"avg_grader_score": round(sum(r.grader_score for r in results) / len(results), 4),
|
| 185 |
+
"avg_final_reward": round(sum(r.reward for r in results) / len(results), 4),
|
| 186 |
+
"episodes": [asdict(r) for r in results],
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
output_path = Path(args.output)
|
| 190 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 191 |
+
output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 192 |
+
|
| 193 |
+
print(json.dumps(summary, indent=2))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
scripts/sample_inference_script.sh
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script Example
|
| 3 |
+
===================================
|
| 4 |
+
MANDATORY
|
| 5 |
+
- Before submitting, ensure the following variables are defined in your environment configuration:
|
| 6 |
+
API_BASE_URL The API endpoint for the LLM.
|
| 7 |
+
MODEL_NAME The model identifier to use for inference.
|
| 8 |
+
HF_TOKEN Your Hugging Face / API key.
|
| 9 |
+
LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
|
| 10 |
+
method
|
| 11 |
+
|
| 12 |
+
- Defaults are set only for API_BASE_URL and MODEL_NAME
|
| 13 |
+
(and should reflect your active inference setup):
|
| 14 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
|
| 15 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
|
| 16 |
+
|
| 17 |
+
- The inference script must be named `inference.py` and placed in the root directory of the project
|
| 18 |
+
- Participants must use OpenAI Client for all LLM calls using above variables
|
| 19 |
+
|
| 20 |
+
STDOUT FORMAT
|
| 21 |
+
- The script must emit exactly three line types to stdout, in this order:
|
| 22 |
+
|
| 23 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 24 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 25 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
| 26 |
+
|
| 27 |
+
Rules:
|
| 28 |
+
- One [START] line at episode begin.
|
| 29 |
+
- One [STEP] line per step, immediately after env.step() returns.
|
| 30 |
+
- One [END] line after env.close(), always emitted (even on exception).
|
| 31 |
+
- reward and rewards are formatted to 2 decimal places.
|
| 32 |
+
- done and success are lowercase booleans: true or false.
|
| 33 |
+
- error is the raw last_action_error string, or null if none.
|
| 34 |
+
- All fields on a single line with no newlines within a line.
|
| 35 |
+
- Each tasks should return score in [0, 1]
|
| 36 |
+
|
| 37 |
+
Example:
|
| 38 |
+
[START] task=click-test env=miniwob model=Qwen3-VL-30B
|
| 39 |
+
[STEP] step=1 action=click('123') reward=0.00 done=false error=null
|
| 40 |
+
[STEP] step=2 action=fill('456','text') reward=0.00 done=false error=null
|
| 41 |
+
[STEP] step=3 action=click('789') reward=1.00 done=true error=null
|
| 42 |
+
[END] success=true steps=3 score=1.00 rewards=0.00,0.00,1.00
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import asyncio
|
| 46 |
+
import os
|
| 47 |
+
import textwrap
|
| 48 |
+
from typing import List, Optional
|
| 49 |
+
|
| 50 |
+
from openai import OpenAI
|
| 51 |
+
|
| 52 |
+
from my_env_v4 import MyEnvV4Action, MyEnvV4Env
|
| 53 |
+
IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
|
| 54 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 55 |
+
|
| 56 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 57 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 58 |
+
TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
|
| 59 |
+
BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
|
| 60 |
+
MAX_STEPS = 8
|
| 61 |
+
TEMPERATURE = 0.7
|
| 62 |
+
MAX_TOKENS = 150
|
| 63 |
+
SUCCESS_SCORE_THRESHOLD = 0.1 # normalized score in [0, 1]
|
| 64 |
+
|
| 65 |
+
# Max possible reward: each token contributes 0.1, across all steps
|
| 66 |
+
_MAX_REWARD_PER_STEP = MAX_TOKENS * 0.1
|
| 67 |
+
MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
|
| 68 |
+
|
| 69 |
+
SYSTEM_PROMPT = textwrap.dedent(
|
| 70 |
+
"""
|
| 71 |
+
You are interacting with a simple echo environment.
|
| 72 |
+
Each turn you must send a message. The environment will echo it back.
|
| 73 |
+
Reward is proportional to message length: reward = len(message) * 0.1
|
| 74 |
+
Your goal is to maximize total reward by sending meaningful, substantive messages.
|
| 75 |
+
Reply with exactly one message string — no quotes, no prefixes, just the message text.
|
| 76 |
+
"""
|
| 77 |
+
).strip()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 81 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 85 |
+
error_val = error if error else "null"
|
| 86 |
+
done_val = str(done).lower()
|
| 87 |
+
print(
|
| 88 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 89 |
+
flush=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 94 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 95 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def build_user_prompt(step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
|
| 99 |
+
history_block = "\n".join(history[-4:]) if history else "None"
|
| 100 |
+
return textwrap.dedent(
|
| 101 |
+
f"""
|
| 102 |
+
Step: {step}
|
| 103 |
+
Last echoed message: {last_echoed!r}
|
| 104 |
+
Last reward: {last_reward:.2f}
|
| 105 |
+
Previous steps:
|
| 106 |
+
{history_block}
|
| 107 |
+
Send your next message.
|
| 108 |
+
"""
|
| 109 |
+
).strip()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
|
| 113 |
+
user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
|
| 114 |
+
try:
|
| 115 |
+
completion = client.chat.completions.create(
|
| 116 |
+
model=MODEL_NAME,
|
| 117 |
+
messages=[
|
| 118 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 119 |
+
{"role": "user", "content": user_prompt},
|
| 120 |
+
],
|
| 121 |
+
temperature=TEMPERATURE,
|
| 122 |
+
max_tokens=MAX_TOKENS,
|
| 123 |
+
stream=False,
|
| 124 |
+
)
|
| 125 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 126 |
+
return text if text else "hello"
|
| 127 |
+
except Exception as exc:
|
| 128 |
+
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 129 |
+
return "hello"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
async def main() -> None:
|
| 133 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 134 |
+
|
| 135 |
+
env = await MyEnvV4Env.from_docker_image(IMAGE_NAME)
|
| 136 |
+
|
| 137 |
+
history: List[str] = []
|
| 138 |
+
rewards: List[float] = []
|
| 139 |
+
steps_taken = 0
|
| 140 |
+
score = 0.0
|
| 141 |
+
success = False
|
| 142 |
+
|
| 143 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
result = await env.reset() # OpenENV.reset()
|
| 147 |
+
last_echoed = result.observation.echoed_message
|
| 148 |
+
last_reward = 0.0
|
| 149 |
+
|
| 150 |
+
for step in range(1, MAX_STEPS + 1):
|
| 151 |
+
if result.done:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
message = get_model_message(client, step, last_echoed, last_reward, history)
|
| 155 |
+
|
| 156 |
+
result = await env.step(MyEnvV4Action(message=message))
|
| 157 |
+
obs = result.observation
|
| 158 |
+
|
| 159 |
+
reward = result.reward or 0.0
|
| 160 |
+
done = result.done
|
| 161 |
+
error = None
|
| 162 |
+
|
| 163 |
+
rewards.append(reward)
|
| 164 |
+
steps_taken = step
|
| 165 |
+
last_echoed = obs.echoed_message
|
| 166 |
+
last_reward = reward
|
| 167 |
+
|
| 168 |
+
log_step(step=step, action=message, reward=reward, done=done, error=error)
|
| 169 |
+
|
| 170 |
+
history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}")
|
| 171 |
+
|
| 172 |
+
if done:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
|
| 176 |
+
score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
|
| 177 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 178 |
+
|
| 179 |
+
finally:
|
| 180 |
+
try:
|
| 181 |
+
await env.close()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
|
| 184 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
asyncio.run(main())
|
scripts/validate_env.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def run_local_checks() -> None:
|
| 9 |
+
commands = [
|
| 10 |
+
["python", "-m", "pytest", "-q"],
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
for cmd in commands:
|
| 14 |
+
print(f"$ {' '.join(cmd)}")
|
| 15 |
+
subprocess.run(cmd, check=True)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def run_openenv_validate_if_available() -> None:
|
| 19 |
+
if shutil.which("openenv") is None:
|
| 20 |
+
print("openenv CLI not found; skipped `openenv validate`.")
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
cmd = ["openenv", "validate", "openenv.yaml"]
|
| 24 |
+
print(f"$ {' '.join(cmd)}")
|
| 25 |
+
subprocess.run(cmd, check=True)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
run_local_checks()
|
| 30 |
+
run_openenv_validate_if_available()
|
src/support_triage_openenv/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Support triage OpenEnv package."""
|
| 2 |
+
|
| 3 |
+
from .env import SupportTriageEnv
|
| 4 |
+
from .models import Action, Observation, Reward
|
| 5 |
+
|
| 6 |
+
__all__ = ["SupportTriageEnv", "Observation", "Action", "Reward"]
|
src/support_triage_openenv/env.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from .graders import grade_task
|
| 7 |
+
from .models import Action, Observation, Reward, StepInfo, TicketView
|
| 8 |
+
from .tasks import TaskSpec, get_tasks
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SupportTriageEnv:
|
| 12 |
+
"""
|
| 13 |
+
OpenEnv-compatible environment for customer support ticket triage.
|
| 14 |
+
|
| 15 |
+
API:
|
| 16 |
+
- reset(task_id: str | None = None) -> Observation
|
| 17 |
+
- step(action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]
|
| 18 |
+
- state() -> dict[str, Any]
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
self._tasks: dict[str, TaskSpec] = {t.task_id: t for t in get_tasks()}
|
| 23 |
+
self._task_order = [t.task_id for t in get_tasks()]
|
| 24 |
+
self._task_index = 0
|
| 25 |
+
self._current_task: TaskSpec | None = None
|
| 26 |
+
self._state: dict[str, Any] = {}
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def task_ids(self) -> list[str]:
|
| 30 |
+
return list(self._task_order)
|
| 31 |
+
|
| 32 |
+
def reset(self, task_id: str | None = None) -> Observation:
|
| 33 |
+
if task_id is None:
|
| 34 |
+
task_id = self._task_order[self._task_index % len(self._task_order)]
|
| 35 |
+
self._task_index += 1
|
| 36 |
+
if task_id not in self._tasks:
|
| 37 |
+
raise ValueError(f"Unknown task_id '{task_id}'. Available: {sorted(self._tasks.keys())}")
|
| 38 |
+
|
| 39 |
+
self._current_task = self._tasks[task_id]
|
| 40 |
+
self._state = {
|
| 41 |
+
"step_count": 0,
|
| 42 |
+
"read_ticket_ids": set(),
|
| 43 |
+
"selected_ticket_id": None,
|
| 44 |
+
"classification": None,
|
| 45 |
+
"draft_reply": None,
|
| 46 |
+
"resolved": False,
|
| 47 |
+
"resolved_ticket_id": None,
|
| 48 |
+
"invalid_actions": 0,
|
| 49 |
+
"repeat_actions": 0,
|
| 50 |
+
"action_history": [],
|
| 51 |
+
"last_note": "Environment reset.",
|
| 52 |
+
"done": False,
|
| 53 |
+
"done_reason": "ongoing",
|
| 54 |
+
}
|
| 55 |
+
return self._build_observation()
|
| 56 |
+
|
| 57 |
+
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict[str, Any]]:
|
| 58 |
+
if self._current_task is None:
|
| 59 |
+
raise RuntimeError("Call reset() before step().")
|
| 60 |
+
if self._state["done"]:
|
| 61 |
+
raise RuntimeError("Episode already done. Call reset() for a new episode.")
|
| 62 |
+
|
| 63 |
+
task = self._current_task
|
| 64 |
+
st = self._state
|
| 65 |
+
st["step_count"] += 1
|
| 66 |
+
|
| 67 |
+
action_fingerprint = action.model_dump_json()
|
| 68 |
+
if st["action_history"] and st["action_history"][-1] == action_fingerprint:
|
| 69 |
+
st["repeat_actions"] += 1
|
| 70 |
+
st["action_history"].append(action_fingerprint)
|
| 71 |
+
|
| 72 |
+
valid_ticket_ids = {t["ticket_id"] for t in task.tickets}
|
| 73 |
+
step_penalty = 0.0
|
| 74 |
+
|
| 75 |
+
if action.action_type in {"read_ticket", "classify_ticket", "resolve_ticket"}:
|
| 76 |
+
if not action.ticket_id or action.ticket_id not in valid_ticket_ids:
|
| 77 |
+
st["invalid_actions"] += 1
|
| 78 |
+
st["last_note"] = "Invalid or missing ticket_id."
|
| 79 |
+
step_penalty -= 0.03
|
| 80 |
+
if st["invalid_actions"] >= 3:
|
| 81 |
+
st["done"] = True
|
| 82 |
+
st["done_reason"] = "invalid_action"
|
| 83 |
+
return self._assemble_step_response(step_penalty)
|
| 84 |
+
|
| 85 |
+
if action.action_type == "read_ticket":
|
| 86 |
+
st["read_ticket_ids"].add(action.ticket_id)
|
| 87 |
+
st["selected_ticket_id"] = action.ticket_id
|
| 88 |
+
st["last_note"] = f"Read ticket {action.ticket_id}."
|
| 89 |
+
|
| 90 |
+
elif action.action_type == "classify_ticket":
|
| 91 |
+
if action.ticket_id != task.target_ticket_id:
|
| 92 |
+
step_penalty -= 0.01
|
| 93 |
+
st["classification"] = {
|
| 94 |
+
"ticket_id": action.ticket_id,
|
| 95 |
+
"priority": action.priority,
|
| 96 |
+
"category": action.category,
|
| 97 |
+
"needs_escalation": action.needs_escalation,
|
| 98 |
+
}
|
| 99 |
+
st["last_note"] = f"Saved classification for {action.ticket_id}."
|
| 100 |
+
|
| 101 |
+
elif action.action_type == "draft_reply":
|
| 102 |
+
text = (action.message or "").strip()
|
| 103 |
+
if not text:
|
| 104 |
+
st["invalid_actions"] += 1
|
| 105 |
+
st["last_note"] = "Draft reply is empty."
|
| 106 |
+
step_penalty -= 0.02
|
| 107 |
+
else:
|
| 108 |
+
st["draft_reply"] = text
|
| 109 |
+
st["last_note"] = "Draft reply saved."
|
| 110 |
+
|
| 111 |
+
elif action.action_type == "resolve_ticket":
|
| 112 |
+
st["resolved"] = True
|
| 113 |
+
st["resolved_ticket_id"] = action.ticket_id
|
| 114 |
+
st["done"] = True
|
| 115 |
+
st["done_reason"] = "resolved"
|
| 116 |
+
st["last_note"] = f"Resolved ticket {action.ticket_id}."
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
st["invalid_actions"] += 1
|
| 120 |
+
st["last_note"] = f"Unknown action {action.action_type}."
|
| 121 |
+
step_penalty -= 0.03
|
| 122 |
+
|
| 123 |
+
if st["step_count"] >= task.max_steps and not st["done"]:
|
| 124 |
+
st["done"] = True
|
| 125 |
+
st["done_reason"] = "max_steps"
|
| 126 |
+
st["last_note"] = "Reached max_steps."
|
| 127 |
+
|
| 128 |
+
if st["repeat_actions"] > 0:
|
| 129 |
+
step_penalty -= min(0.04, 0.01 * st["repeat_actions"])
|
| 130 |
+
|
| 131 |
+
return self._assemble_step_response(step_penalty)
|
| 132 |
+
|
| 133 |
+
def state(self) -> dict[str, Any]:
|
| 134 |
+
if self._current_task is None:
|
| 135 |
+
raise RuntimeError("Environment not initialized. Call reset() first.")
|
| 136 |
+
visible = copy.deepcopy(self._state)
|
| 137 |
+
visible["read_ticket_ids"] = sorted(list(visible["read_ticket_ids"]))
|
| 138 |
+
visible["task_id"] = self._current_task.task_id
|
| 139 |
+
return visible
|
| 140 |
+
|
| 141 |
+
def _build_observation(self) -> Observation:
|
| 142 |
+
assert self._current_task is not None
|
| 143 |
+
task = self._current_task
|
| 144 |
+
st = self._state
|
| 145 |
+
|
| 146 |
+
content = None
|
| 147 |
+
if st.get("selected_ticket_id") in st["read_ticket_ids"]:
|
| 148 |
+
ticket = next(t for t in task.tickets if t["ticket_id"] == st["selected_ticket_id"])
|
| 149 |
+
content = ticket["content"]
|
| 150 |
+
|
| 151 |
+
inbox = [
|
| 152 |
+
TicketView(
|
| 153 |
+
ticket_id=t["ticket_id"],
|
| 154 |
+
subject=t["subject"],
|
| 155 |
+
customer_tier=t["customer_tier"],
|
| 156 |
+
age_minutes=t["age_minutes"],
|
| 157 |
+
read=t["ticket_id"] in st["read_ticket_ids"],
|
| 158 |
+
)
|
| 159 |
+
for t in task.tickets
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
partial = grade_task(task, st)
|
| 163 |
+
return Observation(
|
| 164 |
+
task_id=task.task_id,
|
| 165 |
+
objective=task.objective,
|
| 166 |
+
step_count=st["step_count"],
|
| 167 |
+
max_steps=task.max_steps,
|
| 168 |
+
inbox=inbox,
|
| 169 |
+
current_ticket_content=content,
|
| 170 |
+
latest_system_note=st.get("last_note", ""),
|
| 171 |
+
score_hint={
|
| 172 |
+
"read": partial.read_score,
|
| 173 |
+
"classify": partial.classify_score,
|
| 174 |
+
"reply": partial.reply_score,
|
| 175 |
+
"resolve": partial.resolve_score,
|
| 176 |
+
},
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _assemble_step_response(self, step_penalty: float) -> tuple[Observation, Reward, bool, dict[str, Any]]:
|
| 180 |
+
assert self._current_task is not None
|
| 181 |
+
task = self._current_task
|
| 182 |
+
st = self._state
|
| 183 |
+
|
| 184 |
+
grade = grade_task(task, st)
|
| 185 |
+
progress_signal = 0.75 * grade.total
|
| 186 |
+
penalty_total = 0.0
|
| 187 |
+
|
| 188 |
+
penalties: dict[str, float] = {}
|
| 189 |
+
if st["invalid_actions"]:
|
| 190 |
+
penalties["invalid_actions"] = round(min(0.2, 0.04 * st["invalid_actions"]), 4)
|
| 191 |
+
penalty_total += penalties["invalid_actions"]
|
| 192 |
+
if st["repeat_actions"]:
|
| 193 |
+
penalties["repetition"] = round(min(0.15, 0.02 * st["repeat_actions"]), 4)
|
| 194 |
+
penalty_total += penalties["repetition"]
|
| 195 |
+
if step_penalty < 0:
|
| 196 |
+
penalties["step_penalty"] = round(abs(step_penalty), 4)
|
| 197 |
+
penalty_total += abs(step_penalty)
|
| 198 |
+
|
| 199 |
+
reward_value = progress_signal - penalty_total
|
| 200 |
+
if st["done"]:
|
| 201 |
+
reward_value = max(reward_value, grade.total)
|
| 202 |
+
|
| 203 |
+
reward_value = max(0.0, min(1.0, reward_value))
|
| 204 |
+
|
| 205 |
+
reward = Reward(
|
| 206 |
+
value=round(reward_value, 4),
|
| 207 |
+
components={
|
| 208 |
+
"progress_signal": round(progress_signal, 4),
|
| 209 |
+
"grade_total": grade.total,
|
| 210 |
+
"read_score": grade.read_score,
|
| 211 |
+
"classify_score": grade.classify_score,
|
| 212 |
+
"reply_score": grade.reply_score,
|
| 213 |
+
"resolve_score": grade.resolve_score,
|
| 214 |
+
"penalty_total": round(penalty_total, 4),
|
| 215 |
+
},
|
| 216 |
+
reasoning="Shaped reward from grader progress with penalties for invalid or looping actions.",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
info = StepInfo(
|
| 220 |
+
task_id=task.task_id,
|
| 221 |
+
done_reason=st["done_reason"],
|
| 222 |
+
grader_score=grade.total,
|
| 223 |
+
reward_components=reward.components,
|
| 224 |
+
penalties=penalties,
|
| 225 |
+
state_snapshot=self.state(),
|
| 226 |
+
).model_dump()
|
| 227 |
+
|
| 228 |
+
obs = self._build_observation()
|
| 229 |
+
return obs, reward, st["done"], info
|
src/support_triage_openenv/graders.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
from .tasks import TaskSpec
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class GradeBreakdown:
|
| 10 |
+
read_score: float
|
| 11 |
+
classify_score: float
|
| 12 |
+
reply_score: float
|
| 13 |
+
resolve_score: float
|
| 14 |
+
total: float
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _keyword_coverage(message: str, required: tuple[str, ...]) -> float:
|
| 18 |
+
if not required:
|
| 19 |
+
return 1.0
|
| 20 |
+
lowered = message.lower()
|
| 21 |
+
found = sum(1 for k in required if k.lower() in lowered)
|
| 22 |
+
return found / len(required)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _forbidden_penalty(message: str, forbidden: tuple[str, ...]) -> float:
|
| 26 |
+
lowered = message.lower()
|
| 27 |
+
count = sum(1 for k in forbidden if k.lower() in lowered)
|
| 28 |
+
return min(1.0, 0.5 * count)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def grade_task(task: TaskSpec, env_state: dict) -> GradeBreakdown:
|
| 32 |
+
read_target = 1.0 if task.target_ticket_id in env_state["read_ticket_ids"] else 0.0
|
| 33 |
+
context_hits = sum(1 for tid in task.required_context_ticket_ids if tid in env_state["read_ticket_ids"])
|
| 34 |
+
context_total = len(task.required_context_ticket_ids)
|
| 35 |
+
context_score = context_hits / context_total if context_total else 1.0
|
| 36 |
+
read_score = 0.6 * read_target + 0.4 * context_score
|
| 37 |
+
|
| 38 |
+
classification = env_state.get("classification") or {}
|
| 39 |
+
fields_correct = 0
|
| 40 |
+
fields_total = 3
|
| 41 |
+
fields_correct += int(classification.get("priority") == task.expected_priority)
|
| 42 |
+
fields_correct += int(classification.get("category") == task.expected_category)
|
| 43 |
+
fields_correct += int(classification.get("needs_escalation") == task.expected_escalation)
|
| 44 |
+
classify_score = fields_correct / fields_total
|
| 45 |
+
|
| 46 |
+
draft = env_state.get("draft_reply") or ""
|
| 47 |
+
keyword_score = _keyword_coverage(draft, task.required_reply_keywords)
|
| 48 |
+
forbidden_penalty = _forbidden_penalty(draft, task.forbidden_reply_keywords)
|
| 49 |
+
reply_score = max(0.0, keyword_score - forbidden_penalty)
|
| 50 |
+
|
| 51 |
+
resolved = bool(env_state.get("resolved"))
|
| 52 |
+
resolved_target = env_state.get("resolved_ticket_id") == task.target_ticket_id
|
| 53 |
+
resolve_score = 1.0 if resolved and resolved_target else 0.0
|
| 54 |
+
|
| 55 |
+
total = (0.2 * read_score) + (0.35 * classify_score) + (0.3 * reply_score) + (0.15 * resolve_score)
|
| 56 |
+
total = max(0.0, min(1.0, total))
|
| 57 |
+
|
| 58 |
+
return GradeBreakdown(
|
| 59 |
+
read_score=round(read_score, 4),
|
| 60 |
+
classify_score=round(classify_score, 4),
|
| 61 |
+
reply_score=round(reply_score, 4),
|
| 62 |
+
resolve_score=round(resolve_score, 4),
|
| 63 |
+
total=round(total, 4),
|
| 64 |
+
)
|
src/support_triage_openenv/models.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Literal
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ActionType = Literal["read_ticket", "classify_ticket", "draft_reply", "resolve_ticket"]
|
| 9 |
+
PriorityType = Literal["low", "medium", "high", "urgent"]
|
| 10 |
+
CategoryType = Literal["account", "billing", "technical", "abuse", "general"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TicketView(BaseModel):
|
| 14 |
+
ticket_id: str
|
| 15 |
+
subject: str
|
| 16 |
+
customer_tier: Literal["free", "pro", "enterprise"]
|
| 17 |
+
age_minutes: int
|
| 18 |
+
read: bool = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Observation(BaseModel):
|
| 22 |
+
task_id: str
|
| 23 |
+
objective: str
|
| 24 |
+
step_count: int
|
| 25 |
+
max_steps: int
|
| 26 |
+
inbox: list[TicketView]
|
| 27 |
+
current_ticket_content: str | None = None
|
| 28 |
+
latest_system_note: str = ""
|
| 29 |
+
score_hint: dict[str, float] = Field(default_factory=dict)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Action(BaseModel):
|
| 33 |
+
action_type: ActionType
|
| 34 |
+
ticket_id: str | None = None
|
| 35 |
+
priority: PriorityType | None = None
|
| 36 |
+
category: CategoryType | None = None
|
| 37 |
+
needs_escalation: bool | None = None
|
| 38 |
+
message: str | None = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Reward(BaseModel):
|
| 42 |
+
value: float = Field(ge=0.0, le=1.0)
|
| 43 |
+
components: dict[str, float] = Field(default_factory=dict)
|
| 44 |
+
reasoning: str = ""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class StepInfo(BaseModel):
|
| 48 |
+
task_id: str
|
| 49 |
+
done_reason: Literal["ongoing", "resolved", "max_steps", "invalid_action"]
|
| 50 |
+
grader_score: float
|
| 51 |
+
reward_components: dict[str, float]
|
| 52 |
+
penalties: dict[str, float]
|
| 53 |
+
state_snapshot: dict[str, Any]
|
src/support_triage_openenv/tasks.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class TaskSpec:
|
| 9 |
+
task_id: str
|
| 10 |
+
difficulty: str
|
| 11 |
+
title: str
|
| 12 |
+
objective: str
|
| 13 |
+
max_steps: int
|
| 14 |
+
target_ticket_id: str
|
| 15 |
+
required_context_ticket_ids: tuple[str, ...]
|
| 16 |
+
expected_priority: str
|
| 17 |
+
expected_category: str
|
| 18 |
+
expected_escalation: bool
|
| 19 |
+
required_reply_keywords: tuple[str, ...]
|
| 20 |
+
forbidden_reply_keywords: tuple[str, ...]
|
| 21 |
+
tickets: tuple[dict[str, Any], ...]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_tasks() -> list[TaskSpec]:
|
| 25 |
+
return [
|
| 26 |
+
TaskSpec(
|
| 27 |
+
task_id="easy_password_reset",
|
| 28 |
+
difficulty="easy",
|
| 29 |
+
title="Password reset triage",
|
| 30 |
+
objective=(
|
| 31 |
+
"Resolve customer lockout ticket by selecting correct category/priority and drafting "
|
| 32 |
+
"a secure response that includes a reset link workflow."
|
| 33 |
+
),
|
| 34 |
+
max_steps=10,
|
| 35 |
+
target_ticket_id="T-1001",
|
| 36 |
+
required_context_ticket_ids=(),
|
| 37 |
+
expected_priority="medium",
|
| 38 |
+
expected_category="account",
|
| 39 |
+
expected_escalation=False,
|
| 40 |
+
required_reply_keywords=("reset link", "security", "confirm", "email"),
|
| 41 |
+
forbidden_reply_keywords=("share your password",),
|
| 42 |
+
tickets=(
|
| 43 |
+
{
|
| 44 |
+
"ticket_id": "T-1001",
|
| 45 |
+
"subject": "Cannot log in after phone change",
|
| 46 |
+
"customer_tier": "pro",
|
| 47 |
+
"age_minutes": 33,
|
| 48 |
+
"content": (
|
| 49 |
+
"I switched phones and now MFA fails. I need urgent access to my dashboard. "
|
| 50 |
+
"Please help me reset safely."
|
| 51 |
+
),
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"ticket_id": "T-1002",
|
| 55 |
+
"subject": "Feature request: dark mode",
|
| 56 |
+
"customer_tier": "free",
|
| 57 |
+
"age_minutes": 250,
|
| 58 |
+
"content": "Could you add dark mode next quarter?",
|
| 59 |
+
},
|
| 60 |
+
),
|
| 61 |
+
),
|
| 62 |
+
TaskSpec(
|
| 63 |
+
task_id="medium_billing_dispute",
|
| 64 |
+
difficulty="medium",
|
| 65 |
+
title="Billing dispute and partial refund",
|
| 66 |
+
objective=(
|
| 67 |
+
"Assess a duplicate charge complaint, inspect context ticket, classify correctly, "
|
| 68 |
+
"and draft a policy-compliant refund response."
|
| 69 |
+
),
|
| 70 |
+
max_steps=12,
|
| 71 |
+
target_ticket_id="T-2001",
|
| 72 |
+
required_context_ticket_ids=("T-2002",),
|
| 73 |
+
expected_priority="high",
|
| 74 |
+
expected_category="billing",
|
| 75 |
+
expected_escalation=False,
|
| 76 |
+
required_reply_keywords=("duplicate charge", "refund", "3-5 business days", "invoice"),
|
| 77 |
+
forbidden_reply_keywords=("guaranteed immediate refund",),
|
| 78 |
+
tickets=(
|
| 79 |
+
{
|
| 80 |
+
"ticket_id": "T-2001",
|
| 81 |
+
"subject": "Charged twice this month",
|
| 82 |
+
"customer_tier": "enterprise",
|
| 83 |
+
"age_minutes": 85,
|
| 84 |
+
"content": (
|
| 85 |
+
"We were charged twice for March. Finance needs confirmation and refund timeline today."
|
| 86 |
+
),
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"ticket_id": "T-2002",
|
| 90 |
+
"subject": "Billing system log",
|
| 91 |
+
"customer_tier": "enterprise",
|
| 92 |
+
"age_minutes": 80,
|
| 93 |
+
"content": "Payment gateway shows one charge capture and one duplicate authorization hold.",
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"ticket_id": "T-2003",
|
| 97 |
+
"subject": "Onboarding docs typo",
|
| 98 |
+
"customer_tier": "pro",
|
| 99 |
+
"age_minutes": 700,
|
| 100 |
+
"content": "There is a typo in page 3 of setup docs.",
|
| 101 |
+
},
|
| 102 |
+
),
|
| 103 |
+
),
|
| 104 |
+
TaskSpec(
|
| 105 |
+
task_id="hard_outage_incident",
|
| 106 |
+
difficulty="hard",
|
| 107 |
+
title="Incident comms under pressure",
|
| 108 |
+
objective=(
|
| 109 |
+
"Handle a potential security outage report by collecting key evidence from related tickets, "
|
| 110 |
+
"setting urgent escalation, and drafting a safe incident response message without over-promising."
|
| 111 |
+
),
|
| 112 |
+
max_steps=14,
|
| 113 |
+
target_ticket_id="T-3001",
|
| 114 |
+
required_context_ticket_ids=("T-3002", "T-3003"),
|
| 115 |
+
expected_priority="urgent",
|
| 116 |
+
expected_category="technical",
|
| 117 |
+
expected_escalation=True,
|
| 118 |
+
required_reply_keywords=("incident", "investigating", "status page", "escalated"),
|
| 119 |
+
forbidden_reply_keywords=("issue is fully resolved", "ignore this"),
|
| 120 |
+
tickets=(
|
| 121 |
+
{
|
| 122 |
+
"ticket_id": "T-3001",
|
| 123 |
+
"subject": "API returning 500 for all EU requests",
|
| 124 |
+
"customer_tier": "enterprise",
|
| 125 |
+
"age_minutes": 18,
|
| 126 |
+
"content": (
|
| 127 |
+
"Since 08:10 UTC every API call fails. We suspect a region outage and possible data inconsistency."
|
| 128 |
+
),
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"ticket_id": "T-3002",
|
| 132 |
+
"subject": "SOC alert summary",
|
| 133 |
+
"customer_tier": "enterprise",
|
| 134 |
+
"age_minutes": 15,
|
| 135 |
+
"content": "Monitoring confirms spike in error rate and elevated auth failures in eu-west-1.",
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"ticket_id": "T-3003",
|
| 139 |
+
"subject": "Status page draft",
|
| 140 |
+
"customer_tier": "enterprise",
|
| 141 |
+
"age_minutes": 11,
|
| 142 |
+
"content": "Public message should acknowledge incident, investigation, and next update ETA.",
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"ticket_id": "T-3004",
|
| 146 |
+
"subject": "Question about annual billing",
|
| 147 |
+
"customer_tier": "free",
|
| 148 |
+
"age_minutes": 1440,
|
| 149 |
+
"content": "Can I switch to annual plan later?",
|
| 150 |
+
},
|
| 151 |
+
),
|
| 152 |
+
),
|
| 153 |
+
]
|
tasks/TASKS.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task Details
|
| 2 |
+
|
| 3 |
+
## easy_password_reset
|
| 4 |
+
Objective: resolve a lockout request safely.
|
| 5 |
+
Success signals:
|
| 6 |
+
- Read the target ticket.
|
| 7 |
+
- Classify as `priority=medium`, `category=account`, `needs_escalation=False`.
|
| 8 |
+
- Draft reply mentions reset link and security confirmation.
|
| 9 |
+
- Resolve correct ticket.
|
| 10 |
+
|
| 11 |
+
## medium_billing_dispute
|
| 12 |
+
Objective: handle duplicate charge dispute with refund communication.
|
| 13 |
+
Success signals:
|
| 14 |
+
- Read `T-2001` and context `T-2002`.
|
| 15 |
+
- Classify as `priority=high`, `category=billing`, `needs_escalation=False`.
|
| 16 |
+
- Reply references duplicate charge, refund, invoice, and `3-5 business days`.
|
| 17 |
+
- Resolve `T-2001`.
|
| 18 |
+
|
| 19 |
+
## hard_outage_incident
|
| 20 |
+
Objective: process potential incident/outage with escalation and careful external comms.
|
| 21 |
+
Success signals:
|
| 22 |
+
- Read `T-3001`, `T-3002`, `T-3003` context.
|
| 23 |
+
- Classify as `priority=urgent`, `category=technical`, `needs_escalation=True`.
|
| 24 |
+
- Reply includes incident acknowledgement, active investigation, status page updates, escalation.
|
| 25 |
+
- Resolve `T-3001`.
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi.testclient import TestClient
|
| 2 |
+
|
| 3 |
+
from app import app
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_health_and_reset_flow():
|
| 7 |
+
client = TestClient(app)
|
| 8 |
+
|
| 9 |
+
health = client.get('/health')
|
| 10 |
+
assert health.status_code == 200
|
| 11 |
+
assert health.json()['status'] == 'ok'
|
| 12 |
+
|
| 13 |
+
reset = client.post('/reset', json={'task_id': 'easy_password_reset'})
|
| 14 |
+
assert reset.status_code == 200
|
| 15 |
+
payload = reset.json()
|
| 16 |
+
assert payload['task_id'] == 'easy_password_reset'
|
| 17 |
+
assert payload['step_count'] == 0
|
| 18 |
+
|
| 19 |
+
step = client.post('/step', json={'action_type': 'read_ticket', 'ticket_id': 'T-1001'})
|
| 20 |
+
assert step.status_code == 200
|
| 21 |
+
body = step.json()
|
| 22 |
+
assert body['done'] is False
|
| 23 |
+
assert body['observation']['step_count'] == 1
|
| 24 |
+
|
| 25 |
+
state = client.get('/state')
|
| 26 |
+
assert state.status_code == 200
|
| 27 |
+
assert state.json()['task_id'] == 'easy_password_reset'
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from support_triage_openenv.env import SupportTriageEnv
|
| 2 |
+
from support_triage_openenv.models import Action
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def test_reset_and_state_cycle():
|
| 6 |
+
env = SupportTriageEnv()
|
| 7 |
+
obs = env.reset("easy_password_reset")
|
| 8 |
+
assert obs.task_id == "easy_password_reset"
|
| 9 |
+
assert obs.step_count == 0
|
| 10 |
+
state = env.state()
|
| 11 |
+
assert state["done"] is False
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_easy_task_can_reach_high_score():
|
| 15 |
+
env = SupportTriageEnv()
|
| 16 |
+
env.reset("easy_password_reset")
|
| 17 |
+
env.step(Action(action_type="read_ticket", ticket_id="T-1001"))
|
| 18 |
+
env.step(
|
| 19 |
+
Action(
|
| 20 |
+
action_type="classify_ticket",
|
| 21 |
+
ticket_id="T-1001",
|
| 22 |
+
priority="medium",
|
| 23 |
+
category="account",
|
| 24 |
+
needs_escalation=False,
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
env.step(
|
| 28 |
+
Action(
|
| 29 |
+
action_type="draft_reply",
|
| 30 |
+
message=(
|
| 31 |
+
"We will send a reset link to your email. For security, please confirm the request "
|
| 32 |
+
"from your device after receiving the reset link."
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
_, reward, done, info = env.step(Action(action_type="resolve_ticket", ticket_id="T-1001"))
|
| 37 |
+
assert done is True
|
| 38 |
+
assert info["grader_score"] >= 0.9
|
| 39 |
+
assert reward.value >= 0.9
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_invalid_ticket_penalty_and_done_guard():
|
| 43 |
+
env = SupportTriageEnv()
|
| 44 |
+
env.reset("medium_billing_dispute")
|
| 45 |
+
_, reward, _, info = env.step(Action(action_type="read_ticket", ticket_id="NOT-REAL"))
|
| 46 |
+
assert reward.value < 0.5
|
| 47 |
+
assert info["penalties"]
|