Spaces:
Sleeping
Sleeping
Commit ·
a3d65ce
0
Parent(s):
Initial commit
Browse files- .gitignore +3 -0
- Dockerfile +27 -0
- README.md +12 -0
- __init__.py +11 -0
- baseline.py +194 -0
- client.py +33 -0
- graders.py +136 -0
- gradio_ui.py +211 -0
- models.py +119 -0
- openenv.yaml +57 -0
- openenv_stub/openenv/__init__.py +0 -0
- openenv_stub/openenv/core/__init__.py +3 -0
- openenv_stub/openenv/core/env_client.py +12 -0
- openenv_stub/openenv/core/env_server/__init__.py +0 -0
- openenv_stub/openenv/core/env_server/http_server.py +13 -0
- openenv_stub/openenv/core/env_server/interfaces.py +28 -0
- openenv_stub/openenv/core/env_server/types.py +33 -0
- pyproject.toml +26 -0
- run_tests.py +276 -0
- server/__init__.py +4 -0
- server/app.py +28 -0
- server/requirements.txt +7 -0
- server/support_environment.py +281 -0
- tests/__init__.py +0 -0
- tests/conftest.py +6 -0
- tests/test_environment.py +191 -0
- tests/test_graders.py +121 -0
- tickets.py +84 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ── Dockerfile: Customer Support Ticket Resolution Environment ──
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# System deps
|
| 7 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 8 |
+
build-essential curl && rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Python deps
|
| 11 |
+
COPY server/requirements.txt /app/requirements.txt
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy source
|
| 15 |
+
COPY . /app/support_ticket_env
|
| 16 |
+
ENV PYTHONPATH=/app
|
| 17 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 18 |
+
|
| 19 |
+
# HF Spaces uses port 7860
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
# Health check
|
| 23 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 24 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "support_ticket_env.server.app:app", \
|
| 27 |
+
"--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Support Ticket Env
|
| 3 |
+
emoji: 🎫
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Customer Support Ticket Resolution Environment
|
| 11 |
+
|
| 12 |
+
A real-world OpenEnv environment for AI agent training.
|
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Customer Support Ticket Resolution — OpenEnv Environment."""
|
| 2 |
+
|
| 3 |
+
from support_ticket_env.models import SupportAction, SupportObservation, SupportState
|
| 4 |
+
from support_ticket_env.client import SupportTicketEnv
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"SupportAction",
|
| 8 |
+
"SupportObservation",
|
| 9 |
+
"SupportState",
|
| 10 |
+
"SupportTicketEnv",
|
| 11 |
+
]
|
baseline.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
baseline.py — Baseline inference script for the Support Ticket Environment.
|
| 4 |
+
|
| 5 |
+
Runs an OpenAI-compatible model against all 3 tasks and reports scores.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
OPENAI_API_KEY=sk-... python baseline.py --base-url http://localhost:7860
|
| 9 |
+
|
| 10 |
+
Environment variables:
|
| 11 |
+
OPENAI_API_KEY : required
|
| 12 |
+
OPENAI_BASE_URL : optional override (default https://api.openai.com/v1)
|
| 13 |
+
OPENAI_MODEL : optional model name (default gpt-4o-mini)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import asyncio
|
| 20 |
+
import re
|
| 21 |
+
|
| 22 |
+
from openai import AsyncOpenAI
|
| 23 |
+
from support_ticket_env.client import SupportTicketEnv
|
| 24 |
+
from support_ticket_env.models import SupportAction
|
| 25 |
+
|
| 26 |
+
# ─────────────────────────── Config ────────────────────────────
|
| 27 |
+
|
| 28 |
+
VALID_CATEGORIES = ["billing", "technical", "account", "general", "refund"]
|
| 29 |
+
VALID_ACTIONS = ["classify", "reply", "escalate", "close"]
|
| 30 |
+
|
| 31 |
+
SYSTEM_PROMPT = """You are a customer support AI agent operating in a ticket triage environment.
|
| 32 |
+
|
| 33 |
+
On each turn you receive a JSON observation with:
|
| 34 |
+
- ticket_text : the customer's message
|
| 35 |
+
- feedback : what happened last step
|
| 36 |
+
- task_id : 1=classify only, 2=classify then act, 3=full resolution
|
| 37 |
+
|
| 38 |
+
You must respond with a JSON object (no markdown) matching this schema:
|
| 39 |
+
{
|
| 40 |
+
"action_type": "classify" | "reply" | "escalate" | "close",
|
| 41 |
+
"category": "billing" | "technical" | "account" | "general" | "refund", // only for classify
|
| 42 |
+
"reply_text": "...", // only for reply
|
| 43 |
+
"reason": "..." // optional
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
Strategy:
|
| 47 |
+
- For task 1: only classify (use action_type="classify" with a category).
|
| 48 |
+
- For task 2: first classify, then choose the best action.
|
| 49 |
+
- For task 3: classify each ticket, then reply/escalate/close as appropriate.
|
| 50 |
+
|
| 51 |
+
Always produce valid JSON and nothing else.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_llm_response(text: str) -> dict:
|
| 56 |
+
"""Extract JSON from LLM response, stripping markdown fences if present."""
|
| 57 |
+
text = text.strip()
|
| 58 |
+
# Strip ```json ... ``` fences
|
| 59 |
+
text = re.sub(r"^```(?:json)?\s*", "", text)
|
| 60 |
+
text = re.sub(r"\s*```$", "", text)
|
| 61 |
+
try:
|
| 62 |
+
return json.loads(text)
|
| 63 |
+
except json.JSONDecodeError:
|
| 64 |
+
# Fallback: try to extract first JSON object
|
| 65 |
+
match = re.search(r"\{.*\}", text, re.DOTALL)
|
| 66 |
+
if match:
|
| 67 |
+
return json.loads(match.group())
|
| 68 |
+
raise
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def run_task(
|
| 72 |
+
env_base_url: str,
|
| 73 |
+
llm: AsyncOpenAI,
|
| 74 |
+
model: str,
|
| 75 |
+
task_id: int,
|
| 76 |
+
seed: int = 42,
|
| 77 |
+
max_steps: int = 10,
|
| 78 |
+
) -> float:
|
| 79 |
+
"""Run one episode for a given task_id. Returns the total reward."""
|
| 80 |
+
total_reward = 0.0
|
| 81 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 82 |
+
|
| 83 |
+
async with SupportTicketEnv(base_url=env_base_url) as env:
|
| 84 |
+
result = await env.reset(task_id=task_id, seed=seed)
|
| 85 |
+
obs = result.observation
|
| 86 |
+
|
| 87 |
+
for step in range(max_steps):
|
| 88 |
+
# Build user message from observation
|
| 89 |
+
obs_text = json.dumps({
|
| 90 |
+
"ticket_id": obs.ticket_id,
|
| 91 |
+
"ticket_text": obs.ticket_text,
|
| 92 |
+
"task_id": obs.task_id,
|
| 93 |
+
"current_category": obs.current_category,
|
| 94 |
+
"resolved": obs.resolved,
|
| 95 |
+
"step_count": obs.step_count,
|
| 96 |
+
"feedback": obs.feedback,
|
| 97 |
+
}, indent=2)
|
| 98 |
+
|
| 99 |
+
messages.append({"role": "user", "content": obs_text})
|
| 100 |
+
|
| 101 |
+
# Call LLM
|
| 102 |
+
response = await llm.chat.completions.create(
|
| 103 |
+
model=model,
|
| 104 |
+
messages=messages,
|
| 105 |
+
temperature=0.0,
|
| 106 |
+
max_tokens=256,
|
| 107 |
+
)
|
| 108 |
+
assistant_text = response.choices[0].message.content
|
| 109 |
+
messages.append({"role": "assistant", "content": assistant_text})
|
| 110 |
+
|
| 111 |
+
# Parse action
|
| 112 |
+
try:
|
| 113 |
+
action_dict = parse_llm_response(assistant_text)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f" [step {step+1}] Failed to parse LLM response: {e}")
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
action = SupportAction(**action_dict)
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f" [step {step+1}] Invalid action schema: {e}")
|
| 122 |
+
break
|
| 123 |
+
|
| 124 |
+
# Step environment
|
| 125 |
+
result = await env.step(action)
|
| 126 |
+
obs = result.observation
|
| 127 |
+
reward = result.reward or 0.0
|
| 128 |
+
total_reward += reward
|
| 129 |
+
|
| 130 |
+
print(
|
| 131 |
+
f" [step {step+1}] action={action.action_type}"
|
| 132 |
+
+ (f"/{action.category}" if action.category else "")
|
| 133 |
+
+ f" reward={reward:.3f} feedback={obs.feedback[:60]}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if result.done:
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
return round(total_reward, 4)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
async def main(env_base_url: str, model: str, seeds: list[int]) -> None:
|
| 143 |
+
api_key = os.environ.get("OPENAI_API_KEY", "not-needed")
|
| 144 |
+
openai_base = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 145 |
+
|
| 146 |
+
llm = AsyncOpenAI(api_key=api_key, base_url=openai_base)
|
| 147 |
+
|
| 148 |
+
results = {}
|
| 149 |
+
for task_id in [1, 2, 3]:
|
| 150 |
+
task_scores = []
|
| 151 |
+
print(f"\n{'='*60}")
|
| 152 |
+
print(f" TASK {task_id} (seed={seeds[0]})")
|
| 153 |
+
print(f"{'='*60}")
|
| 154 |
+
for seed in seeds:
|
| 155 |
+
score = await run_task(env_base_url, llm, model, task_id, seed=seed)
|
| 156 |
+
task_scores.append(score)
|
| 157 |
+
print(f" → total_reward for seed {seed}: {score}")
|
| 158 |
+
avg = round(sum(task_scores) / len(task_scores), 4)
|
| 159 |
+
results[f"task{task_id}"] = {"scores": task_scores, "avg": avg}
|
| 160 |
+
print(f" ► Average: {avg}")
|
| 161 |
+
|
| 162 |
+
print("\n" + "="*60)
|
| 163 |
+
print(" BASELINE SUMMARY")
|
| 164 |
+
print("="*60)
|
| 165 |
+
for k, v in results.items():
|
| 166 |
+
print(f" {k}: avg={v['avg']:.4f} scores={v['scores']}")
|
| 167 |
+
overall = round(
|
| 168 |
+
sum(v["avg"] for v in results.values()) / len(results), 4
|
| 169 |
+
)
|
| 170 |
+
print(f"\n Overall avg: {overall:.4f}")
|
| 171 |
+
print("="*60)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
parser = argparse.ArgumentParser(description="Baseline inference for support_ticket_env")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--base-url",
|
| 178 |
+
default=os.environ.get("ENV_BASE_URL", "http://localhost:7860"),
|
| 179 |
+
help="Base URL of the running environment server",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--model",
|
| 183 |
+
default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
| 184 |
+
help="OpenAI model name",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--seeds",
|
| 188 |
+
nargs="+",
|
| 189 |
+
type=int,
|
| 190 |
+
default=[42, 7, 123],
|
| 191 |
+
help="Random seeds for reproducibility",
|
| 192 |
+
)
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
asyncio.run(main(args.base_url, args.model, args.seeds))
|
client.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Client for the Customer Support Ticket Resolution Environment.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_client import EnvClient
|
| 6 |
+
from support_ticket_env.models import SupportAction, SupportObservation, SupportState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SupportTicketEnv(EnvClient[SupportAction, SupportObservation, SupportState]):
|
| 10 |
+
"""
|
| 11 |
+
OpenEnv client for the Support Ticket Resolution environment.
|
| 12 |
+
|
| 13 |
+
Usage (async):
|
| 14 |
+
async with SupportTicketEnv(base_url="http://localhost:8000") as env:
|
| 15 |
+
result = await env.reset(task_id=1)
|
| 16 |
+
result = await env.step(SupportAction(action_type="classify", category="billing"))
|
| 17 |
+
|
| 18 |
+
Usage (sync):
|
| 19 |
+
with SupportTicketEnv(base_url="http://localhost:8000").sync() as env:
|
| 20 |
+
result = env.reset(task_id=2)
|
| 21 |
+
result = env.step(SupportAction(action_type="classify", category="technical"))
|
| 22 |
+
result = env.step(SupportAction(action_type="escalate"))
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def _parse_action(self, action: SupportAction) -> dict:
|
| 26 |
+
return action.model_dump()
|
| 27 |
+
|
| 28 |
+
def _parse_result(self, data: dict) -> SupportObservation:
|
| 29 |
+
obs_data = data.get("observation", data)
|
| 30 |
+
return SupportObservation(**obs_data)
|
| 31 |
+
|
| 32 |
+
def _parse_state(self, data: dict) -> SupportState:
|
| 33 |
+
return SupportState(**data)
|
graders.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graders for all three tasks.
|
| 3 |
+
|
| 4 |
+
Each grader returns a float in [0.0, 1.0].
|
| 5 |
+
|
| 6 |
+
Task 1 – Classification (easy)
|
| 7 |
+
- 1.0 : correct category
|
| 8 |
+
- 0.0 : wrong category
|
| 9 |
+
|
| 10 |
+
Task 2 – Action Selection (medium)
|
| 11 |
+
- 1.0 : correct action
|
| 12 |
+
- 0.5 : partially correct (e.g., escalate vs reply both defensible)
|
| 13 |
+
- 0.0 : clearly wrong (e.g., close an unsolved ticket)
|
| 14 |
+
|
| 15 |
+
Task 3 – Full Resolution (hard)
|
| 16 |
+
Combines classification + action + reply quality into a single score.
|
| 17 |
+
Rewards partial progress so the agent gets signal throughout the trajectory.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
from typing import Dict, Any
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ─────────────────────────── helpers ───────────────────────────
|
| 25 |
+
|
| 26 |
+
# Pairs of actions that are considered "close enough" for partial credit
|
| 27 |
+
_PARTIAL_CREDIT_PAIRS = {
|
| 28 |
+
frozenset({"reply", "escalate"}), # borderline tickets
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
_KEYWORD_REWARDS: Dict[str, list[str]] = {
|
| 32 |
+
"billing": ["refund", "charge", "invoice", "payment", "billing"],
|
| 33 |
+
"account": ["password", "login", "account", "cancel", "subscription"],
|
| 34 |
+
"technical": ["engineering", "escalate", "bug", "crash", "error", "fix"],
|
| 35 |
+
"refund": ["refund", "return", "credit", "process"],
|
| 36 |
+
"general": ["hours", "contact", "phone", "information", "help"],
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _reply_quality(reply_text: str, category: str) -> float:
|
| 41 |
+
"""Return 0.0–0.5 based on how relevant the reply text is."""
|
| 42 |
+
if not reply_text:
|
| 43 |
+
return 0.0
|
| 44 |
+
text_lower = reply_text.lower()
|
| 45 |
+
keywords = _KEYWORD_REWARDS.get(category, [])
|
| 46 |
+
hits = sum(1 for kw in keywords if kw in text_lower)
|
| 47 |
+
# cap at 0.5 (the other 0.5 comes from action correctness)
|
| 48 |
+
return min(0.5, hits * 0.1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ─────────────────────────── Task 1 ────────────────────────────
|
| 52 |
+
|
| 53 |
+
def grade_task1(
|
| 54 |
+
predicted_category: str,
|
| 55 |
+
correct_category: str,
|
| 56 |
+
) -> float:
|
| 57 |
+
"""Binary classification reward."""
|
| 58 |
+
return 1.0 if predicted_category == correct_category else 0.0
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ─────────────────────────── Task 2 ────────────────────────────
|
| 62 |
+
|
| 63 |
+
def grade_task2(
|
| 64 |
+
action_type: str,
|
| 65 |
+
correct_action: str,
|
| 66 |
+
category: str | None = None,
|
| 67 |
+
) -> float:
|
| 68 |
+
"""
|
| 69 |
+
Action-selection reward.
|
| 70 |
+
Full credit for exact match, partial credit for defensible alternatives.
|
| 71 |
+
Penalises closing an unresolved ticket.
|
| 72 |
+
"""
|
| 73 |
+
if action_type == correct_action:
|
| 74 |
+
return 1.0
|
| 75 |
+
|
| 76 |
+
# Partial credit for ambiguous cases
|
| 77 |
+
pair = frozenset({action_type, correct_action})
|
| 78 |
+
if pair in _PARTIAL_CREDIT_PAIRS:
|
| 79 |
+
return 0.5
|
| 80 |
+
|
| 81 |
+
# Closing an unresolved ticket is always wrong
|
| 82 |
+
if action_type == "close":
|
| 83 |
+
return 0.0
|
| 84 |
+
|
| 85 |
+
return 0.0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ─────────────────────────── Task 3 ────────────────────────────
|
| 89 |
+
|
| 90 |
+
def grade_task3(
|
| 91 |
+
classified_correctly: bool,
|
| 92 |
+
action_correct: bool,
|
| 93 |
+
action_partial: bool,
|
| 94 |
+
reply_text: str | None,
|
| 95 |
+
category: str,
|
| 96 |
+
resolved: bool,
|
| 97 |
+
steps_taken: int,
|
| 98 |
+
max_steps: int = 5,
|
| 99 |
+
) -> float:
|
| 100 |
+
"""
|
| 101 |
+
Multi-step resolution reward with partial progress.
|
| 102 |
+
|
| 103 |
+
Breakdown:
|
| 104 |
+
0.20 – classification correct
|
| 105 |
+
0.40 – action correct (0.20 if partial)
|
| 106 |
+
0.25 – reply quality (NLP keyword overlap)
|
| 107 |
+
0.15 – efficiency bonus (fewer steps → higher bonus)
|
| 108 |
+
"""
|
| 109 |
+
score = 0.0
|
| 110 |
+
|
| 111 |
+
if classified_correctly:
|
| 112 |
+
score += 0.20
|
| 113 |
+
|
| 114 |
+
if action_correct:
|
| 115 |
+
score += 0.40
|
| 116 |
+
elif action_partial:
|
| 117 |
+
score += 0.20
|
| 118 |
+
|
| 119 |
+
if reply_text:
|
| 120 |
+
score += _reply_quality(reply_text, category) * 0.5 # scaled to 0.25 max
|
| 121 |
+
|
| 122 |
+
# Efficiency: full 0.15 for 1 step, 0 for max_steps steps
|
| 123 |
+
if resolved and steps_taken <= max_steps:
|
| 124 |
+
efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))
|
| 125 |
+
score += 0.15 * efficiency
|
| 126 |
+
|
| 127 |
+
return round(min(1.0, score), 4)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ─────────────────────────── Penalty ───────────────────────────
|
| 131 |
+
|
| 132 |
+
def loop_penalty(step_count: int, max_steps: int = 10) -> float:
|
| 133 |
+
"""Return a negative reward if agent is stuck in a loop."""
|
| 134 |
+
if step_count > max_steps:
|
| 135 |
+
return -0.05 * (step_count - max_steps)
|
| 136 |
+
return 0.0
|
gradio_ui.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gradio_ui.py — Interactive Gradio web interface for the Support Ticket Environment.
|
| 3 |
+
|
| 4 |
+
Allows human exploration and debugging without writing code.
|
| 5 |
+
Launched automatically when ENABLE_WEB_INTERFACE=true or run directly.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python support_ticket_env/gradio_ui.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
+
STUB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "openenv_stub")
|
| 17 |
+
sys.path.insert(0, STUB)
|
| 18 |
+
sys.path.insert(0, ROOT)
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import gradio as gr
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("gradio not installed. Run: pip install gradio")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
|
| 26 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 27 |
+
from support_ticket_env.models import SupportAction
|
| 28 |
+
|
| 29 |
+
# ─── shared env instance ────────────────────────────────────────
|
| 30 |
+
_env = SupportTicketEnvironment()
|
| 31 |
+
_history: list[dict] = []
|
| 32 |
+
_current_obs = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ─── helpers ────────────────────────────────────────────────────
|
| 36 |
+
|
| 37 |
+
def _format_history() -> str:
|
| 38 |
+
if not _history:
|
| 39 |
+
return "_No actions yet._"
|
| 40 |
+
lines = []
|
| 41 |
+
for i, h in enumerate(_history, 1):
|
| 42 |
+
reward_str = f"{h['reward']:+.3f}" if h["reward"] is not None else "—"
|
| 43 |
+
lines.append(
|
| 44 |
+
f"**Step {i}** | `{h['action']}` → reward `{reward_str}`\n"
|
| 45 |
+
f"> {h['feedback']}"
|
| 46 |
+
)
|
| 47 |
+
return "\n\n".join(lines)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _obs_to_display(obs) -> tuple[str, str, str]:
|
| 51 |
+
"""Return (ticket_box, status_box, score_box)."""
|
| 52 |
+
ticket = f"**[{obs.ticket_id}]** {obs.ticket_text}"
|
| 53 |
+
status = (
|
| 54 |
+
f"Task **{obs.task_id}** | Step **{obs.step_count}** | "
|
| 55 |
+
f"Category: `{obs.current_category or 'unknown'}` | "
|
| 56 |
+
f"Resolved: {'✅' if obs.resolved else '⬜'}"
|
| 57 |
+
)
|
| 58 |
+
score = f"Last step score: **{obs.score:.3f}** | reward: **{obs.reward or 0.0:+.3f}**"
|
| 59 |
+
return ticket, status, score
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ─── UI callbacks ────────────────────────────────────────────────
|
| 63 |
+
|
| 64 |
+
def do_reset(task_id: int, seed: int):
|
| 65 |
+
global _history, _current_obs
|
| 66 |
+
_history = []
|
| 67 |
+
obs = _env.reset(task_id=task_id, seed=seed)
|
| 68 |
+
_current_obs = obs
|
| 69 |
+
ticket, status, score = _obs_to_display(obs)
|
| 70 |
+
return (
|
| 71 |
+
ticket, status, score,
|
| 72 |
+
_format_history(),
|
| 73 |
+
gr.update(interactive=True),
|
| 74 |
+
obs.feedback,
|
| 75 |
+
gr.update(value=False), # done flag
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def do_step(action_type: str, category: str, reply_text: str, reason: str):
|
| 80 |
+
global _current_obs
|
| 81 |
+
if _current_obs is None:
|
| 82 |
+
return (
|
| 83 |
+
"⚠️ Please reset the environment first.",
|
| 84 |
+
"", "", _format_history(), "", gr.update(value=False),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Build action
|
| 88 |
+
kwargs = {"action_type": action_type}
|
| 89 |
+
if action_type == "classify" and category:
|
| 90 |
+
kwargs["category"] = category
|
| 91 |
+
if action_type == "reply" and reply_text:
|
| 92 |
+
kwargs["reply_text"] = reply_text
|
| 93 |
+
if reason:
|
| 94 |
+
kwargs["reason"] = reason
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
action = SupportAction(**kwargs)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
return (
|
| 100 |
+
_current_obs.ticket_text,
|
| 101 |
+
f"❌ Invalid action: {e}", "",
|
| 102 |
+
_format_history(), "", gr.update(value=False),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
obs = _env.step(action)
|
| 106 |
+
_current_obs = obs
|
| 107 |
+
|
| 108 |
+
_history.append({
|
| 109 |
+
"action": f"{action_type}" + (f"/{category}" if category and action_type == "classify" else ""),
|
| 110 |
+
"reward": obs.reward,
|
| 111 |
+
"feedback": obs.feedback,
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
ticket, status, score = _obs_to_display(obs)
|
| 115 |
+
done_msg = "🏁 Episode finished!" if obs.done else ""
|
| 116 |
+
return (
|
| 117 |
+
ticket, status, score,
|
| 118 |
+
_format_history(),
|
| 119 |
+
obs.feedback,
|
| 120 |
+
gr.update(value=obs.done),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def do_state():
|
| 125 |
+
state = _env.state
|
| 126 |
+
return json.dumps({
|
| 127 |
+
"episode_id": state.episode_id,
|
| 128 |
+
"step_count": state.step_count,
|
| 129 |
+
"task_id": state.task_id,
|
| 130 |
+
"ticket_id": state.ticket_id,
|
| 131 |
+
"correct_category": state.correct_category,
|
| 132 |
+
"correct_action": state.correct_action,
|
| 133 |
+
"classified": state.classified,
|
| 134 |
+
"resolved": state.resolved,
|
| 135 |
+
"total_reward": state.total_reward,
|
| 136 |
+
"tickets_resolved": state.tickets_resolved,
|
| 137 |
+
"tickets_total": state.tickets_total,
|
| 138 |
+
}, indent=2)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ─── UI layout ──────────────────────────────────────────────────
|
| 142 |
+
|
| 143 |
+
DESCRIPTION = """
|
| 144 |
+
# 🎫 Customer Support Ticket Resolution Environment
|
| 145 |
+
|
| 146 |
+
An **OpenEnv** environment for training AI agents to handle customer support tickets.
|
| 147 |
+
|
| 148 |
+
**Tasks:** 1 = Classify · 2 = Classify + Action · 3 = Full Queue Resolution
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
with gr.Blocks(title="Support Ticket Env", theme=gr.themes.Soft()) as demo:
|
| 152 |
+
gr.Markdown(DESCRIPTION)
|
| 153 |
+
|
| 154 |
+
with gr.Row():
|
| 155 |
+
# ── Left panel: controls ────────────────────────────────
|
| 156 |
+
with gr.Column(scale=1):
|
| 157 |
+
gr.Markdown("### ⚙️ Episode Setup")
|
| 158 |
+
task_slider = gr.Slider(1, 3, value=1, step=1, label="Task ID")
|
| 159 |
+
seed_input = gr.Number(value=42, label="Seed", precision=0)
|
| 160 |
+
reset_btn = gr.Button("🔄 Reset Episode", variant="primary")
|
| 161 |
+
|
| 162 |
+
gr.Markdown("### 🎬 Take Action")
|
| 163 |
+
action_type = gr.Radio(
|
| 164 |
+
["classify", "reply", "escalate", "close"],
|
| 165 |
+
value="classify", label="Action Type",
|
| 166 |
+
)
|
| 167 |
+
category_dd = gr.Dropdown(
|
| 168 |
+
["billing", "technical", "account", "general", "refund"],
|
| 169 |
+
label="Category (for classify)",
|
| 170 |
+
value=None,
|
| 171 |
+
)
|
| 172 |
+
reply_box = gr.Textbox(label="Reply Text (for reply)", lines=3)
|
| 173 |
+
reason_box = gr.Textbox(label="Reason (optional)")
|
| 174 |
+
step_btn = gr.Button("▶️ Step", variant="secondary")
|
| 175 |
+
state_btn = gr.Button("🔍 Show State")
|
| 176 |
+
|
| 177 |
+
# ── Right panel: observation ────────────────────────────
|
| 178 |
+
with gr.Column(scale=2):
|
| 179 |
+
gr.Markdown("### 📬 Current Ticket")
|
| 180 |
+
ticket_display = gr.Markdown("_Reset to start._")
|
| 181 |
+
status_display = gr.Markdown("")
|
| 182 |
+
score_display = gr.Markdown("")
|
| 183 |
+
feedback_box = gr.Textbox(label="Last Feedback", interactive=False)
|
| 184 |
+
done_checkbox = gr.Checkbox(label="Episode Done", interactive=False)
|
| 185 |
+
|
| 186 |
+
gr.Markdown("### 📜 Action History")
|
| 187 |
+
history_display = gr.Markdown("_No actions yet._")
|
| 188 |
+
|
| 189 |
+
gr.Markdown("### 🗂️ Raw State (JSON)")
|
| 190 |
+
state_output = gr.Code(language="json", label="state()")
|
| 191 |
+
|
| 192 |
+
# ── wire up ─────────────────────────────────────────────────
|
| 193 |
+
reset_btn.click(
|
| 194 |
+
do_reset,
|
| 195 |
+
inputs=[task_slider, seed_input],
|
| 196 |
+
outputs=[ticket_display, status_display, score_display,
|
| 197 |
+
history_display, step_btn, feedback_box, done_checkbox],
|
| 198 |
+
)
|
| 199 |
+
step_btn.click(
|
| 200 |
+
do_step,
|
| 201 |
+
inputs=[action_type, category_dd, reply_box, reason_box],
|
| 202 |
+
outputs=[ticket_display, status_display, score_display,
|
| 203 |
+
history_display, feedback_box, done_checkbox],
|
| 204 |
+
)
|
| 205 |
+
state_btn.click(
|
| 206 |
+
do_state, inputs=[], outputs=[state_output],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
demo.launch(server_name="0.0.0.0", server_port=7861, share=False)
|
models.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Typed models for the Customer Support Ticket Resolution Environment.
|
| 3 |
+
Works with pydantic (production) or stdlib (offline/testing).
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from pydantic import BaseModel, ConfigDict
|
| 10 |
+
_USE_PYDANTIC = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
_USE_PYDANTIC = False
|
| 13 |
+
|
| 14 |
+
# ── import base classes from openenv (or stub) ──────────────────
|
| 15 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ═══════════════════════════════════════════════════════════════
|
| 19 |
+
# Action
|
| 20 |
+
# ═══════════════════════════════════════════════════════════════
|
| 21 |
+
|
| 22 |
+
if _USE_PYDANTIC:
|
| 23 |
+
class SupportAction(Action, BaseModel): # type: ignore[misc]
|
| 24 |
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
|
| 25 |
+
metadata: Dict[str, Any] = {}
|
| 26 |
+
action_type: Literal["classify", "reply", "escalate", "close"]
|
| 27 |
+
category: Optional[
|
| 28 |
+
Literal["billing", "technical", "account", "general", "refund"]
|
| 29 |
+
] = None
|
| 30 |
+
reply_text: Optional[str] = None
|
| 31 |
+
reason: Optional[str] = None
|
| 32 |
+
|
| 33 |
+
def model_dump(self, **kw):
|
| 34 |
+
return super().model_dump(**kw)
|
| 35 |
+
else:
|
| 36 |
+
_VALID_ACTION_TYPES = {"classify", "reply", "escalate", "close"}
|
| 37 |
+
_VALID_CATEGORIES = {"billing", "technical", "account", "general", "refund"}
|
| 38 |
+
|
| 39 |
+
class SupportAction(Action): # type: ignore[no-redef]
|
| 40 |
+
def __init__(self, **kwargs):
|
| 41 |
+
action_type = kwargs.get("action_type")
|
| 42 |
+
if action_type not in _VALID_ACTION_TYPES:
|
| 43 |
+
raise ValueError(f"Invalid action_type: {action_type!r}")
|
| 44 |
+
category = kwargs.get("category")
|
| 45 |
+
if category is not None and category not in _VALID_CATEGORIES:
|
| 46 |
+
raise ValueError(f"Invalid category: {category!r}")
|
| 47 |
+
self.action_type = action_type
|
| 48 |
+
self.category = category
|
| 49 |
+
self.reply_text = kwargs.get("reply_text")
|
| 50 |
+
self.reason = kwargs.get("reason")
|
| 51 |
+
self.metadata = kwargs.get("metadata", {})
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ═══════════════════════════════════════════════════════════════
|
| 55 |
+
# Observation
|
| 56 |
+
# ═══════════════════════════════════════════════════════════════
|
| 57 |
+
|
| 58 |
+
if _USE_PYDANTIC:
|
| 59 |
+
class SupportObservation(Observation, BaseModel): # type: ignore[misc]
|
| 60 |
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
|
| 61 |
+
done: bool = False
|
| 62 |
+
reward: Optional[float] = None
|
| 63 |
+
metadata: Dict[str, Any] = {}
|
| 64 |
+
ticket_id: str = ""
|
| 65 |
+
ticket_text: str = ""
|
| 66 |
+
task_id: int = 1
|
| 67 |
+
current_category: Optional[str] = None
|
| 68 |
+
resolved: bool = False
|
| 69 |
+
step_count: int = 0
|
| 70 |
+
feedback: str = ""
|
| 71 |
+
score: float = 0.0
|
| 72 |
+
else:
|
| 73 |
+
class SupportObservation(Observation): # type: ignore[no-redef]
|
| 74 |
+
def __init__(self, **kwargs):
|
| 75 |
+
self.done = kwargs.pop("done", False)
|
| 76 |
+
self.reward = kwargs.pop("reward", None)
|
| 77 |
+
self.metadata = kwargs.pop("metadata", {})
|
| 78 |
+
self.ticket_id = kwargs.pop("ticket_id", "")
|
| 79 |
+
self.ticket_text = kwargs.pop("ticket_text", "")
|
| 80 |
+
self.task_id = kwargs.pop("task_id", 1)
|
| 81 |
+
self.current_category = kwargs.pop("current_category", None)
|
| 82 |
+
self.resolved = kwargs.pop("resolved", False)
|
| 83 |
+
self.step_count = kwargs.pop("step_count", 0)
|
| 84 |
+
self.feedback = kwargs.pop("feedback", "")
|
| 85 |
+
self.score = kwargs.pop("score", 0.0)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ═══════════════════════════════════════════════════════════════
|
| 89 |
+
# State
|
| 90 |
+
# ═══════════════════════════════════════════════════════════════
|
| 91 |
+
|
| 92 |
+
if _USE_PYDANTIC:
|
| 93 |
+
class SupportState(State, BaseModel): # type: ignore[misc]
|
| 94 |
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
| 95 |
+
episode_id: Optional[str] = None
|
| 96 |
+
step_count: int = 0
|
| 97 |
+
task_id: int = 1
|
| 98 |
+
ticket_id: str = ""
|
| 99 |
+
correct_category: str = ""
|
| 100 |
+
correct_action: str = ""
|
| 101 |
+
classified: bool = False
|
| 102 |
+
resolved: bool = False
|
| 103 |
+
total_reward: float = 0.0
|
| 104 |
+
tickets_resolved: int = 0
|
| 105 |
+
tickets_total: int = 1
|
| 106 |
+
else:
|
| 107 |
+
class SupportState(State): # type: ignore[no-redef]
|
| 108 |
+
def __init__(self, **kwargs):
|
| 109 |
+
self.episode_id = kwargs.pop("episode_id", None)
|
| 110 |
+
self.step_count = kwargs.pop("step_count", 0)
|
| 111 |
+
self.task_id = kwargs.pop("task_id", 1)
|
| 112 |
+
self.ticket_id = kwargs.pop("ticket_id", "")
|
| 113 |
+
self.correct_category = kwargs.pop("correct_category", "")
|
| 114 |
+
self.correct_action = kwargs.pop("correct_action", "")
|
| 115 |
+
self.classified = kwargs.pop("classified", False)
|
| 116 |
+
self.resolved = kwargs.pop("resolved", False)
|
| 117 |
+
self.total_reward = kwargs.pop("total_reward", 0.0)
|
| 118 |
+
self.tickets_resolved = kwargs.pop("tickets_resolved", 0)
|
| 119 |
+
self.tickets_total = kwargs.pop("tickets_total", 1)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: support_ticket_env
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A real-world customer support ticket triage environment.
|
| 5 |
+
An AI agent acts as a support executive: it classifies incoming tickets,
|
| 6 |
+
selects the correct action (reply / escalate / close), and resolves
|
| 7 |
+
multi-ticket queues efficiently.
|
| 8 |
+
author: OpenEnv Hackathon Entry
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- customer-support
|
| 12 |
+
- triage
|
| 13 |
+
- nlp
|
| 14 |
+
- real-world
|
| 15 |
+
tasks:
|
| 16 |
+
- id: 1
|
| 17 |
+
name: Classification
|
| 18 |
+
difficulty: easy
|
| 19 |
+
description: >
|
| 20 |
+
Given a customer ticket, predict the correct category
|
| 21 |
+
(billing | technical | account | general | refund).
|
| 22 |
+
score_range: [0.0, 1.0]
|
| 23 |
+
- id: 2
|
| 24 |
+
name: Action Selection
|
| 25 |
+
difficulty: medium
|
| 26 |
+
description: >
|
| 27 |
+
First classify the ticket, then choose the best action:
|
| 28 |
+
reply, escalate, or close.
|
| 29 |
+
score_range: [0.0, 1.0]
|
| 30 |
+
- id: 3
|
| 31 |
+
name: Full Resolution
|
| 32 |
+
difficulty: hard
|
| 33 |
+
description: >
|
| 34 |
+
Handle a queue of 3 tickets. For each ticket classify it,
|
| 35 |
+
choose the right action, and (if replying) craft a relevant reply.
|
| 36 |
+
Bonus for fewer steps.
|
| 37 |
+
score_range: [0.0, 1.0]
|
| 38 |
+
action_space:
|
| 39 |
+
type: SupportAction
|
| 40 |
+
fields:
|
| 41 |
+
action_type: "classify | reply | escalate | close"
|
| 42 |
+
category: "billing | technical | account | general | refund (required for classify)"
|
| 43 |
+
reply_text: "string (required for reply)"
|
| 44 |
+
reason: "optional justification"
|
| 45 |
+
observation_space:
|
| 46 |
+
type: SupportObservation
|
| 47 |
+
fields:
|
| 48 |
+
ticket_id: string
|
| 49 |
+
ticket_text: string
|
| 50 |
+
task_id: integer
|
| 51 |
+
current_category: "string | null"
|
| 52 |
+
resolved: boolean
|
| 53 |
+
step_count: integer
|
| 54 |
+
feedback: string
|
| 55 |
+
score: float
|
| 56 |
+
docker_image: "support-ticket-env:latest"
|
| 57 |
+
hf_space: "openenv/support-ticket-env"
|
openenv_stub/openenv/__init__.py
ADDED
|
File without changes
|
openenv_stub/openenv/core/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openenv.core.env_client import EnvClient
|
| 2 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 3 |
+
from openenv.core.env_server.interfaces import Environment
|
openenv_stub/openenv/core/env_client.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stub for openenv.core.env_client."""
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import Generic, TypeVar
|
| 4 |
+
|
| 5 |
+
ActT = TypeVar("ActT")
|
| 6 |
+
ObsT = TypeVar("ObsT")
|
| 7 |
+
StateT = TypeVar("StateT")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
|
| 11 |
+
def __init__(self, base_url: str, **kwargs):
|
| 12 |
+
self.base_url = base_url
|
openenv_stub/openenv/core/env_server/__init__.py
ADDED
|
File without changes
|
openenv_stub/openenv/core/env_server/http_server.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stub for openenv.core.env_server.http_server."""
|
| 2 |
+
from typing import Any, Callable, Optional, Type
|
| 3 |
+
from openenv.core.env_server.types import Action, Observation
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_app(env, action_cls, observation_cls, env_name=None, max_concurrent_envs=1, **kwargs):
|
| 7 |
+
"""Stub — returns None when FastAPI is not available."""
|
| 8 |
+
try:
|
| 9 |
+
from fastapi import FastAPI
|
| 10 |
+
app = FastAPI(title=env_name or "SupportTicketEnv")
|
| 11 |
+
return app
|
| 12 |
+
except ImportError:
|
| 13 |
+
return None
|
openenv_stub/openenv/core/env_server/interfaces.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stub for openenv.core.env_server.interfaces."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Environment(ABC):
|
| 9 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = False
|
| 10 |
+
|
| 11 |
+
def __init__(self, transform=None, rubric=None):
|
| 12 |
+
self.rubric = rubric
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> Observation:
|
| 16 |
+
...
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs) -> Observation:
|
| 20 |
+
...
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def state(self) -> State:
|
| 25 |
+
...
|
| 26 |
+
|
| 27 |
+
def close(self) -> None:
|
| 28 |
+
pass
|
openenv_stub/openenv/core/env_server/types.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Offline stub for openenv.core.env_server.types.
|
| 3 |
+
Uses stdlib only — no pydantic required.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Action:
|
| 10 |
+
def __init__(self, **kwargs):
|
| 11 |
+
self.metadata = kwargs.pop("metadata", {})
|
| 12 |
+
for k, v in kwargs.items():
|
| 13 |
+
setattr(self, k, v)
|
| 14 |
+
|
| 15 |
+
def model_dump(self):
|
| 16 |
+
return {k: v for k, v in vars(self).items()}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Observation:
|
| 20 |
+
def __init__(self, **kwargs):
|
| 21 |
+
self.done = kwargs.pop("done", False)
|
| 22 |
+
self.reward = kwargs.pop("reward", None)
|
| 23 |
+
self.metadata = kwargs.pop("metadata", {})
|
| 24 |
+
for k, v in kwargs.items():
|
| 25 |
+
setattr(self, k, v)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class State:
|
| 29 |
+
def __init__(self, **kwargs):
|
| 30 |
+
self.episode_id = kwargs.pop("episode_id", None)
|
| 31 |
+
self.step_count = kwargs.pop("step_count", 0)
|
| 32 |
+
for k, v in kwargs.items():
|
| 33 |
+
setattr(self, k, v)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.backends.legacy:build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "support-ticket-env"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Customer Support Ticket Resolution — OpenEnv Environment"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
dependencies = [
|
| 13 |
+
"openenv-core>=0.2.1",
|
| 14 |
+
"fastapi>=0.104.0",
|
| 15 |
+
"uvicorn[standard]>=0.24.0",
|
| 16 |
+
"pydantic>=2.0.0",
|
| 17 |
+
"openai>=1.0.0",
|
| 18 |
+
"pyyaml>=6.0",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[project.optional-dependencies]
|
| 22 |
+
dev = ["pytest>=7.0", "pytest-asyncio", "httpx"]
|
| 23 |
+
|
| 24 |
+
[tool.setuptools.packages.find]
|
| 25 |
+
where = ["."]
|
| 26 |
+
include = ["support_ticket_env*"]
|
run_tests.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
run_tests.py — Self-contained test runner for support_ticket_env.
|
| 4 |
+
Runs all test cases using only the Python standard library.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python run_tests.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
import traceback
|
| 13 |
+
from typing import Callable, List, Tuple
|
| 14 |
+
|
| 15 |
+
# ─── path setup ────────────────────────────────────────────────
|
| 16 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
STUB = os.path.join(ROOT, "openenv_stub")
|
| 18 |
+
sys.path.insert(0, STUB)
|
| 19 |
+
sys.path.insert(0, ROOT)
|
| 20 |
+
|
| 21 |
+
# ─── minimal test framework ────────────────────────────────────
|
| 22 |
+
_tests: List[Tuple[str, Callable]] = []
|
| 23 |
+
_passed = 0
|
| 24 |
+
_failed = 0
|
| 25 |
+
_errors = 0
|
| 26 |
+
|
| 27 |
+
def test(fn: Callable) -> Callable:
|
| 28 |
+
_tests.append((fn.__qualname__, fn))
|
| 29 |
+
return fn
|
| 30 |
+
|
| 31 |
+
def assert_eq(a, b, msg=""):
|
| 32 |
+
if a != b:
|
| 33 |
+
raise AssertionError(f"{msg} | expected {b!r}, got {a!r}")
|
| 34 |
+
|
| 35 |
+
def assert_true(val, msg=""):
|
| 36 |
+
if not val:
|
| 37 |
+
raise AssertionError(msg or f"Expected truthy, got {val!r}")
|
| 38 |
+
|
| 39 |
+
def assert_in_range(val, lo, hi, msg=""):
|
| 40 |
+
if not (lo <= val <= hi):
|
| 41 |
+
raise AssertionError(msg or f"Expected {val!r} in [{lo}, {hi}]")
|
| 42 |
+
|
| 43 |
+
# ─────────────────────────────── imports ───────────────────────
|
| 44 |
+
from support_ticket_env.graders import (
|
| 45 |
+
grade_task1, grade_task2, grade_task3, loop_penalty,
|
| 46 |
+
)
|
| 47 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 48 |
+
from support_ticket_env.models import SupportAction
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def make_env():
|
| 52 |
+
return SupportTicketEnvironment()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ════════════════════════════════════════════════════════════════
|
| 56 |
+
# GRADER TESTS
|
| 57 |
+
# ════════════════════════════════════════════════════════════════
|
| 58 |
+
|
| 59 |
+
@test
|
| 60 |
+
def test_grade1_correct():
|
| 61 |
+
assert_eq(grade_task1("billing", "billing"), 1.0)
|
| 62 |
+
|
| 63 |
+
@test
|
| 64 |
+
def test_grade1_wrong():
|
| 65 |
+
assert_eq(grade_task1("technical", "billing"), 0.0)
|
| 66 |
+
|
| 67 |
+
@test
|
| 68 |
+
def test_grade1_all_categories():
|
| 69 |
+
for cat in ["billing", "technical", "account", "general", "refund"]:
|
| 70 |
+
assert_eq(grade_task1(cat, cat), 1.0, f"cat={cat}")
|
| 71 |
+
|
| 72 |
+
@test
|
| 73 |
+
def test_grade1_empty():
|
| 74 |
+
assert_eq(grade_task1("", "billing"), 0.0)
|
| 75 |
+
|
| 76 |
+
@test
|
| 77 |
+
def test_grade2_exact_reply():
|
| 78 |
+
assert_eq(grade_task2("reply", "reply"), 1.0)
|
| 79 |
+
|
| 80 |
+
@test
|
| 81 |
+
def test_grade2_exact_escalate():
|
| 82 |
+
assert_eq(grade_task2("escalate", "escalate"), 1.0)
|
| 83 |
+
|
| 84 |
+
@test
|
| 85 |
+
def test_grade2_exact_close():
|
| 86 |
+
assert_eq(grade_task2("close", "close"), 1.0)
|
| 87 |
+
|
| 88 |
+
@test
|
| 89 |
+
def test_grade2_partial_reply_escalate():
|
| 90 |
+
assert_eq(grade_task2("reply", "escalate"), 0.5)
|
| 91 |
+
assert_eq(grade_task2("escalate", "reply"), 0.5)
|
| 92 |
+
|
| 93 |
+
@test
|
| 94 |
+
def test_grade2_close_wrong():
|
| 95 |
+
assert_eq(grade_task2("close", "reply"), 0.0)
|
| 96 |
+
|
| 97 |
+
@test
|
| 98 |
+
def test_grade3_perfect():
|
| 99 |
+
score = grade_task3(True, True, False,
|
| 100 |
+
"we will process your refund billing payment",
|
| 101 |
+
"billing", True, 1, 5)
|
| 102 |
+
assert_true(score >= 0.9, f"Expected >=0.9, got {score}")
|
| 103 |
+
|
| 104 |
+
@test
|
| 105 |
+
def test_grade3_capped_at_one():
|
| 106 |
+
score = grade_task3(True, True, False,
|
| 107 |
+
"refund billing payment account cancel subscription",
|
| 108 |
+
"billing", True, 1, 5)
|
| 109 |
+
assert_true(score <= 1.0, f"Score exceeds 1.0: {score}")
|
| 110 |
+
|
| 111 |
+
@test
|
| 112 |
+
def test_grade3_partial_action_less_than_full():
|
| 113 |
+
s_partial = grade_task3(True, False, True, None, "technical", True, 2)
|
| 114 |
+
s_full = grade_task3(True, True, False, None, "technical", True, 2)
|
| 115 |
+
assert_true(s_partial < s_full, f"partial={s_partial} should < full={s_full}")
|
| 116 |
+
|
| 117 |
+
@test
|
| 118 |
+
def test_loop_penalty_none_within_limit():
|
| 119 |
+
assert_eq(loop_penalty(5), 0.0)
|
| 120 |
+
assert_eq(loop_penalty(10), 0.0)
|
| 121 |
+
|
| 122 |
+
@test
|
| 123 |
+
def test_loop_penalty_grows():
|
| 124 |
+
assert_true(loop_penalty(12) < loop_penalty(11))
|
| 125 |
+
assert_true(loop_penalty(11) < 0)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ════════════════════════════════════════════════════════════════
|
| 129 |
+
# ENVIRONMENT TESTS
|
| 130 |
+
# ════════════════════════════════════════════════════════════════
|
| 131 |
+
|
| 132 |
+
@test
|
| 133 |
+
def test_env_reset_task1():
|
| 134 |
+
env = make_env()
|
| 135 |
+
obs = env.reset(task_id=1, seed=42)
|
| 136 |
+
assert_true(obs.ticket_text != "", "ticket_text should not be empty")
|
| 137 |
+
assert_eq(obs.task_id, 1)
|
| 138 |
+
assert_eq(obs.done, False)
|
| 139 |
+
|
| 140 |
+
@test
|
| 141 |
+
def test_env_task1_correct_classification():
|
| 142 |
+
env = make_env()
|
| 143 |
+
env.reset(task_id=1, seed=42)
|
| 144 |
+
state = env.state
|
| 145 |
+
obs = env.step(SupportAction(action_type="classify", category=state.correct_category))
|
| 146 |
+
assert_eq(obs.reward, 1.0)
|
| 147 |
+
assert_eq(obs.done, True)
|
| 148 |
+
|
| 149 |
+
@test
|
| 150 |
+
def test_env_task1_wrong_classification():
|
| 151 |
+
env = make_env()
|
| 152 |
+
env.reset(task_id=1, seed=42)
|
| 153 |
+
state = env.state
|
| 154 |
+
wrong = next(c for c in ["billing","technical","account","general","refund"]
|
| 155 |
+
if c != state.correct_category)
|
| 156 |
+
obs = env.step(SupportAction(action_type="classify", category=wrong))
|
| 157 |
+
assert_eq(obs.reward, 0.0)
|
| 158 |
+
assert_eq(obs.done, True)
|
| 159 |
+
|
| 160 |
+
@test
|
| 161 |
+
def test_env_task2_must_classify_first():
|
| 162 |
+
env = make_env()
|
| 163 |
+
env.reset(task_id=2, seed=42)
|
| 164 |
+
obs = env.step(SupportAction(action_type="escalate"))
|
| 165 |
+
assert_eq(obs.done, False)
|
| 166 |
+
assert_true("classify" in obs.feedback.lower())
|
| 167 |
+
|
| 168 |
+
@test
|
| 169 |
+
def test_env_task2_full_correct_episode():
|
| 170 |
+
env = make_env()
|
| 171 |
+
env.reset(task_id=2, seed=42)
|
| 172 |
+
state = env.state
|
| 173 |
+
env.step(SupportAction(action_type="classify", category=state.correct_category))
|
| 174 |
+
obs = env.step(SupportAction(action_type=state.correct_action))
|
| 175 |
+
assert_eq(obs.done, True)
|
| 176 |
+
assert_true(obs.reward >= 0.5, f"reward={obs.reward}")
|
| 177 |
+
|
| 178 |
+
@test
|
| 179 |
+
def test_env_task3_three_tickets():
|
| 180 |
+
env = make_env()
|
| 181 |
+
env.reset(task_id=3, seed=42)
|
| 182 |
+
assert_eq(env.state.tickets_total, 3)
|
| 183 |
+
|
| 184 |
+
@test
|
| 185 |
+
def test_env_task3_resolves_all():
|
| 186 |
+
env = make_env()
|
| 187 |
+
env.reset(task_id=3, seed=42)
|
| 188 |
+
done = False
|
| 189 |
+
steps = 0
|
| 190 |
+
while not done and steps < 30:
|
| 191 |
+
state = env.state
|
| 192 |
+
if not state.classified:
|
| 193 |
+
action = SupportAction(action_type="classify", category=state.correct_category)
|
| 194 |
+
else:
|
| 195 |
+
ca = state.correct_action
|
| 196 |
+
action = (SupportAction(action_type="reply",
|
| 197 |
+
reply_text=f"Handling your {state.correct_category} issue.")
|
| 198 |
+
if ca == "reply" else SupportAction(action_type=ca))
|
| 199 |
+
obs = env.step(action)
|
| 200 |
+
done = obs.done
|
| 201 |
+
steps += 1
|
| 202 |
+
assert_true(done, "Episode did not finish")
|
| 203 |
+
assert_eq(env.state.tickets_resolved, 3)
|
| 204 |
+
|
| 205 |
+
@test
|
| 206 |
+
def test_env_state_step_count():
|
| 207 |
+
env = make_env()
|
| 208 |
+
env.reset(task_id=1, seed=0)
|
| 209 |
+
assert_eq(env.state.step_count, 0)
|
| 210 |
+
state = env.state
|
| 211 |
+
env.step(SupportAction(action_type="classify", category=state.correct_category))
|
| 212 |
+
assert_eq(env.state.step_count, 1)
|
| 213 |
+
|
| 214 |
+
@test
|
| 215 |
+
def test_env_reward_always_in_range():
|
| 216 |
+
for seed in [0, 1, 2, 42, 99]:
|
| 217 |
+
for task_id in [1, 2, 3]:
|
| 218 |
+
env = make_env()
|
| 219 |
+
env.reset(task_id=task_id, seed=seed)
|
| 220 |
+
state = env.state
|
| 221 |
+
obs = env.step(SupportAction(action_type="classify", category=state.correct_category))
|
| 222 |
+
r = obs.reward or 0.0
|
| 223 |
+
assert_in_range(r, -1.0, 1.0, f"task={task_id} seed={seed} reward={r}")
|
| 224 |
+
|
| 225 |
+
@test
|
| 226 |
+
def test_env_task3_total_reward_positive():
|
| 227 |
+
env = make_env()
|
| 228 |
+
env.reset(task_id=3, seed=7)
|
| 229 |
+
total = 0.0
|
| 230 |
+
done = False
|
| 231 |
+
steps = 0
|
| 232 |
+
while not done and steps < 20:
|
| 233 |
+
state = env.state
|
| 234 |
+
action = (SupportAction(action_type="classify", category=state.correct_category)
|
| 235 |
+
if not state.classified
|
| 236 |
+
else SupportAction(action_type=state.correct_action))
|
| 237 |
+
obs = env.step(action)
|
| 238 |
+
total += obs.reward or 0.0
|
| 239 |
+
done = obs.done
|
| 240 |
+
steps += 1
|
| 241 |
+
assert_true(total > 0.0, f"total_reward={total}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ════════════════════════════════════════════════════════════════
|
| 245 |
+
# Runner
|
| 246 |
+
# ════════════════════════════════════════════════════════════════
|
| 247 |
+
|
| 248 |
+
def run_all():
|
| 249 |
+
global _passed, _failed, _errors
|
| 250 |
+
width = max(len(name) for name, _ in _tests) + 2
|
| 251 |
+
print(f"\n{'='*60}")
|
| 252 |
+
print(f" Running {len(_tests)} tests")
|
| 253 |
+
print(f"{'='*60}")
|
| 254 |
+
for name, fn in _tests:
|
| 255 |
+
try:
|
| 256 |
+
fn()
|
| 257 |
+
print(f" ✅ {name}")
|
| 258 |
+
_passed += 1
|
| 259 |
+
except AssertionError as e:
|
| 260 |
+
print(f" ❌ {name}")
|
| 261 |
+
print(f" {e}")
|
| 262 |
+
_failed += 1
|
| 263 |
+
except Exception:
|
| 264 |
+
print(f" 💥 {name}")
|
| 265 |
+
traceback.print_exc(limit=3)
|
| 266 |
+
_errors += 1
|
| 267 |
+
total = _passed + _failed + _errors
|
| 268 |
+
print(f"\n{'='*60}")
|
| 269 |
+
print(f" Results: {_passed}/{total} passed | {_failed} failed | {_errors} errors")
|
| 270 |
+
print(f"{'='*60}\n")
|
| 271 |
+
return _failed + _errors == 0
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
ok = run_all()
|
| 276 |
+
sys.exit(0 if ok else 1)
|
server/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 2 |
+
from support_ticket_env.server.app import app
|
| 3 |
+
|
| 4 |
+
__all__ = ["SupportTicketEnvironment", "app"]
|
server/app.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application entry point for the Support Ticket Environment.
|
| 3 |
+
Serves the OpenEnv HTTP/WebSocket API and optionally the Gradio UI at /web.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from openenv.core.env_server.http_server import create_app
|
| 7 |
+
|
| 8 |
+
from support_ticket_env.models import SupportAction, SupportObservation
|
| 9 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 10 |
+
|
| 11 |
+
app = create_app(
|
| 12 |
+
env=SupportTicketEnvironment,
|
| 13 |
+
action_cls=SupportAction,
|
| 14 |
+
observation_cls=SupportObservation,
|
| 15 |
+
env_name="support_ticket_env",
|
| 16 |
+
max_concurrent_envs=4,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Mount Gradio UI at /web when requested
|
| 20 |
+
if os.getenv("ENABLE_WEB_INTERFACE", "true").lower() == "true":
|
| 21 |
+
try:
|
| 22 |
+
import gradio as gr
|
| 23 |
+
from support_ticket_env.gradio_ui import demo
|
| 24 |
+
import gradio.routes
|
| 25 |
+
app = gr.mount_gradio_app(app, demo, path="/web")
|
| 26 |
+
print("Gradio UI mounted at /web")
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Gradio UI not mounted: {e}")
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.1
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn[standard]>=0.24.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
pyyaml>=6.0
|
| 6 |
+
gradio>=4.0.0
|
| 7 |
+
openai>=1.0.0
|
server/support_environment.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Customer Support Ticket Resolution — OpenEnv Environment (server side).
|
| 3 |
+
|
| 4 |
+
Implements the three tasks:
|
| 5 |
+
Task 1 (easy) – Classify a single ticket
|
| 6 |
+
Task 2 (medium) – Choose the correct action for a classified ticket
|
| 7 |
+
Task 3 (hard) – Fully resolve a queue of tickets with minimal steps
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import random
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
from openenv.core.env_server.interfaces import Environment
|
| 16 |
+
from openenv.core.env_server.types import State
|
| 17 |
+
|
| 18 |
+
from support_ticket_env.models import SupportAction, SupportObservation, SupportState
|
| 19 |
+
from support_ticket_env.tickets import TICKETS, TICKET_LOOKUP
|
| 20 |
+
from support_ticket_env.graders import (
|
| 21 |
+
grade_task1,
|
| 22 |
+
grade_task2,
|
| 23 |
+
grade_task3,
|
| 24 |
+
loop_penalty,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SupportTicketEnvironment(Environment):
|
| 29 |
+
"""
|
| 30 |
+
OpenEnv environment that simulates a customer-support triage desk.
|
| 31 |
+
|
| 32 |
+
The task_id (1, 2, or 3) is set when the environment is reset.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 36 |
+
|
| 37 |
+
def __init__(self) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
self._task_id: int = 1
|
| 40 |
+
self._ticket: dict = {}
|
| 41 |
+
self._classified: bool = False
|
| 42 |
+
self._resolved: bool = False
|
| 43 |
+
self._step_count: int = 0
|
| 44 |
+
self._total_reward: float = 0.0
|
| 45 |
+
self._episode_id: Optional[str] = None
|
| 46 |
+
|
| 47 |
+
# Task 3: queue of tickets
|
| 48 |
+
self._queue: list[dict] = []
|
| 49 |
+
self._tickets_resolved: int = 0
|
| 50 |
+
self._tickets_total: int = 1
|
| 51 |
+
|
| 52 |
+
# ──────────────────────── reset ────────────────────────────
|
| 53 |
+
|
| 54 |
+
def reset(
|
| 55 |
+
self,
|
| 56 |
+
seed: Optional[int] = None,
|
| 57 |
+
episode_id: Optional[str] = None,
|
| 58 |
+
task_id: int = 1,
|
| 59 |
+
**kwargs,
|
| 60 |
+
) -> SupportObservation:
|
| 61 |
+
rng = random.Random(seed)
|
| 62 |
+
self._episode_id = episode_id
|
| 63 |
+
self._task_id = int(task_id)
|
| 64 |
+
self._step_count = 0
|
| 65 |
+
self._total_reward = 0.0
|
| 66 |
+
self._classified = False
|
| 67 |
+
self._resolved = False
|
| 68 |
+
|
| 69 |
+
if self._task_id == 3:
|
| 70 |
+
# Give the agent a queue of 3 tickets
|
| 71 |
+
self._queue = rng.sample(TICKETS, k=3)
|
| 72 |
+
self._tickets_total = len(self._queue)
|
| 73 |
+
self._tickets_resolved = 0
|
| 74 |
+
self._ticket = self._queue[0]
|
| 75 |
+
else:
|
| 76 |
+
self._ticket = rng.choice(TICKETS)
|
| 77 |
+
self._tickets_total = 1
|
| 78 |
+
self._tickets_resolved = 0
|
| 79 |
+
|
| 80 |
+
return self._make_obs(
|
| 81 |
+
feedback="New episode started. Read the ticket and take action.",
|
| 82 |
+
score=0.0,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# ──────────────────────── step ─────────────────────────────
|
| 86 |
+
|
| 87 |
+
def step(self, action: SupportAction, **kwargs) -> SupportObservation: # type: ignore[override]
|
| 88 |
+
self._step_count += 1
|
| 89 |
+
penalty = loop_penalty(self._step_count)
|
| 90 |
+
|
| 91 |
+
if self._task_id == 1:
|
| 92 |
+
obs = self._step_task1(action)
|
| 93 |
+
elif self._task_id == 2:
|
| 94 |
+
obs = self._step_task2(action)
|
| 95 |
+
else:
|
| 96 |
+
obs = self._step_task3(action)
|
| 97 |
+
|
| 98 |
+
# Apply loop penalty on top of step reward
|
| 99 |
+
obs.reward = (obs.reward or 0.0) + penalty
|
| 100 |
+
obs.reward = round(max(-1.0, min(1.0, obs.reward)), 4)
|
| 101 |
+
self._total_reward += obs.reward
|
| 102 |
+
obs.step_count = self._step_count
|
| 103 |
+
return obs
|
| 104 |
+
|
| 105 |
+
# ──────────────────────── Task 1 ───────────────────────────
|
| 106 |
+
|
| 107 |
+
def _step_task1(self, action: SupportAction) -> SupportObservation:
|
| 108 |
+
if action.action_type != "classify":
|
| 109 |
+
return self._make_obs(
|
| 110 |
+
feedback="Task 1 requires a 'classify' action.",
|
| 111 |
+
score=0.0,
|
| 112 |
+
done=False,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
score = grade_task1(
|
| 116 |
+
predicted_category=action.category or "",
|
| 117 |
+
correct_category=self._ticket["category"],
|
| 118 |
+
)
|
| 119 |
+
self._classified = score == 1.0
|
| 120 |
+
correct = self._ticket["category"]
|
| 121 |
+
|
| 122 |
+
if score == 1.0:
|
| 123 |
+
feedback = f"✅ Correct! Category: '{correct}'."
|
| 124 |
+
done = True
|
| 125 |
+
else:
|
| 126 |
+
feedback = (
|
| 127 |
+
f"❌ Wrong. You said '{action.category}', correct is '{correct}'."
|
| 128 |
+
)
|
| 129 |
+
done = True # Task 1 is one-shot — agent gets one attempt
|
| 130 |
+
|
| 131 |
+
obs = self._make_obs(feedback=feedback, score=score, done=done)
|
| 132 |
+
if done:
|
| 133 |
+
self._resolved = True
|
| 134 |
+
return obs
|
| 135 |
+
|
| 136 |
+
# ──────────────────────── Task 2 ───────────────────────────
|
| 137 |
+
|
| 138 |
+
def _step_task2(self, action: SupportAction) -> SupportObservation:
|
| 139 |
+
# First step must be classification
|
| 140 |
+
if not self._classified:
|
| 141 |
+
if action.action_type != "classify":
|
| 142 |
+
return self._make_obs(
|
| 143 |
+
feedback="Please classify the ticket first.",
|
| 144 |
+
score=0.0,
|
| 145 |
+
)
|
| 146 |
+
cat_score = grade_task1(
|
| 147 |
+
action.category or "", self._ticket["category"]
|
| 148 |
+
)
|
| 149 |
+
self._classified = True
|
| 150 |
+
return self._make_obs(
|
| 151 |
+
feedback=(
|
| 152 |
+
f"Classified as '{action.category}'. "
|
| 153 |
+
f"{'Correct ✅' if cat_score == 1.0 else 'Incorrect ❌'} "
|
| 154 |
+
"Now choose an action."
|
| 155 |
+
),
|
| 156 |
+
score=cat_score * 0.3, # partial credit toward max 1.0
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Second step: choose action
|
| 160 |
+
score = grade_task2(
|
| 161 |
+
action_type=action.action_type,
|
| 162 |
+
correct_action=self._ticket["correct_action"],
|
| 163 |
+
category=self._ticket["category"],
|
| 164 |
+
)
|
| 165 |
+
correct = self._ticket["correct_action"]
|
| 166 |
+
if score == 1.0:
|
| 167 |
+
feedback = f"✅ Correct action: '{correct}'."
|
| 168 |
+
elif score == 0.5:
|
| 169 |
+
feedback = (
|
| 170 |
+
f"⚠️ Partial credit. '{action.action_type}' is defensible "
|
| 171 |
+
f"but '{correct}' is preferred."
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
feedback = f"❌ Wrong action. Correct: '{correct}'."
|
| 175 |
+
|
| 176 |
+
self._resolved = True
|
| 177 |
+
return self._make_obs(feedback=feedback, score=score, done=True)
|
| 178 |
+
|
| 179 |
+
# ──────────────────────── Task 3 ───────────────────────────
|
| 180 |
+
|
| 181 |
+
def _step_task3(self, action: SupportAction) -> SupportObservation:
|
| 182 |
+
MAX_STEPS = 15
|
| 183 |
+
|
| 184 |
+
if not self._classified:
|
| 185 |
+
# Must classify first
|
| 186 |
+
if action.action_type != "classify":
|
| 187 |
+
return self._make_obs(
|
| 188 |
+
feedback="Classify the ticket before taking action.",
|
| 189 |
+
score=0.0,
|
| 190 |
+
)
|
| 191 |
+
cat_score = grade_task1(
|
| 192 |
+
action.category or "", self._ticket["category"]
|
| 193 |
+
)
|
| 194 |
+
self._classified = True
|
| 195 |
+
return self._make_obs(
|
| 196 |
+
feedback=(
|
| 197 |
+
f"Classified '{self._ticket['id']}' as '{action.category}'. "
|
| 198 |
+
f"{'Correct ✅' if cat_score == 1.0 else 'Incorrect ❌'} "
|
| 199 |
+
"Now resolve it."
|
| 200 |
+
),
|
| 201 |
+
score=cat_score * 0.1,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Resolve current ticket
|
| 205 |
+
action_correct = action.action_type == self._ticket["correct_action"]
|
| 206 |
+
pair = frozenset({action.action_type, self._ticket["correct_action"]})
|
| 207 |
+
action_partial = (not action_correct) and pair in {
|
| 208 |
+
frozenset({"reply", "escalate"})
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
score = grade_task3(
|
| 212 |
+
classified_correctly=self._classified,
|
| 213 |
+
action_correct=action_correct,
|
| 214 |
+
action_partial=action_partial,
|
| 215 |
+
reply_text=action.reply_text,
|
| 216 |
+
category=self._ticket["category"],
|
| 217 |
+
resolved=True,
|
| 218 |
+
steps_taken=self._step_count,
|
| 219 |
+
max_steps=MAX_STEPS,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
self._tickets_resolved += 1
|
| 223 |
+
correct_action = self._ticket["correct_action"]
|
| 224 |
+
|
| 225 |
+
# Advance to next ticket in queue
|
| 226 |
+
if self._tickets_resolved < self._tickets_total:
|
| 227 |
+
self._ticket = self._queue[self._tickets_resolved]
|
| 228 |
+
self._classified = False
|
| 229 |
+
feedback = (
|
| 230 |
+
f"Ticket resolved (score {score:.2f}). "
|
| 231 |
+
f"Moving to next ticket ({self._tickets_resolved + 1}/{self._tickets_total})."
|
| 232 |
+
)
|
| 233 |
+
done = False
|
| 234 |
+
else:
|
| 235 |
+
feedback = (
|
| 236 |
+
f"All {self._tickets_total} tickets resolved! "
|
| 237 |
+
f"Episode score: {self._total_reward + score:.2f}"
|
| 238 |
+
)
|
| 239 |
+
done = True
|
| 240 |
+
self._resolved = True
|
| 241 |
+
|
| 242 |
+
return self._make_obs(feedback=feedback, score=score, done=done)
|
| 243 |
+
|
| 244 |
+
# ──────────────────────── helpers ──────────────────────────
|
| 245 |
+
|
| 246 |
+
def _make_obs(
|
| 247 |
+
self,
|
| 248 |
+
feedback: str,
|
| 249 |
+
score: float,
|
| 250 |
+
done: bool = False,
|
| 251 |
+
) -> SupportObservation:
|
| 252 |
+
return SupportObservation(
|
| 253 |
+
ticket_id=self._ticket.get("id", ""),
|
| 254 |
+
ticket_text=self._ticket.get("text", ""),
|
| 255 |
+
task_id=self._task_id,
|
| 256 |
+
current_category=self._ticket.get("category") if self._classified else None,
|
| 257 |
+
resolved=self._resolved,
|
| 258 |
+
step_count=self._step_count,
|
| 259 |
+
feedback=feedback,
|
| 260 |
+
score=score,
|
| 261 |
+
reward=score,
|
| 262 |
+
done=done,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# ──────────────────────── state ────────────────────────────
|
| 266 |
+
|
| 267 |
+
@property
|
| 268 |
+
def state(self) -> SupportState:
|
| 269 |
+
return SupportState(
|
| 270 |
+
episode_id=self._episode_id,
|
| 271 |
+
step_count=self._step_count,
|
| 272 |
+
task_id=self._task_id,
|
| 273 |
+
ticket_id=self._ticket.get("id", ""),
|
| 274 |
+
correct_category=self._ticket.get("category", ""),
|
| 275 |
+
correct_action=self._ticket.get("correct_action", ""),
|
| 276 |
+
classified=self._classified,
|
| 277 |
+
resolved=self._resolved,
|
| 278 |
+
total_reward=self._total_reward,
|
| 279 |
+
tickets_resolved=self._tickets_resolved,
|
| 280 |
+
tickets_total=self._tickets_total,
|
| 281 |
+
)
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest configuration for support_ticket_env tests."""
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Ensure the package root is on the path when running tests directly
|
| 6 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for SupportTicketEnvironment — runs the environment directly
|
| 3 |
+
(no HTTP server required).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from support_ticket_env.server.support_environment import SupportTicketEnvironment
|
| 8 |
+
from support_ticket_env.models import SupportAction
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ─────────────────────────── fixtures ──────────────────────────
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def env():
|
| 15 |
+
return SupportTicketEnvironment()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ─────────────────────────── Task 1 ────────────────────────────
|
| 19 |
+
|
| 20 |
+
class TestTask1:
|
| 21 |
+
def test_reset_returns_observation(self, env):
|
| 22 |
+
obs = env.reset(task_id=1, seed=42)
|
| 23 |
+
assert obs.ticket_text
|
| 24 |
+
assert obs.task_id == 1
|
| 25 |
+
assert obs.done is False
|
| 26 |
+
|
| 27 |
+
def test_correct_classification(self, env):
|
| 28 |
+
obs = env.reset(task_id=1, seed=42)
|
| 29 |
+
# Find out the correct category via state
|
| 30 |
+
state = env.state
|
| 31 |
+
action = SupportAction(
|
| 32 |
+
action_type="classify",
|
| 33 |
+
category=state.correct_category,
|
| 34 |
+
)
|
| 35 |
+
obs = env.step(action)
|
| 36 |
+
assert obs.reward == 1.0
|
| 37 |
+
assert obs.done is True
|
| 38 |
+
|
| 39 |
+
def test_wrong_classification(self, env):
|
| 40 |
+
env.reset(task_id=1, seed=42)
|
| 41 |
+
state = env.state
|
| 42 |
+
wrong_cats = [
|
| 43 |
+
c for c in ["billing", "technical", "account", "general", "refund"]
|
| 44 |
+
if c != state.correct_category
|
| 45 |
+
]
|
| 46 |
+
action = SupportAction(action_type="classify", category=wrong_cats[0])
|
| 47 |
+
obs = env.step(action)
|
| 48 |
+
assert obs.reward == 0.0
|
| 49 |
+
assert obs.done is True
|
| 50 |
+
|
| 51 |
+
def test_non_classify_action_penalised(self, env):
|
| 52 |
+
env.reset(task_id=1, seed=42)
|
| 53 |
+
obs = env.step(SupportAction(action_type="reply", reply_text="hello"))
|
| 54 |
+
# Should not crash; done might be False and reward 0
|
| 55 |
+
assert obs.reward is not None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ─────────────────────────── Task 2 ────────────────────────────
|
| 59 |
+
|
| 60 |
+
class TestTask2:
|
| 61 |
+
def test_full_correct_episode(self, env):
|
| 62 |
+
env.reset(task_id=2, seed=42)
|
| 63 |
+
state = env.state
|
| 64 |
+
|
| 65 |
+
# Step 1: classify
|
| 66 |
+
obs = env.step(SupportAction(
|
| 67 |
+
action_type="classify",
|
| 68 |
+
category=state.correct_category,
|
| 69 |
+
))
|
| 70 |
+
assert obs.done is False
|
| 71 |
+
assert obs.reward > 0
|
| 72 |
+
|
| 73 |
+
# Step 2: correct action
|
| 74 |
+
obs = env.step(SupportAction(action_type=state.correct_action))
|
| 75 |
+
assert obs.done is True
|
| 76 |
+
assert obs.reward >= 0.5
|
| 77 |
+
|
| 78 |
+
def test_must_classify_first(self, env):
|
| 79 |
+
env.reset(task_id=2, seed=7)
|
| 80 |
+
obs = env.step(SupportAction(action_type="escalate"))
|
| 81 |
+
assert obs.done is False
|
| 82 |
+
assert "classify" in obs.feedback.lower()
|
| 83 |
+
|
| 84 |
+
def test_state_reflects_progress(self, env):
|
| 85 |
+
env.reset(task_id=2, seed=7)
|
| 86 |
+
state = env.state
|
| 87 |
+
assert state.classified is False
|
| 88 |
+
|
| 89 |
+
env.step(SupportAction(
|
| 90 |
+
action_type="classify",
|
| 91 |
+
category=state.correct_category,
|
| 92 |
+
))
|
| 93 |
+
state2 = env.state
|
| 94 |
+
assert state2.classified is True
|
| 95 |
+
assert state2.step_count == 1
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ─────────────────────────── Task 3 ────────────────────────────
|
| 99 |
+
|
| 100 |
+
class TestTask3:
|
| 101 |
+
def test_queue_has_three_tickets(self, env):
|
| 102 |
+
env.reset(task_id=3, seed=42)
|
| 103 |
+
state = env.state
|
| 104 |
+
assert state.tickets_total == 3
|
| 105 |
+
assert state.tickets_resolved == 0
|
| 106 |
+
|
| 107 |
+
def test_resolve_all_tickets(self, env):
|
| 108 |
+
env.reset(task_id=3, seed=42)
|
| 109 |
+
done = False
|
| 110 |
+
steps = 0
|
| 111 |
+
|
| 112 |
+
while not done and steps < 30:
|
| 113 |
+
state = env.state
|
| 114 |
+
if not state.classified:
|
| 115 |
+
action = SupportAction(
|
| 116 |
+
action_type="classify",
|
| 117 |
+
category=state.correct_category,
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
ca = state.correct_action
|
| 121 |
+
if ca == "reply":
|
| 122 |
+
action = SupportAction(
|
| 123 |
+
action_type="reply",
|
| 124 |
+
reply_text=f"We are handling your {state.correct_category} issue.",
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
action = SupportAction(action_type=ca)
|
| 128 |
+
obs = env.step(action)
|
| 129 |
+
done = obs.done
|
| 130 |
+
steps += 1
|
| 131 |
+
|
| 132 |
+
assert done, "Episode should finish after 3 tickets"
|
| 133 |
+
final_state = env.state
|
| 134 |
+
assert final_state.tickets_resolved == 3
|
| 135 |
+
|
| 136 |
+
def test_total_reward_positive(self, env):
|
| 137 |
+
env.reset(task_id=3, seed=123)
|
| 138 |
+
total = 0.0
|
| 139 |
+
done = False
|
| 140 |
+
steps = 0
|
| 141 |
+
|
| 142 |
+
while not done and steps < 20:
|
| 143 |
+
state = env.state
|
| 144 |
+
if not state.classified:
|
| 145 |
+
action = SupportAction(
|
| 146 |
+
action_type="classify",
|
| 147 |
+
category=state.correct_category,
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
action = SupportAction(action_type=state.correct_action)
|
| 151 |
+
obs = env.step(action)
|
| 152 |
+
total += obs.reward or 0.0
|
| 153 |
+
done = obs.done
|
| 154 |
+
steps += 1
|
| 155 |
+
|
| 156 |
+
assert total > 0.0
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ─────────────────────────── State API ─────────────────────────
|
| 160 |
+
|
| 161 |
+
class TestStateAPI:
|
| 162 |
+
def test_state_after_reset(self, env):
|
| 163 |
+
env.reset(task_id=1, seed=0)
|
| 164 |
+
state = env.state
|
| 165 |
+
assert state.step_count == 0
|
| 166 |
+
assert state.task_id == 1
|
| 167 |
+
assert state.ticket_id != ""
|
| 168 |
+
|
| 169 |
+
def test_step_count_increments(self, env):
|
| 170 |
+
env.reset(task_id=1, seed=0)
|
| 171 |
+
state = env.state
|
| 172 |
+
env.step(SupportAction(action_type="classify", category=state.correct_category))
|
| 173 |
+
assert env.state.step_count == 1
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ─────────────────────────── Reward bounds ─────────────────────
|
| 177 |
+
|
| 178 |
+
class TestRewardBounds:
|
| 179 |
+
def test_reward_in_range(self, env):
|
| 180 |
+
for seed in [0, 1, 2, 3, 42]:
|
| 181 |
+
for task_id in [1, 2, 3]:
|
| 182 |
+
env.reset(task_id=task_id, seed=seed)
|
| 183 |
+
state = env.state
|
| 184 |
+
action = SupportAction(
|
| 185 |
+
action_type="classify",
|
| 186 |
+
category=state.correct_category,
|
| 187 |
+
)
|
| 188 |
+
obs = env.step(action)
|
| 189 |
+
assert -1.0 <= (obs.reward or 0.0) <= 1.0, (
|
| 190 |
+
f"Reward out of bounds: {obs.reward}"
|
| 191 |
+
)
|
tests/test_graders.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for grader functions."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from support_ticket_env.graders import (
|
| 5 |
+
grade_task1,
|
| 6 |
+
grade_task2,
|
| 7 |
+
grade_task3,
|
| 8 |
+
loop_penalty,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestTask1Grader:
|
| 13 |
+
def test_correct_category(self):
|
| 14 |
+
assert grade_task1("billing", "billing") == 1.0
|
| 15 |
+
|
| 16 |
+
def test_wrong_category(self):
|
| 17 |
+
assert grade_task1("technical", "billing") == 0.0
|
| 18 |
+
|
| 19 |
+
def test_all_categories(self):
|
| 20 |
+
for cat in ["billing", "technical", "account", "general", "refund"]:
|
| 21 |
+
assert grade_task1(cat, cat) == 1.0
|
| 22 |
+
|
| 23 |
+
def test_empty_prediction(self):
|
| 24 |
+
assert grade_task1("", "billing") == 0.0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TestTask2Grader:
|
| 28 |
+
def test_exact_match(self):
|
| 29 |
+
assert grade_task2("reply", "reply") == 1.0
|
| 30 |
+
assert grade_task2("escalate", "escalate") == 1.0
|
| 31 |
+
assert grade_task2("close", "close") == 1.0
|
| 32 |
+
|
| 33 |
+
def test_partial_credit_reply_escalate(self):
|
| 34 |
+
score = grade_task2("reply", "escalate")
|
| 35 |
+
assert score == 0.5
|
| 36 |
+
score = grade_task2("escalate", "reply")
|
| 37 |
+
assert score == 0.5
|
| 38 |
+
|
| 39 |
+
def test_wrong_action_close(self):
|
| 40 |
+
assert grade_task2("close", "reply") == 0.0
|
| 41 |
+
assert grade_task2("close", "escalate") == 0.0
|
| 42 |
+
|
| 43 |
+
def test_classify_when_action_expected(self):
|
| 44 |
+
assert grade_task2("classify", "reply") == 0.0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TestTask3Grader:
|
| 48 |
+
def test_perfect_resolution(self):
|
| 49 |
+
score = grade_task3(
|
| 50 |
+
classified_correctly=True,
|
| 51 |
+
action_correct=True,
|
| 52 |
+
action_partial=False,
|
| 53 |
+
reply_text="we will process your refund billing payment",
|
| 54 |
+
category="billing",
|
| 55 |
+
resolved=True,
|
| 56 |
+
steps_taken=1,
|
| 57 |
+
max_steps=5,
|
| 58 |
+
)
|
| 59 |
+
assert score > 0.9
|
| 60 |
+
|
| 61 |
+
def test_no_classification(self):
|
| 62 |
+
score = grade_task3(
|
| 63 |
+
classified_correctly=False,
|
| 64 |
+
action_correct=True,
|
| 65 |
+
action_partial=False,
|
| 66 |
+
reply_text="here is the refund",
|
| 67 |
+
category="billing",
|
| 68 |
+
resolved=True,
|
| 69 |
+
steps_taken=2,
|
| 70 |
+
)
|
| 71 |
+
# Should not get the 0.20 classification bonus
|
| 72 |
+
assert score < 1.0
|
| 73 |
+
|
| 74 |
+
def test_partial_action(self):
|
| 75 |
+
score_partial = grade_task3(
|
| 76 |
+
classified_correctly=True,
|
| 77 |
+
action_correct=False,
|
| 78 |
+
action_partial=True,
|
| 79 |
+
reply_text=None,
|
| 80 |
+
category="technical",
|
| 81 |
+
resolved=True,
|
| 82 |
+
steps_taken=2,
|
| 83 |
+
)
|
| 84 |
+
score_correct = grade_task3(
|
| 85 |
+
classified_correctly=True,
|
| 86 |
+
action_correct=True,
|
| 87 |
+
action_partial=False,
|
| 88 |
+
reply_text=None,
|
| 89 |
+
category="technical",
|
| 90 |
+
resolved=True,
|
| 91 |
+
steps_taken=2,
|
| 92 |
+
)
|
| 93 |
+
assert score_partial < score_correct
|
| 94 |
+
|
| 95 |
+
def test_score_capped_at_one(self):
|
| 96 |
+
score = grade_task3(
|
| 97 |
+
classified_correctly=True,
|
| 98 |
+
action_correct=True,
|
| 99 |
+
action_partial=False,
|
| 100 |
+
reply_text="refund billing payment account cancel subscription",
|
| 101 |
+
category="billing",
|
| 102 |
+
resolved=True,
|
| 103 |
+
steps_taken=1,
|
| 104 |
+
max_steps=5,
|
| 105 |
+
)
|
| 106 |
+
assert score <= 1.0
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TestLoopPenalty:
|
| 110 |
+
def test_no_penalty_within_limit(self):
|
| 111 |
+
assert loop_penalty(5) == 0.0
|
| 112 |
+
assert loop_penalty(10) == 0.0
|
| 113 |
+
|
| 114 |
+
def test_penalty_beyond_limit(self):
|
| 115 |
+
assert loop_penalty(11) < 0.0
|
| 116 |
+
assert loop_penalty(15) < loop_penalty(11)
|
| 117 |
+
|
| 118 |
+
def test_penalty_grows(self):
|
| 119 |
+
p1 = loop_penalty(12)
|
| 120 |
+
p2 = loop_penalty(14)
|
| 121 |
+
assert p2 < p1
|
tickets.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Realistic support ticket dataset with ground-truth labels.
|
| 3 |
+
Each ticket includes:
|
| 4 |
+
- id
|
| 5 |
+
- text : customer message
|
| 6 |
+
- category : ground-truth category
|
| 7 |
+
- correct_action : best first action ("reply" | "escalate" | "close")
|
| 8 |
+
- resolution_hint : ideal reply / close reason (used for reward scoring)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
TICKETS = [
|
| 12 |
+
{
|
| 13 |
+
"id": "T001",
|
| 14 |
+
"text": "Hi, I was charged twice for my subscription this month. Please help!",
|
| 15 |
+
"category": "billing",
|
| 16 |
+
"correct_action": "reply",
|
| 17 |
+
"resolution_hint": "apologize and initiate refund for duplicate charge",
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"id": "T002",
|
| 21 |
+
"text": "I cannot log into my account. The password reset email never arrives.",
|
| 22 |
+
"category": "account",
|
| 23 |
+
"correct_action": "reply",
|
| 24 |
+
"resolution_hint": "guide user to check spam folder and verify email address",
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"id": "T003",
|
| 28 |
+
"text": "Your app crashes every time I try to upload a file larger than 10 MB.",
|
| 29 |
+
"category": "technical",
|
| 30 |
+
"correct_action": "escalate",
|
| 31 |
+
"resolution_hint": "escalate to engineering team with crash details",
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": "T004",
|
| 35 |
+
"text": "I'd like a full refund. I haven't used the service at all this month.",
|
| 36 |
+
"category": "refund",
|
| 37 |
+
"correct_action": "reply",
|
| 38 |
+
"resolution_hint": "verify account activity and process refund per policy",
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"id": "T005",
|
| 42 |
+
"text": "What are your business hours and do you have a phone number I can call?",
|
| 43 |
+
"category": "general",
|
| 44 |
+
"correct_action": "reply",
|
| 45 |
+
"resolution_hint": "provide business hours and contact information",
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"id": "T006",
|
| 49 |
+
"text": "My invoice shows a charge for a plan I never subscribed to.",
|
| 50 |
+
"category": "billing",
|
| 51 |
+
"correct_action": "escalate",
|
| 52 |
+
"resolution_hint": "escalate potential fraudulent charge to billing team",
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"id": "T007",
|
| 56 |
+
"text": "How do I cancel my subscription? I can't find the option anywhere.",
|
| 57 |
+
"category": "account",
|
| 58 |
+
"correct_action": "reply",
|
| 59 |
+
"resolution_hint": "guide user to account settings > subscription > cancel",
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"id": "T008",
|
| 63 |
+
"text": "The API is returning 500 errors intermittently for the past 2 hours.",
|
| 64 |
+
"category": "technical",
|
| 65 |
+
"correct_action": "escalate",
|
| 66 |
+
"resolution_hint": "escalate to on-call engineering with timestamps",
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"id": "T009",
|
| 70 |
+
"text": "Thank you! The issue has been resolved. You guys are awesome.",
|
| 71 |
+
"category": "general",
|
| 72 |
+
"correct_action": "close",
|
| 73 |
+
"resolution_hint": "acknowledge and close the ticket",
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"id": "T010",
|
| 77 |
+
"text": "I need an itemised invoice for my company's accounting department.",
|
| 78 |
+
"category": "billing",
|
| 79 |
+
"correct_action": "reply",
|
| 80 |
+
"resolution_hint": "generate and send itemised invoice to customer email",
|
| 81 |
+
},
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
TICKET_LOOKUP = {t["id"]: t for t in TICKETS}
|