AlgoCore commited on
Commit
a3d65ce
·
0 Parent(s):

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Dockerfile: Customer Support Ticket Resolution Environment ──
2
+ FROM python:3.11-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # System deps
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ build-essential curl && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Python deps
11
+ COPY server/requirements.txt /app/requirements.txt
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy source
15
+ COPY . /app/support_ticket_env
16
+ ENV PYTHONPATH=/app
17
+ ENV ENABLE_WEB_INTERFACE=true
18
+
19
+ # HF Spaces uses port 7860
20
+ EXPOSE 7860
21
+
22
+ # Health check
23
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
24
+ CMD curl -f http://localhost:7860/health || exit 1
25
+
26
+ CMD ["uvicorn", "support_ticket_env.server.app:app", \
27
+ "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Support Ticket Env
3
+ emoji: 🎫
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # Customer Support Ticket Resolution Environment
11
+
12
+ A real-world OpenEnv environment for AI agent training.
__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Customer Support Ticket Resolution — OpenEnv Environment."""
2
+
3
+ from support_ticket_env.models import SupportAction, SupportObservation, SupportState
4
+ from support_ticket_env.client import SupportTicketEnv
5
+
6
+ __all__ = [
7
+ "SupportAction",
8
+ "SupportObservation",
9
+ "SupportState",
10
+ "SupportTicketEnv",
11
+ ]
baseline.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ baseline.py — Baseline inference script for the Support Ticket Environment.
4
+
5
+ Runs an OpenAI-compatible model against all 3 tasks and reports scores.
6
+
7
+ Usage:
8
+ OPENAI_API_KEY=sk-... python baseline.py --base-url http://localhost:7860
9
+
10
+ Environment variables:
11
+ OPENAI_API_KEY : required
12
+ OPENAI_BASE_URL : optional override (default https://api.openai.com/v1)
13
+ OPENAI_MODEL : optional model name (default gpt-4o-mini)
14
+ """
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ import asyncio
20
+ import re
21
+
22
+ from openai import AsyncOpenAI
23
+ from support_ticket_env.client import SupportTicketEnv
24
+ from support_ticket_env.models import SupportAction
25
+
26
+ # ─────────────────────────── Config ────────────────────────────
27
+
28
+ VALID_CATEGORIES = ["billing", "technical", "account", "general", "refund"]
29
+ VALID_ACTIONS = ["classify", "reply", "escalate", "close"]
30
+
31
+ SYSTEM_PROMPT = """You are a customer support AI agent operating in a ticket triage environment.
32
+
33
+ On each turn you receive a JSON observation with:
34
+ - ticket_text : the customer's message
35
+ - feedback : what happened last step
36
+ - task_id : 1=classify only, 2=classify then act, 3=full resolution
37
+
38
+ You must respond with a JSON object (no markdown) matching this schema:
39
+ {
40
+ "action_type": "classify" | "reply" | "escalate" | "close",
41
+ "category": "billing" | "technical" | "account" | "general" | "refund", // only for classify
42
+ "reply_text": "...", // only for reply
43
+ "reason": "..." // optional
44
+ }
45
+
46
+ Strategy:
47
+ - For task 1: only classify (use action_type="classify" with a category).
48
+ - For task 2: first classify, then choose the best action.
49
+ - For task 3: classify each ticket, then reply/escalate/close as appropriate.
50
+
51
+ Always produce valid JSON and nothing else.
52
+ """
53
+
54
+
55
+ def parse_llm_response(text: str) -> dict:
56
+ """Extract JSON from LLM response, stripping markdown fences if present."""
57
+ text = text.strip()
58
+ # Strip ```json ... ``` fences
59
+ text = re.sub(r"^```(?:json)?\s*", "", text)
60
+ text = re.sub(r"\s*```$", "", text)
61
+ try:
62
+ return json.loads(text)
63
+ except json.JSONDecodeError:
64
+ # Fallback: try to extract first JSON object
65
+ match = re.search(r"\{.*\}", text, re.DOTALL)
66
+ if match:
67
+ return json.loads(match.group())
68
+ raise
69
+
70
+
71
+ async def run_task(
72
+ env_base_url: str,
73
+ llm: AsyncOpenAI,
74
+ model: str,
75
+ task_id: int,
76
+ seed: int = 42,
77
+ max_steps: int = 10,
78
+ ) -> float:
79
+ """Run one episode for a given task_id. Returns the total reward."""
80
+ total_reward = 0.0
81
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
82
+
83
+ async with SupportTicketEnv(base_url=env_base_url) as env:
84
+ result = await env.reset(task_id=task_id, seed=seed)
85
+ obs = result.observation
86
+
87
+ for step in range(max_steps):
88
+ # Build user message from observation
89
+ obs_text = json.dumps({
90
+ "ticket_id": obs.ticket_id,
91
+ "ticket_text": obs.ticket_text,
92
+ "task_id": obs.task_id,
93
+ "current_category": obs.current_category,
94
+ "resolved": obs.resolved,
95
+ "step_count": obs.step_count,
96
+ "feedback": obs.feedback,
97
+ }, indent=2)
98
+
99
+ messages.append({"role": "user", "content": obs_text})
100
+
101
+ # Call LLM
102
+ response = await llm.chat.completions.create(
103
+ model=model,
104
+ messages=messages,
105
+ temperature=0.0,
106
+ max_tokens=256,
107
+ )
108
+ assistant_text = response.choices[0].message.content
109
+ messages.append({"role": "assistant", "content": assistant_text})
110
+
111
+ # Parse action
112
+ try:
113
+ action_dict = parse_llm_response(assistant_text)
114
+ except Exception as e:
115
+ print(f" [step {step+1}] Failed to parse LLM response: {e}")
116
+ break
117
+
118
+ try:
119
+ action = SupportAction(**action_dict)
120
+ except Exception as e:
121
+ print(f" [step {step+1}] Invalid action schema: {e}")
122
+ break
123
+
124
+ # Step environment
125
+ result = await env.step(action)
126
+ obs = result.observation
127
+ reward = result.reward or 0.0
128
+ total_reward += reward
129
+
130
+ print(
131
+ f" [step {step+1}] action={action.action_type}"
132
+ + (f"/{action.category}" if action.category else "")
133
+ + f" reward={reward:.3f} feedback={obs.feedback[:60]}"
134
+ )
135
+
136
+ if result.done:
137
+ break
138
+
139
+ return round(total_reward, 4)
140
+
141
+
142
+ async def main(env_base_url: str, model: str, seeds: list[int]) -> None:
143
+ api_key = os.environ.get("OPENAI_API_KEY", "not-needed")
144
+ openai_base = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
145
+
146
+ llm = AsyncOpenAI(api_key=api_key, base_url=openai_base)
147
+
148
+ results = {}
149
+ for task_id in [1, 2, 3]:
150
+ task_scores = []
151
+ print(f"\n{'='*60}")
152
+ print(f" TASK {task_id} (seed={seeds[0]})")
153
+ print(f"{'='*60}")
154
+ for seed in seeds:
155
+ score = await run_task(env_base_url, llm, model, task_id, seed=seed)
156
+ task_scores.append(score)
157
+ print(f" → total_reward for seed {seed}: {score}")
158
+ avg = round(sum(task_scores) / len(task_scores), 4)
159
+ results[f"task{task_id}"] = {"scores": task_scores, "avg": avg}
160
+ print(f" ► Average: {avg}")
161
+
162
+ print("\n" + "="*60)
163
+ print(" BASELINE SUMMARY")
164
+ print("="*60)
165
+ for k, v in results.items():
166
+ print(f" {k}: avg={v['avg']:.4f} scores={v['scores']}")
167
+ overall = round(
168
+ sum(v["avg"] for v in results.values()) / len(results), 4
169
+ )
170
+ print(f"\n Overall avg: {overall:.4f}")
171
+ print("="*60)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ parser = argparse.ArgumentParser(description="Baseline inference for support_ticket_env")
176
+ parser.add_argument(
177
+ "--base-url",
178
+ default=os.environ.get("ENV_BASE_URL", "http://localhost:7860"),
179
+ help="Base URL of the running environment server",
180
+ )
181
+ parser.add_argument(
182
+ "--model",
183
+ default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
184
+ help="OpenAI model name",
185
+ )
186
+ parser.add_argument(
187
+ "--seeds",
188
+ nargs="+",
189
+ type=int,
190
+ default=[42, 7, 123],
191
+ help="Random seeds for reproducibility",
192
+ )
193
+ args = parser.parse_args()
194
+ asyncio.run(main(args.base_url, args.model, args.seeds))
client.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client for the Customer Support Ticket Resolution Environment.
3
+ """
4
+
5
+ from openenv.core.env_client import EnvClient
6
+ from support_ticket_env.models import SupportAction, SupportObservation, SupportState
7
+
8
+
9
+ class SupportTicketEnv(EnvClient[SupportAction, SupportObservation, SupportState]):
10
+ """
11
+ OpenEnv client for the Support Ticket Resolution environment.
12
+
13
+ Usage (async):
14
+ async with SupportTicketEnv(base_url="http://localhost:8000") as env:
15
+ result = await env.reset(task_id=1)
16
+ result = await env.step(SupportAction(action_type="classify", category="billing"))
17
+
18
+ Usage (sync):
19
+ with SupportTicketEnv(base_url="http://localhost:8000").sync() as env:
20
+ result = env.reset(task_id=2)
21
+ result = env.step(SupportAction(action_type="classify", category="technical"))
22
+ result = env.step(SupportAction(action_type="escalate"))
23
+ """
24
+
25
+ def _parse_action(self, action: SupportAction) -> dict:
26
+ return action.model_dump()
27
+
28
+ def _parse_result(self, data: dict) -> SupportObservation:
29
+ obs_data = data.get("observation", data)
30
+ return SupportObservation(**obs_data)
31
+
32
+ def _parse_state(self, data: dict) -> SupportState:
33
+ return SupportState(**data)
graders.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graders for all three tasks.
3
+
4
+ Each grader returns a float in [0.0, 1.0].
5
+
6
+ Task 1 – Classification (easy)
7
+ - 1.0 : correct category
8
+ - 0.0 : wrong category
9
+
10
+ Task 2 – Action Selection (medium)
11
+ - 1.0 : correct action
12
+ - 0.5 : partially correct (e.g., escalate vs reply both defensible)
13
+ - 0.0 : clearly wrong (e.g., close an unsolved ticket)
14
+
15
+ Task 3 – Full Resolution (hard)
16
+ Combines classification + action + reply quality into a single score.
17
+ Rewards partial progress so the agent gets signal throughout the trajectory.
18
+ """
19
+
20
+ from __future__ import annotations
21
+ from typing import Dict, Any
22
+
23
+
24
+ # ─────────────────────────── helpers ───────────────────────────
25
+
26
+ # Pairs of actions that are considered "close enough" for partial credit
27
+ _PARTIAL_CREDIT_PAIRS = {
28
+ frozenset({"reply", "escalate"}), # borderline tickets
29
+ }
30
+
31
+ _KEYWORD_REWARDS: Dict[str, list[str]] = {
32
+ "billing": ["refund", "charge", "invoice", "payment", "billing"],
33
+ "account": ["password", "login", "account", "cancel", "subscription"],
34
+ "technical": ["engineering", "escalate", "bug", "crash", "error", "fix"],
35
+ "refund": ["refund", "return", "credit", "process"],
36
+ "general": ["hours", "contact", "phone", "information", "help"],
37
+ }
38
+
39
+
40
+ def _reply_quality(reply_text: str, category: str) -> float:
41
+ """Return 0.0–0.5 based on how relevant the reply text is."""
42
+ if not reply_text:
43
+ return 0.0
44
+ text_lower = reply_text.lower()
45
+ keywords = _KEYWORD_REWARDS.get(category, [])
46
+ hits = sum(1 for kw in keywords if kw in text_lower)
47
+ # cap at 0.5 (the other 0.5 comes from action correctness)
48
+ return min(0.5, hits * 0.1)
49
+
50
+
51
+ # ─────────────────────────── Task 1 ────────────────────────────
52
+
53
+ def grade_task1(
54
+ predicted_category: str,
55
+ correct_category: str,
56
+ ) -> float:
57
+ """Binary classification reward."""
58
+ return 1.0 if predicted_category == correct_category else 0.0
59
+
60
+
61
+ # ─────────────────────────── Task 2 ────────────────────────────
62
+
63
+ def grade_task2(
64
+ action_type: str,
65
+ correct_action: str,
66
+ category: str | None = None,
67
+ ) -> float:
68
+ """
69
+ Action-selection reward.
70
+ Full credit for exact match, partial credit for defensible alternatives.
71
+ Penalises closing an unresolved ticket.
72
+ """
73
+ if action_type == correct_action:
74
+ return 1.0
75
+
76
+ # Partial credit for ambiguous cases
77
+ pair = frozenset({action_type, correct_action})
78
+ if pair in _PARTIAL_CREDIT_PAIRS:
79
+ return 0.5
80
+
81
+ # Closing an unresolved ticket is always wrong
82
+ if action_type == "close":
83
+ return 0.0
84
+
85
+ return 0.0
86
+
87
+
88
+ # ─────────────────────────── Task 3 ────────────────────────────
89
+
90
+ def grade_task3(
91
+ classified_correctly: bool,
92
+ action_correct: bool,
93
+ action_partial: bool,
94
+ reply_text: str | None,
95
+ category: str,
96
+ resolved: bool,
97
+ steps_taken: int,
98
+ max_steps: int = 5,
99
+ ) -> float:
100
+ """
101
+ Multi-step resolution reward with partial progress.
102
+
103
+ Breakdown:
104
+ 0.20 – classification correct
105
+ 0.40 – action correct (0.20 if partial)
106
+ 0.25 – reply quality (NLP keyword overlap)
107
+ 0.15 – efficiency bonus (fewer steps → higher bonus)
108
+ """
109
+ score = 0.0
110
+
111
+ if classified_correctly:
112
+ score += 0.20
113
+
114
+ if action_correct:
115
+ score += 0.40
116
+ elif action_partial:
117
+ score += 0.20
118
+
119
+ if reply_text:
120
+ score += _reply_quality(reply_text, category) * 0.5 # scaled to 0.25 max
121
+
122
+ # Efficiency: full 0.15 for 1 step, 0 for max_steps steps
123
+ if resolved and steps_taken <= max_steps:
124
+ efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))
125
+ score += 0.15 * efficiency
126
+
127
+ return round(min(1.0, score), 4)
128
+
129
+
130
+ # ─────────────────────────── Penalty ───────────────────────────
131
+
132
+ def loop_penalty(step_count: int, max_steps: int = 10) -> float:
133
+ """Return a negative reward if agent is stuck in a loop."""
134
+ if step_count > max_steps:
135
+ return -0.05 * (step_count - max_steps)
136
+ return 0.0
gradio_ui.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gradio_ui.py — Interactive Gradio web interface for the Support Ticket Environment.
3
+
4
+ Allows human exploration and debugging without writing code.
5
+ Launched automatically when ENABLE_WEB_INTERFACE=true or run directly.
6
+
7
+ Usage:
8
+ python support_ticket_env/gradio_ui.py
9
+ """
10
+
11
+ import json
12
+ import sys
13
+ import os
14
+
15
+ ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16
+ STUB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "openenv_stub")
17
+ sys.path.insert(0, STUB)
18
+ sys.path.insert(0, ROOT)
19
+
20
+ try:
21
+ import gradio as gr
22
+ except ImportError:
23
+ print("gradio not installed. Run: pip install gradio")
24
+ sys.exit(1)
25
+
26
+ from support_ticket_env.server.support_environment import SupportTicketEnvironment
27
+ from support_ticket_env.models import SupportAction
28
+
29
+ # ─── shared env instance ────────────────────────────────────────
30
+ _env = SupportTicketEnvironment()
31
+ _history: list[dict] = []
32
+ _current_obs = None
33
+
34
+
35
+ # ─── helpers ────────────────────────────────────────────────────
36
+
37
+ def _format_history() -> str:
38
+ if not _history:
39
+ return "_No actions yet._"
40
+ lines = []
41
+ for i, h in enumerate(_history, 1):
42
+ reward_str = f"{h['reward']:+.3f}" if h["reward"] is not None else "—"
43
+ lines.append(
44
+ f"**Step {i}** | `{h['action']}` → reward `{reward_str}`\n"
45
+ f"> {h['feedback']}"
46
+ )
47
+ return "\n\n".join(lines)
48
+
49
+
50
+ def _obs_to_display(obs) -> tuple[str, str, str]:
51
+ """Return (ticket_box, status_box, score_box)."""
52
+ ticket = f"**[{obs.ticket_id}]** {obs.ticket_text}"
53
+ status = (
54
+ f"Task **{obs.task_id}** | Step **{obs.step_count}** | "
55
+ f"Category: `{obs.current_category or 'unknown'}` | "
56
+ f"Resolved: {'✅' if obs.resolved else '⬜'}"
57
+ )
58
+ score = f"Last step score: **{obs.score:.3f}** | reward: **{obs.reward or 0.0:+.3f}**"
59
+ return ticket, status, score
60
+
61
+
62
+ # ─── UI callbacks ────────────────────────────────────────────────
63
+
64
+ def do_reset(task_id: int, seed: int):
65
+ global _history, _current_obs
66
+ _history = []
67
+ obs = _env.reset(task_id=task_id, seed=seed)
68
+ _current_obs = obs
69
+ ticket, status, score = _obs_to_display(obs)
70
+ return (
71
+ ticket, status, score,
72
+ _format_history(),
73
+ gr.update(interactive=True),
74
+ obs.feedback,
75
+ gr.update(value=False), # done flag
76
+ )
77
+
78
+
79
+ def do_step(action_type: str, category: str, reply_text: str, reason: str):
80
+ global _current_obs
81
+ if _current_obs is None:
82
+ return (
83
+ "⚠️ Please reset the environment first.",
84
+ "", "", _format_history(), "", gr.update(value=False),
85
+ )
86
+
87
+ # Build action
88
+ kwargs = {"action_type": action_type}
89
+ if action_type == "classify" and category:
90
+ kwargs["category"] = category
91
+ if action_type == "reply" and reply_text:
92
+ kwargs["reply_text"] = reply_text
93
+ if reason:
94
+ kwargs["reason"] = reason
95
+
96
+ try:
97
+ action = SupportAction(**kwargs)
98
+ except Exception as e:
99
+ return (
100
+ _current_obs.ticket_text,
101
+ f"❌ Invalid action: {e}", "",
102
+ _format_history(), "", gr.update(value=False),
103
+ )
104
+
105
+ obs = _env.step(action)
106
+ _current_obs = obs
107
+
108
+ _history.append({
109
+ "action": f"{action_type}" + (f"/{category}" if category and action_type == "classify" else ""),
110
+ "reward": obs.reward,
111
+ "feedback": obs.feedback,
112
+ })
113
+
114
+ ticket, status, score = _obs_to_display(obs)
115
+ done_msg = "🏁 Episode finished!" if obs.done else ""
116
+ return (
117
+ ticket, status, score,
118
+ _format_history(),
119
+ obs.feedback,
120
+ gr.update(value=obs.done),
121
+ )
122
+
123
+
124
+ def do_state():
125
+ state = _env.state
126
+ return json.dumps({
127
+ "episode_id": state.episode_id,
128
+ "step_count": state.step_count,
129
+ "task_id": state.task_id,
130
+ "ticket_id": state.ticket_id,
131
+ "correct_category": state.correct_category,
132
+ "correct_action": state.correct_action,
133
+ "classified": state.classified,
134
+ "resolved": state.resolved,
135
+ "total_reward": state.total_reward,
136
+ "tickets_resolved": state.tickets_resolved,
137
+ "tickets_total": state.tickets_total,
138
+ }, indent=2)
139
+
140
+
141
+ # ─── UI layout ──────────────────────────────────────────────────
142
+
143
+ DESCRIPTION = """
144
+ # 🎫 Customer Support Ticket Resolution Environment
145
+
146
+ An **OpenEnv** environment for training AI agents to handle customer support tickets.
147
+
148
+ **Tasks:** 1 = Classify · 2 = Classify + Action · 3 = Full Queue Resolution
149
+ """
150
+
151
+ with gr.Blocks(title="Support Ticket Env", theme=gr.themes.Soft()) as demo:
152
+ gr.Markdown(DESCRIPTION)
153
+
154
+ with gr.Row():
155
+ # ── Left panel: controls ────────────────────────────────
156
+ with gr.Column(scale=1):
157
+ gr.Markdown("### ⚙️ Episode Setup")
158
+ task_slider = gr.Slider(1, 3, value=1, step=1, label="Task ID")
159
+ seed_input = gr.Number(value=42, label="Seed", precision=0)
160
+ reset_btn = gr.Button("🔄 Reset Episode", variant="primary")
161
+
162
+ gr.Markdown("### 🎬 Take Action")
163
+ action_type = gr.Radio(
164
+ ["classify", "reply", "escalate", "close"],
165
+ value="classify", label="Action Type",
166
+ )
167
+ category_dd = gr.Dropdown(
168
+ ["billing", "technical", "account", "general", "refund"],
169
+ label="Category (for classify)",
170
+ value=None,
171
+ )
172
+ reply_box = gr.Textbox(label="Reply Text (for reply)", lines=3)
173
+ reason_box = gr.Textbox(label="Reason (optional)")
174
+ step_btn = gr.Button("▶️ Step", variant="secondary")
175
+ state_btn = gr.Button("🔍 Show State")
176
+
177
+ # ── Right panel: observation ────────────────────────────
178
+ with gr.Column(scale=2):
179
+ gr.Markdown("### 📬 Current Ticket")
180
+ ticket_display = gr.Markdown("_Reset to start._")
181
+ status_display = gr.Markdown("")
182
+ score_display = gr.Markdown("")
183
+ feedback_box = gr.Textbox(label="Last Feedback", interactive=False)
184
+ done_checkbox = gr.Checkbox(label="Episode Done", interactive=False)
185
+
186
+ gr.Markdown("### 📜 Action History")
187
+ history_display = gr.Markdown("_No actions yet._")
188
+
189
+ gr.Markdown("### 🗂️ Raw State (JSON)")
190
+ state_output = gr.Code(language="json", label="state()")
191
+
192
+ # ── wire up ─────────────────────────────────────────────────
193
+ reset_btn.click(
194
+ do_reset,
195
+ inputs=[task_slider, seed_input],
196
+ outputs=[ticket_display, status_display, score_display,
197
+ history_display, step_btn, feedback_box, done_checkbox],
198
+ )
199
+ step_btn.click(
200
+ do_step,
201
+ inputs=[action_type, category_dd, reply_box, reason_box],
202
+ outputs=[ticket_display, status_display, score_display,
203
+ history_display, feedback_box, done_checkbox],
204
+ )
205
+ state_btn.click(
206
+ do_state, inputs=[], outputs=[state_output],
207
+ )
208
+
209
+
210
+ if __name__ == "__main__":
211
+ demo.launch(server_name="0.0.0.0", server_port=7861, share=False)
models.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Typed models for the Customer Support Ticket Resolution Environment.
3
+ Works with pydantic (production) or stdlib (offline/testing).
4
+ """
5
+ from __future__ import annotations
6
+ from typing import Any, Dict, List, Literal, Optional
7
+
8
+ try:
9
+ from pydantic import BaseModel, ConfigDict
10
+ _USE_PYDANTIC = True
11
+ except ImportError:
12
+ _USE_PYDANTIC = False
13
+
14
+ # ── import base classes from openenv (or stub) ──────────────────
15
+ from openenv.core.env_server.types import Action, Observation, State
16
+
17
+
18
+ # ═══════════════════════════════════════════════════════════════
19
+ # Action
20
+ # ═══════════════════════════════════════════════════════════════
21
+
22
+ if _USE_PYDANTIC:
23
+ class SupportAction(Action, BaseModel): # type: ignore[misc]
24
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
25
+ metadata: Dict[str, Any] = {}
26
+ action_type: Literal["classify", "reply", "escalate", "close"]
27
+ category: Optional[
28
+ Literal["billing", "technical", "account", "general", "refund"]
29
+ ] = None
30
+ reply_text: Optional[str] = None
31
+ reason: Optional[str] = None
32
+
33
+ def model_dump(self, **kw):
34
+ return super().model_dump(**kw)
35
+ else:
36
+ _VALID_ACTION_TYPES = {"classify", "reply", "escalate", "close"}
37
+ _VALID_CATEGORIES = {"billing", "technical", "account", "general", "refund"}
38
+
39
+ class SupportAction(Action): # type: ignore[no-redef]
40
+ def __init__(self, **kwargs):
41
+ action_type = kwargs.get("action_type")
42
+ if action_type not in _VALID_ACTION_TYPES:
43
+ raise ValueError(f"Invalid action_type: {action_type!r}")
44
+ category = kwargs.get("category")
45
+ if category is not None and category not in _VALID_CATEGORIES:
46
+ raise ValueError(f"Invalid category: {category!r}")
47
+ self.action_type = action_type
48
+ self.category = category
49
+ self.reply_text = kwargs.get("reply_text")
50
+ self.reason = kwargs.get("reason")
51
+ self.metadata = kwargs.get("metadata", {})
52
+
53
+
54
+ # ═══════════════════════════════════════════════════════════════
55
+ # Observation
56
+ # ═══════════════════════════════════════════════════════════════
57
+
58
+ if _USE_PYDANTIC:
59
+ class SupportObservation(Observation, BaseModel): # type: ignore[misc]
60
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
61
+ done: bool = False
62
+ reward: Optional[float] = None
63
+ metadata: Dict[str, Any] = {}
64
+ ticket_id: str = ""
65
+ ticket_text: str = ""
66
+ task_id: int = 1
67
+ current_category: Optional[str] = None
68
+ resolved: bool = False
69
+ step_count: int = 0
70
+ feedback: str = ""
71
+ score: float = 0.0
72
+ else:
73
+ class SupportObservation(Observation): # type: ignore[no-redef]
74
+ def __init__(self, **kwargs):
75
+ self.done = kwargs.pop("done", False)
76
+ self.reward = kwargs.pop("reward", None)
77
+ self.metadata = kwargs.pop("metadata", {})
78
+ self.ticket_id = kwargs.pop("ticket_id", "")
79
+ self.ticket_text = kwargs.pop("ticket_text", "")
80
+ self.task_id = kwargs.pop("task_id", 1)
81
+ self.current_category = kwargs.pop("current_category", None)
82
+ self.resolved = kwargs.pop("resolved", False)
83
+ self.step_count = kwargs.pop("step_count", 0)
84
+ self.feedback = kwargs.pop("feedback", "")
85
+ self.score = kwargs.pop("score", 0.0)
86
+
87
+
88
+ # ═══════════════════════════════════════════════════════════════
89
+ # State
90
+ # ═══════════════════════════════════════════════════════════════
91
+
92
+ if _USE_PYDANTIC:
93
+ class SupportState(State, BaseModel): # type: ignore[misc]
94
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
95
+ episode_id: Optional[str] = None
96
+ step_count: int = 0
97
+ task_id: int = 1
98
+ ticket_id: str = ""
99
+ correct_category: str = ""
100
+ correct_action: str = ""
101
+ classified: bool = False
102
+ resolved: bool = False
103
+ total_reward: float = 0.0
104
+ tickets_resolved: int = 0
105
+ tickets_total: int = 1
106
+ else:
107
+ class SupportState(State): # type: ignore[no-redef]
108
+ def __init__(self, **kwargs):
109
+ self.episode_id = kwargs.pop("episode_id", None)
110
+ self.step_count = kwargs.pop("step_count", 0)
111
+ self.task_id = kwargs.pop("task_id", 1)
112
+ self.ticket_id = kwargs.pop("ticket_id", "")
113
+ self.correct_category = kwargs.pop("correct_category", "")
114
+ self.correct_action = kwargs.pop("correct_action", "")
115
+ self.classified = kwargs.pop("classified", False)
116
+ self.resolved = kwargs.pop("resolved", False)
117
+ self.total_reward = kwargs.pop("total_reward", 0.0)
118
+ self.tickets_resolved = kwargs.pop("tickets_resolved", 0)
119
+ self.tickets_total = kwargs.pop("tickets_total", 1)
openenv.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: support_ticket_env
2
+ version: "1.0.0"
3
+ description: >
4
+ A real-world customer support ticket triage environment.
5
+ An AI agent acts as a support executive: it classifies incoming tickets,
6
+ selects the correct action (reply / escalate / close), and resolves
7
+ multi-ticket queues efficiently.
8
+ author: OpenEnv Hackathon Entry
9
+ tags:
10
+ - openenv
11
+ - customer-support
12
+ - triage
13
+ - nlp
14
+ - real-world
15
+ tasks:
16
+ - id: 1
17
+ name: Classification
18
+ difficulty: easy
19
+ description: >
20
+ Given a customer ticket, predict the correct category
21
+ (billing | technical | account | general | refund).
22
+ score_range: [0.0, 1.0]
23
+ - id: 2
24
+ name: Action Selection
25
+ difficulty: medium
26
+ description: >
27
+ First classify the ticket, then choose the best action:
28
+ reply, escalate, or close.
29
+ score_range: [0.0, 1.0]
30
+ - id: 3
31
+ name: Full Resolution
32
+ difficulty: hard
33
+ description: >
34
+ Handle a queue of 3 tickets. For each ticket classify it,
35
+ choose the right action, and (if replying) craft a relevant reply.
36
+ Bonus for fewer steps.
37
+ score_range: [0.0, 1.0]
38
+ action_space:
39
+ type: SupportAction
40
+ fields:
41
+ action_type: "classify | reply | escalate | close"
42
+ category: "billing | technical | account | general | refund (required for classify)"
43
+ reply_text: "string (required for reply)"
44
+ reason: "optional justification"
45
+ observation_space:
46
+ type: SupportObservation
47
+ fields:
48
+ ticket_id: string
49
+ ticket_text: string
50
+ task_id: integer
51
+ current_category: "string | null"
52
+ resolved: boolean
53
+ step_count: integer
54
+ feedback: string
55
+ score: float
56
+ docker_image: "support-ticket-env:latest"
57
+ hf_space: "openenv/support-ticket-env"
openenv_stub/openenv/__init__.py ADDED
File without changes
openenv_stub/openenv/core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from openenv.core.env_client import EnvClient
2
+ from openenv.core.env_server.types import Action, Observation, State
3
+ from openenv.core.env_server.interfaces import Environment
openenv_stub/openenv/core/env_client.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stub for openenv.core.env_client."""
2
+ from abc import ABC
3
+ from typing import Generic, TypeVar
4
+
5
+ ActT = TypeVar("ActT")
6
+ ObsT = TypeVar("ObsT")
7
+ StateT = TypeVar("StateT")
8
+
9
+
10
+ class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
11
+ def __init__(self, base_url: str, **kwargs):
12
+ self.base_url = base_url
openenv_stub/openenv/core/env_server/__init__.py ADDED
File without changes
openenv_stub/openenv/core/env_server/http_server.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stub for openenv.core.env_server.http_server."""
2
+ from typing import Any, Callable, Optional, Type
3
+ from openenv.core.env_server.types import Action, Observation
4
+
5
+
6
+ def create_app(env, action_cls, observation_cls, env_name=None, max_concurrent_envs=1, **kwargs):
7
+ """Stub — returns None when FastAPI is not available."""
8
+ try:
9
+ from fastapi import FastAPI
10
+ app = FastAPI(title=env_name or "SupportTicketEnv")
11
+ return app
12
+ except ImportError:
13
+ return None
openenv_stub/openenv/core/env_server/interfaces.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stub for openenv.core.env_server.interfaces."""
2
+ from __future__ import annotations
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Optional
5
+ from openenv.core.env_server.types import Action, Observation, State
6
+
7
+
8
+ class Environment(ABC):
9
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
10
+
11
+ def __init__(self, transform=None, rubric=None):
12
+ self.rubric = rubric
13
+
14
+ @abstractmethod
15
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> Observation:
16
+ ...
17
+
18
+ @abstractmethod
19
+ def step(self, action: Action, timeout_s: Optional[float] = None, **kwargs) -> Observation:
20
+ ...
21
+
22
+ @property
23
+ @abstractmethod
24
+ def state(self) -> State:
25
+ ...
26
+
27
+ def close(self) -> None:
28
+ pass
openenv_stub/openenv/core/env_server/types.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Offline stub for openenv.core.env_server.types.
3
+ Uses stdlib only — no pydantic required.
4
+ """
5
+ from __future__ import annotations
6
+ from typing import Any, Dict, Optional
7
+
8
+
9
+ class Action:
10
+ def __init__(self, **kwargs):
11
+ self.metadata = kwargs.pop("metadata", {})
12
+ for k, v in kwargs.items():
13
+ setattr(self, k, v)
14
+
15
+ def model_dump(self):
16
+ return {k: v for k, v in vars(self).items()}
17
+
18
+
19
+ class Observation:
20
+ def __init__(self, **kwargs):
21
+ self.done = kwargs.pop("done", False)
22
+ self.reward = kwargs.pop("reward", None)
23
+ self.metadata = kwargs.pop("metadata", {})
24
+ for k, v in kwargs.items():
25
+ setattr(self, k, v)
26
+
27
+
28
+ class State:
29
+ def __init__(self, **kwargs):
30
+ self.episode_id = kwargs.pop("episode_id", None)
31
+ self.step_count = kwargs.pop("step_count", 0)
32
+ for k, v in kwargs.items():
33
+ setattr(self, k, v)
pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.backends.legacy:build"
4
+
5
+ [project]
6
+ name = "support-ticket-env"
7
+ version = "1.0.0"
8
+ description = "Customer Support Ticket Resolution — OpenEnv Environment"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ dependencies = [
13
+ "openenv-core>=0.2.1",
14
+ "fastapi>=0.104.0",
15
+ "uvicorn[standard]>=0.24.0",
16
+ "pydantic>=2.0.0",
17
+ "openai>=1.0.0",
18
+ "pyyaml>=6.0",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ dev = ["pytest>=7.0", "pytest-asyncio", "httpx"]
23
+
24
+ [tool.setuptools.packages.find]
25
+ where = ["."]
26
+ include = ["support_ticket_env*"]
run_tests.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ run_tests.py — Self-contained test runner for support_ticket_env.
4
+ Runs all test cases using only the Python standard library.
5
+
6
+ Usage:
7
+ python run_tests.py
8
+ """
9
+
10
+ import sys
11
+ import os
12
+ import traceback
13
+ from typing import Callable, List, Tuple
14
+
15
+ # ─── path setup ────────────────────────────────────────────────
16
+ ROOT = os.path.dirname(os.path.abspath(__file__))
17
+ STUB = os.path.join(ROOT, "openenv_stub")
18
+ sys.path.insert(0, STUB)
19
+ sys.path.insert(0, ROOT)
20
+
21
+ # ─── minimal test framework ────────────────────────────────────
22
+ _tests: List[Tuple[str, Callable]] = []
23
+ _passed = 0
24
+ _failed = 0
25
+ _errors = 0
26
+
27
+ def test(fn: Callable) -> Callable:
28
+ _tests.append((fn.__qualname__, fn))
29
+ return fn
30
+
31
+ def assert_eq(a, b, msg=""):
32
+ if a != b:
33
+ raise AssertionError(f"{msg} | expected {b!r}, got {a!r}")
34
+
35
+ def assert_true(val, msg=""):
36
+ if not val:
37
+ raise AssertionError(msg or f"Expected truthy, got {val!r}")
38
+
39
+ def assert_in_range(val, lo, hi, msg=""):
40
+ if not (lo <= val <= hi):
41
+ raise AssertionError(msg or f"Expected {val!r} in [{lo}, {hi}]")
42
+
43
+ # ─────────────────────────────── imports ───────────────────────
44
+ from support_ticket_env.graders import (
45
+ grade_task1, grade_task2, grade_task3, loop_penalty,
46
+ )
47
+ from support_ticket_env.server.support_environment import SupportTicketEnvironment
48
+ from support_ticket_env.models import SupportAction
49
+
50
+
51
+ def make_env():
52
+ return SupportTicketEnvironment()
53
+
54
+
55
+ # ════════════════════════════════════════════════════════════════
56
+ # GRADER TESTS
57
+ # ════════════════════════════════════════════════════════════════
58
+
59
+ @test
60
+ def test_grade1_correct():
61
+ assert_eq(grade_task1("billing", "billing"), 1.0)
62
+
63
+ @test
64
+ def test_grade1_wrong():
65
+ assert_eq(grade_task1("technical", "billing"), 0.0)
66
+
67
+ @test
68
+ def test_grade1_all_categories():
69
+ for cat in ["billing", "technical", "account", "general", "refund"]:
70
+ assert_eq(grade_task1(cat, cat), 1.0, f"cat={cat}")
71
+
72
+ @test
73
+ def test_grade1_empty():
74
+ assert_eq(grade_task1("", "billing"), 0.0)
75
+
76
+ @test
77
+ def test_grade2_exact_reply():
78
+ assert_eq(grade_task2("reply", "reply"), 1.0)
79
+
80
+ @test
81
+ def test_grade2_exact_escalate():
82
+ assert_eq(grade_task2("escalate", "escalate"), 1.0)
83
+
84
+ @test
85
+ def test_grade2_exact_close():
86
+ assert_eq(grade_task2("close", "close"), 1.0)
87
+
88
+ @test
89
+ def test_grade2_partial_reply_escalate():
90
+ assert_eq(grade_task2("reply", "escalate"), 0.5)
91
+ assert_eq(grade_task2("escalate", "reply"), 0.5)
92
+
93
+ @test
94
+ def test_grade2_close_wrong():
95
+ assert_eq(grade_task2("close", "reply"), 0.0)
96
+
97
+ @test
98
+ def test_grade3_perfect():
99
+ score = grade_task3(True, True, False,
100
+ "we will process your refund billing payment",
101
+ "billing", True, 1, 5)
102
+ assert_true(score >= 0.9, f"Expected >=0.9, got {score}")
103
+
104
+ @test
105
+ def test_grade3_capped_at_one():
106
+ score = grade_task3(True, True, False,
107
+ "refund billing payment account cancel subscription",
108
+ "billing", True, 1, 5)
109
+ assert_true(score <= 1.0, f"Score exceeds 1.0: {score}")
110
+
111
+ @test
112
+ def test_grade3_partial_action_less_than_full():
113
+ s_partial = grade_task3(True, False, True, None, "technical", True, 2)
114
+ s_full = grade_task3(True, True, False, None, "technical", True, 2)
115
+ assert_true(s_partial < s_full, f"partial={s_partial} should < full={s_full}")
116
+
117
+ @test
118
+ def test_loop_penalty_none_within_limit():
119
+ assert_eq(loop_penalty(5), 0.0)
120
+ assert_eq(loop_penalty(10), 0.0)
121
+
122
+ @test
123
+ def test_loop_penalty_grows():
124
+ assert_true(loop_penalty(12) < loop_penalty(11))
125
+ assert_true(loop_penalty(11) < 0)
126
+
127
+
128
+ # ════════════════════════════════════════════════════════════════
129
+ # ENVIRONMENT TESTS
130
+ # ════════════════════════════════════════════════════════════════
131
+
132
+ @test
133
+ def test_env_reset_task1():
134
+ env = make_env()
135
+ obs = env.reset(task_id=1, seed=42)
136
+ assert_true(obs.ticket_text != "", "ticket_text should not be empty")
137
+ assert_eq(obs.task_id, 1)
138
+ assert_eq(obs.done, False)
139
+
140
+ @test
141
+ def test_env_task1_correct_classification():
142
+ env = make_env()
143
+ env.reset(task_id=1, seed=42)
144
+ state = env.state
145
+ obs = env.step(SupportAction(action_type="classify", category=state.correct_category))
146
+ assert_eq(obs.reward, 1.0)
147
+ assert_eq(obs.done, True)
148
+
149
+ @test
150
+ def test_env_task1_wrong_classification():
151
+ env = make_env()
152
+ env.reset(task_id=1, seed=42)
153
+ state = env.state
154
+ wrong = next(c for c in ["billing","technical","account","general","refund"]
155
+ if c != state.correct_category)
156
+ obs = env.step(SupportAction(action_type="classify", category=wrong))
157
+ assert_eq(obs.reward, 0.0)
158
+ assert_eq(obs.done, True)
159
+
160
+ @test
161
+ def test_env_task2_must_classify_first():
162
+ env = make_env()
163
+ env.reset(task_id=2, seed=42)
164
+ obs = env.step(SupportAction(action_type="escalate"))
165
+ assert_eq(obs.done, False)
166
+ assert_true("classify" in obs.feedback.lower())
167
+
168
+ @test
169
+ def test_env_task2_full_correct_episode():
170
+ env = make_env()
171
+ env.reset(task_id=2, seed=42)
172
+ state = env.state
173
+ env.step(SupportAction(action_type="classify", category=state.correct_category))
174
+ obs = env.step(SupportAction(action_type=state.correct_action))
175
+ assert_eq(obs.done, True)
176
+ assert_true(obs.reward >= 0.5, f"reward={obs.reward}")
177
+
178
+ @test
179
+ def test_env_task3_three_tickets():
180
+ env = make_env()
181
+ env.reset(task_id=3, seed=42)
182
+ assert_eq(env.state.tickets_total, 3)
183
+
184
+ @test
185
+ def test_env_task3_resolves_all():
186
+ env = make_env()
187
+ env.reset(task_id=3, seed=42)
188
+ done = False
189
+ steps = 0
190
+ while not done and steps < 30:
191
+ state = env.state
192
+ if not state.classified:
193
+ action = SupportAction(action_type="classify", category=state.correct_category)
194
+ else:
195
+ ca = state.correct_action
196
+ action = (SupportAction(action_type="reply",
197
+ reply_text=f"Handling your {state.correct_category} issue.")
198
+ if ca == "reply" else SupportAction(action_type=ca))
199
+ obs = env.step(action)
200
+ done = obs.done
201
+ steps += 1
202
+ assert_true(done, "Episode did not finish")
203
+ assert_eq(env.state.tickets_resolved, 3)
204
+
205
+ @test
206
+ def test_env_state_step_count():
207
+ env = make_env()
208
+ env.reset(task_id=1, seed=0)
209
+ assert_eq(env.state.step_count, 0)
210
+ state = env.state
211
+ env.step(SupportAction(action_type="classify", category=state.correct_category))
212
+ assert_eq(env.state.step_count, 1)
213
+
214
+ @test
215
+ def test_env_reward_always_in_range():
216
+ for seed in [0, 1, 2, 42, 99]:
217
+ for task_id in [1, 2, 3]:
218
+ env = make_env()
219
+ env.reset(task_id=task_id, seed=seed)
220
+ state = env.state
221
+ obs = env.step(SupportAction(action_type="classify", category=state.correct_category))
222
+ r = obs.reward or 0.0
223
+ assert_in_range(r, -1.0, 1.0, f"task={task_id} seed={seed} reward={r}")
224
+
225
+ @test
226
+ def test_env_task3_total_reward_positive():
227
+ env = make_env()
228
+ env.reset(task_id=3, seed=7)
229
+ total = 0.0
230
+ done = False
231
+ steps = 0
232
+ while not done and steps < 20:
233
+ state = env.state
234
+ action = (SupportAction(action_type="classify", category=state.correct_category)
235
+ if not state.classified
236
+ else SupportAction(action_type=state.correct_action))
237
+ obs = env.step(action)
238
+ total += obs.reward or 0.0
239
+ done = obs.done
240
+ steps += 1
241
+ assert_true(total > 0.0, f"total_reward={total}")
242
+
243
+
244
+ # ════════════════════════════════════════════════════════════════
245
+ # Runner
246
+ # ════════════════════════════════════════════════════════════════
247
+
248
+ def run_all():
249
+ global _passed, _failed, _errors
250
+ width = max(len(name) for name, _ in _tests) + 2
251
+ print(f"\n{'='*60}")
252
+ print(f" Running {len(_tests)} tests")
253
+ print(f"{'='*60}")
254
+ for name, fn in _tests:
255
+ try:
256
+ fn()
257
+ print(f" ✅ {name}")
258
+ _passed += 1
259
+ except AssertionError as e:
260
+ print(f" ❌ {name}")
261
+ print(f" {e}")
262
+ _failed += 1
263
+ except Exception:
264
+ print(f" 💥 {name}")
265
+ traceback.print_exc(limit=3)
266
+ _errors += 1
267
+ total = _passed + _failed + _errors
268
+ print(f"\n{'='*60}")
269
+ print(f" Results: {_passed}/{total} passed | {_failed} failed | {_errors} errors")
270
+ print(f"{'='*60}\n")
271
+ return _failed + _errors == 0
272
+
273
+
274
+ if __name__ == "__main__":
275
+ ok = run_all()
276
+ sys.exit(0 if ok else 1)
server/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from support_ticket_env.server.support_environment import SupportTicketEnvironment
2
+ from support_ticket_env.server.app import app
3
+
4
+ __all__ = ["SupportTicketEnvironment", "app"]
server/app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application entry point for the Support Ticket Environment.
3
+ Serves the OpenEnv HTTP/WebSocket API and optionally the Gradio UI at /web.
4
+ """
5
+ import os
6
+ from openenv.core.env_server.http_server import create_app
7
+
8
+ from support_ticket_env.models import SupportAction, SupportObservation
9
+ from support_ticket_env.server.support_environment import SupportTicketEnvironment
10
+
11
+ app = create_app(
12
+ env=SupportTicketEnvironment,
13
+ action_cls=SupportAction,
14
+ observation_cls=SupportObservation,
15
+ env_name="support_ticket_env",
16
+ max_concurrent_envs=4,
17
+ )
18
+
19
+ # Mount Gradio UI at /web when requested
20
+ if os.getenv("ENABLE_WEB_INTERFACE", "true").lower() == "true":
21
+ try:
22
+ import gradio as gr
23
+ from support_ticket_env.gradio_ui import demo
24
+ import gradio.routes
25
+ app = gr.mount_gradio_app(app, demo, path="/web")
26
+ print("Gradio UI mounted at /web")
27
+ except Exception as e:
28
+ print(f"Gradio UI not mounted: {e}")
server/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openenv-core>=0.2.1
2
+ fastapi>=0.104.0
3
+ uvicorn[standard]>=0.24.0
4
+ pydantic>=2.0.0
5
+ pyyaml>=6.0
6
+ gradio>=4.0.0
7
+ openai>=1.0.0
server/support_environment.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Customer Support Ticket Resolution — OpenEnv Environment (server side).
3
+
4
+ Implements the three tasks:
5
+ Task 1 (easy) – Classify a single ticket
6
+ Task 2 (medium) – Choose the correct action for a classified ticket
7
+ Task 3 (hard) – Fully resolve a queue of tickets with minimal steps
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import random
13
+ from typing import Optional
14
+
15
+ from openenv.core.env_server.interfaces import Environment
16
+ from openenv.core.env_server.types import State
17
+
18
+ from support_ticket_env.models import SupportAction, SupportObservation, SupportState
19
+ from support_ticket_env.tickets import TICKETS, TICKET_LOOKUP
20
+ from support_ticket_env.graders import (
21
+ grade_task1,
22
+ grade_task2,
23
+ grade_task3,
24
+ loop_penalty,
25
+ )
26
+
27
+
28
+ class SupportTicketEnvironment(Environment):
29
+ """
30
+ OpenEnv environment that simulates a customer-support triage desk.
31
+
32
+ The task_id (1, 2, or 3) is set when the environment is reset.
33
+ """
34
+
35
+ SUPPORTS_CONCURRENT_SESSIONS = True
36
+
37
+ def __init__(self) -> None:
38
+ super().__init__()
39
+ self._task_id: int = 1
40
+ self._ticket: dict = {}
41
+ self._classified: bool = False
42
+ self._resolved: bool = False
43
+ self._step_count: int = 0
44
+ self._total_reward: float = 0.0
45
+ self._episode_id: Optional[str] = None
46
+
47
+ # Task 3: queue of tickets
48
+ self._queue: list[dict] = []
49
+ self._tickets_resolved: int = 0
50
+ self._tickets_total: int = 1
51
+
52
+ # ──────────────────────── reset ────────────────────────────
53
+
54
+ def reset(
55
+ self,
56
+ seed: Optional[int] = None,
57
+ episode_id: Optional[str] = None,
58
+ task_id: int = 1,
59
+ **kwargs,
60
+ ) -> SupportObservation:
61
+ rng = random.Random(seed)
62
+ self._episode_id = episode_id
63
+ self._task_id = int(task_id)
64
+ self._step_count = 0
65
+ self._total_reward = 0.0
66
+ self._classified = False
67
+ self._resolved = False
68
+
69
+ if self._task_id == 3:
70
+ # Give the agent a queue of 3 tickets
71
+ self._queue = rng.sample(TICKETS, k=3)
72
+ self._tickets_total = len(self._queue)
73
+ self._tickets_resolved = 0
74
+ self._ticket = self._queue[0]
75
+ else:
76
+ self._ticket = rng.choice(TICKETS)
77
+ self._tickets_total = 1
78
+ self._tickets_resolved = 0
79
+
80
+ return self._make_obs(
81
+ feedback="New episode started. Read the ticket and take action.",
82
+ score=0.0,
83
+ )
84
+
85
+ # ──────────────────────── step ─────────────────────────────
86
+
87
+ def step(self, action: SupportAction, **kwargs) -> SupportObservation: # type: ignore[override]
88
+ self._step_count += 1
89
+ penalty = loop_penalty(self._step_count)
90
+
91
+ if self._task_id == 1:
92
+ obs = self._step_task1(action)
93
+ elif self._task_id == 2:
94
+ obs = self._step_task2(action)
95
+ else:
96
+ obs = self._step_task3(action)
97
+
98
+ # Apply loop penalty on top of step reward
99
+ obs.reward = (obs.reward or 0.0) + penalty
100
+ obs.reward = round(max(-1.0, min(1.0, obs.reward)), 4)
101
+ self._total_reward += obs.reward
102
+ obs.step_count = self._step_count
103
+ return obs
104
+
105
+ # ──────────────────────── Task 1 ───────────────────────────
106
+
107
+ def _step_task1(self, action: SupportAction) -> SupportObservation:
108
+ if action.action_type != "classify":
109
+ return self._make_obs(
110
+ feedback="Task 1 requires a 'classify' action.",
111
+ score=0.0,
112
+ done=False,
113
+ )
114
+
115
+ score = grade_task1(
116
+ predicted_category=action.category or "",
117
+ correct_category=self._ticket["category"],
118
+ )
119
+ self._classified = score == 1.0
120
+ correct = self._ticket["category"]
121
+
122
+ if score == 1.0:
123
+ feedback = f"✅ Correct! Category: '{correct}'."
124
+ done = True
125
+ else:
126
+ feedback = (
127
+ f"❌ Wrong. You said '{action.category}', correct is '{correct}'."
128
+ )
129
+ done = True # Task 1 is one-shot — agent gets one attempt
130
+
131
+ obs = self._make_obs(feedback=feedback, score=score, done=done)
132
+ if done:
133
+ self._resolved = True
134
+ return obs
135
+
136
+ # ──────────────────────── Task 2 ───────────────────────────
137
+
138
+ def _step_task2(self, action: SupportAction) -> SupportObservation:
139
+ # First step must be classification
140
+ if not self._classified:
141
+ if action.action_type != "classify":
142
+ return self._make_obs(
143
+ feedback="Please classify the ticket first.",
144
+ score=0.0,
145
+ )
146
+ cat_score = grade_task1(
147
+ action.category or "", self._ticket["category"]
148
+ )
149
+ self._classified = True
150
+ return self._make_obs(
151
+ feedback=(
152
+ f"Classified as '{action.category}'. "
153
+ f"{'Correct ✅' if cat_score == 1.0 else 'Incorrect ❌'} "
154
+ "Now choose an action."
155
+ ),
156
+ score=cat_score * 0.3, # partial credit toward max 1.0
157
+ )
158
+
159
+ # Second step: choose action
160
+ score = grade_task2(
161
+ action_type=action.action_type,
162
+ correct_action=self._ticket["correct_action"],
163
+ category=self._ticket["category"],
164
+ )
165
+ correct = self._ticket["correct_action"]
166
+ if score == 1.0:
167
+ feedback = f"✅ Correct action: '{correct}'."
168
+ elif score == 0.5:
169
+ feedback = (
170
+ f"⚠️ Partial credit. '{action.action_type}' is defensible "
171
+ f"but '{correct}' is preferred."
172
+ )
173
+ else:
174
+ feedback = f"❌ Wrong action. Correct: '{correct}'."
175
+
176
+ self._resolved = True
177
+ return self._make_obs(feedback=feedback, score=score, done=True)
178
+
179
+ # ──────────────────────── Task 3 ───────────────────────────
180
+
181
+ def _step_task3(self, action: SupportAction) -> SupportObservation:
182
+ MAX_STEPS = 15
183
+
184
+ if not self._classified:
185
+ # Must classify first
186
+ if action.action_type != "classify":
187
+ return self._make_obs(
188
+ feedback="Classify the ticket before taking action.",
189
+ score=0.0,
190
+ )
191
+ cat_score = grade_task1(
192
+ action.category or "", self._ticket["category"]
193
+ )
194
+ self._classified = True
195
+ return self._make_obs(
196
+ feedback=(
197
+ f"Classified '{self._ticket['id']}' as '{action.category}'. "
198
+ f"{'Correct ✅' if cat_score == 1.0 else 'Incorrect ❌'} "
199
+ "Now resolve it."
200
+ ),
201
+ score=cat_score * 0.1,
202
+ )
203
+
204
+ # Resolve current ticket
205
+ action_correct = action.action_type == self._ticket["correct_action"]
206
+ pair = frozenset({action.action_type, self._ticket["correct_action"]})
207
+ action_partial = (not action_correct) and pair in {
208
+ frozenset({"reply", "escalate"})
209
+ }
210
+
211
+ score = grade_task3(
212
+ classified_correctly=self._classified,
213
+ action_correct=action_correct,
214
+ action_partial=action_partial,
215
+ reply_text=action.reply_text,
216
+ category=self._ticket["category"],
217
+ resolved=True,
218
+ steps_taken=self._step_count,
219
+ max_steps=MAX_STEPS,
220
+ )
221
+
222
+ self._tickets_resolved += 1
223
+ correct_action = self._ticket["correct_action"]
224
+
225
+ # Advance to next ticket in queue
226
+ if self._tickets_resolved < self._tickets_total:
227
+ self._ticket = self._queue[self._tickets_resolved]
228
+ self._classified = False
229
+ feedback = (
230
+ f"Ticket resolved (score {score:.2f}). "
231
+ f"Moving to next ticket ({self._tickets_resolved + 1}/{self._tickets_total})."
232
+ )
233
+ done = False
234
+ else:
235
+ feedback = (
236
+ f"All {self._tickets_total} tickets resolved! "
237
+ f"Episode score: {self._total_reward + score:.2f}"
238
+ )
239
+ done = True
240
+ self._resolved = True
241
+
242
+ return self._make_obs(feedback=feedback, score=score, done=done)
243
+
244
+ # ──────────────────────── helpers ──────────────────────────
245
+
246
+ def _make_obs(
247
+ self,
248
+ feedback: str,
249
+ score: float,
250
+ done: bool = False,
251
+ ) -> SupportObservation:
252
+ return SupportObservation(
253
+ ticket_id=self._ticket.get("id", ""),
254
+ ticket_text=self._ticket.get("text", ""),
255
+ task_id=self._task_id,
256
+ current_category=self._ticket.get("category") if self._classified else None,
257
+ resolved=self._resolved,
258
+ step_count=self._step_count,
259
+ feedback=feedback,
260
+ score=score,
261
+ reward=score,
262
+ done=done,
263
+ )
264
+
265
+ # ──────────────────────── state ────────────────────────────
266
+
267
+ @property
268
+ def state(self) -> SupportState:
269
+ return SupportState(
270
+ episode_id=self._episode_id,
271
+ step_count=self._step_count,
272
+ task_id=self._task_id,
273
+ ticket_id=self._ticket.get("id", ""),
274
+ correct_category=self._ticket.get("category", ""),
275
+ correct_action=self._ticket.get("correct_action", ""),
276
+ classified=self._classified,
277
+ resolved=self._resolved,
278
+ total_reward=self._total_reward,
279
+ tickets_resolved=self._tickets_resolved,
280
+ tickets_total=self._tickets_total,
281
+ )
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Pytest configuration for support_ticket_env tests."""
2
+ import sys
3
+ import os
4
+
5
+ # Ensure the package root is on the path when running tests directly
6
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
tests/test_environment.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for SupportTicketEnvironment — runs the environment directly
3
+ (no HTTP server required).
4
+ """
5
+
6
+ import pytest
7
+ from support_ticket_env.server.support_environment import SupportTicketEnvironment
8
+ from support_ticket_env.models import SupportAction
9
+
10
+
11
+ # ─────────────────────────── fixtures ──────────────────────────
12
+
13
+ @pytest.fixture
14
+ def env():
15
+ return SupportTicketEnvironment()
16
+
17
+
18
+ # ─────────────────────────── Task 1 ────────────────────────────
19
+
20
+ class TestTask1:
21
+ def test_reset_returns_observation(self, env):
22
+ obs = env.reset(task_id=1, seed=42)
23
+ assert obs.ticket_text
24
+ assert obs.task_id == 1
25
+ assert obs.done is False
26
+
27
+ def test_correct_classification(self, env):
28
+ obs = env.reset(task_id=1, seed=42)
29
+ # Find out the correct category via state
30
+ state = env.state
31
+ action = SupportAction(
32
+ action_type="classify",
33
+ category=state.correct_category,
34
+ )
35
+ obs = env.step(action)
36
+ assert obs.reward == 1.0
37
+ assert obs.done is True
38
+
39
+ def test_wrong_classification(self, env):
40
+ env.reset(task_id=1, seed=42)
41
+ state = env.state
42
+ wrong_cats = [
43
+ c for c in ["billing", "technical", "account", "general", "refund"]
44
+ if c != state.correct_category
45
+ ]
46
+ action = SupportAction(action_type="classify", category=wrong_cats[0])
47
+ obs = env.step(action)
48
+ assert obs.reward == 0.0
49
+ assert obs.done is True
50
+
51
+ def test_non_classify_action_penalised(self, env):
52
+ env.reset(task_id=1, seed=42)
53
+ obs = env.step(SupportAction(action_type="reply", reply_text="hello"))
54
+ # Should not crash; done might be False and reward 0
55
+ assert obs.reward is not None
56
+
57
+
58
+ # ─────────────────────────── Task 2 ────────────────────────────
59
+
60
+ class TestTask2:
61
+ def test_full_correct_episode(self, env):
62
+ env.reset(task_id=2, seed=42)
63
+ state = env.state
64
+
65
+ # Step 1: classify
66
+ obs = env.step(SupportAction(
67
+ action_type="classify",
68
+ category=state.correct_category,
69
+ ))
70
+ assert obs.done is False
71
+ assert obs.reward > 0
72
+
73
+ # Step 2: correct action
74
+ obs = env.step(SupportAction(action_type=state.correct_action))
75
+ assert obs.done is True
76
+ assert obs.reward >= 0.5
77
+
78
+ def test_must_classify_first(self, env):
79
+ env.reset(task_id=2, seed=7)
80
+ obs = env.step(SupportAction(action_type="escalate"))
81
+ assert obs.done is False
82
+ assert "classify" in obs.feedback.lower()
83
+
84
+ def test_state_reflects_progress(self, env):
85
+ env.reset(task_id=2, seed=7)
86
+ state = env.state
87
+ assert state.classified is False
88
+
89
+ env.step(SupportAction(
90
+ action_type="classify",
91
+ category=state.correct_category,
92
+ ))
93
+ state2 = env.state
94
+ assert state2.classified is True
95
+ assert state2.step_count == 1
96
+
97
+
98
+ # ─────────────────────────── Task 3 ────────────────────────────
99
+
100
+ class TestTask3:
101
+ def test_queue_has_three_tickets(self, env):
102
+ env.reset(task_id=3, seed=42)
103
+ state = env.state
104
+ assert state.tickets_total == 3
105
+ assert state.tickets_resolved == 0
106
+
107
+ def test_resolve_all_tickets(self, env):
108
+ env.reset(task_id=3, seed=42)
109
+ done = False
110
+ steps = 0
111
+
112
+ while not done and steps < 30:
113
+ state = env.state
114
+ if not state.classified:
115
+ action = SupportAction(
116
+ action_type="classify",
117
+ category=state.correct_category,
118
+ )
119
+ else:
120
+ ca = state.correct_action
121
+ if ca == "reply":
122
+ action = SupportAction(
123
+ action_type="reply",
124
+ reply_text=f"We are handling your {state.correct_category} issue.",
125
+ )
126
+ else:
127
+ action = SupportAction(action_type=ca)
128
+ obs = env.step(action)
129
+ done = obs.done
130
+ steps += 1
131
+
132
+ assert done, "Episode should finish after 3 tickets"
133
+ final_state = env.state
134
+ assert final_state.tickets_resolved == 3
135
+
136
+ def test_total_reward_positive(self, env):
137
+ env.reset(task_id=3, seed=123)
138
+ total = 0.0
139
+ done = False
140
+ steps = 0
141
+
142
+ while not done and steps < 20:
143
+ state = env.state
144
+ if not state.classified:
145
+ action = SupportAction(
146
+ action_type="classify",
147
+ category=state.correct_category,
148
+ )
149
+ else:
150
+ action = SupportAction(action_type=state.correct_action)
151
+ obs = env.step(action)
152
+ total += obs.reward or 0.0
153
+ done = obs.done
154
+ steps += 1
155
+
156
+ assert total > 0.0
157
+
158
+
159
+ # ─────────────────────────── State API ─────────────────────────
160
+
161
+ class TestStateAPI:
162
+ def test_state_after_reset(self, env):
163
+ env.reset(task_id=1, seed=0)
164
+ state = env.state
165
+ assert state.step_count == 0
166
+ assert state.task_id == 1
167
+ assert state.ticket_id != ""
168
+
169
+ def test_step_count_increments(self, env):
170
+ env.reset(task_id=1, seed=0)
171
+ state = env.state
172
+ env.step(SupportAction(action_type="classify", category=state.correct_category))
173
+ assert env.state.step_count == 1
174
+
175
+
176
+ # ─────────────────────────── Reward bounds ─────────────────────
177
+
178
+ class TestRewardBounds:
179
+ def test_reward_in_range(self, env):
180
+ for seed in [0, 1, 2, 3, 42]:
181
+ for task_id in [1, 2, 3]:
182
+ env.reset(task_id=task_id, seed=seed)
183
+ state = env.state
184
+ action = SupportAction(
185
+ action_type="classify",
186
+ category=state.correct_category,
187
+ )
188
+ obs = env.step(action)
189
+ assert -1.0 <= (obs.reward or 0.0) <= 1.0, (
190
+ f"Reward out of bounds: {obs.reward}"
191
+ )
tests/test_graders.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for grader functions."""
2
+
3
+ import pytest
4
+ from support_ticket_env.graders import (
5
+ grade_task1,
6
+ grade_task2,
7
+ grade_task3,
8
+ loop_penalty,
9
+ )
10
+
11
+
12
+ class TestTask1Grader:
13
+ def test_correct_category(self):
14
+ assert grade_task1("billing", "billing") == 1.0
15
+
16
+ def test_wrong_category(self):
17
+ assert grade_task1("technical", "billing") == 0.0
18
+
19
+ def test_all_categories(self):
20
+ for cat in ["billing", "technical", "account", "general", "refund"]:
21
+ assert grade_task1(cat, cat) == 1.0
22
+
23
+ def test_empty_prediction(self):
24
+ assert grade_task1("", "billing") == 0.0
25
+
26
+
27
+ class TestTask2Grader:
28
+ def test_exact_match(self):
29
+ assert grade_task2("reply", "reply") == 1.0
30
+ assert grade_task2("escalate", "escalate") == 1.0
31
+ assert grade_task2("close", "close") == 1.0
32
+
33
+ def test_partial_credit_reply_escalate(self):
34
+ score = grade_task2("reply", "escalate")
35
+ assert score == 0.5
36
+ score = grade_task2("escalate", "reply")
37
+ assert score == 0.5
38
+
39
+ def test_wrong_action_close(self):
40
+ assert grade_task2("close", "reply") == 0.0
41
+ assert grade_task2("close", "escalate") == 0.0
42
+
43
+ def test_classify_when_action_expected(self):
44
+ assert grade_task2("classify", "reply") == 0.0
45
+
46
+
47
+ class TestTask3Grader:
48
+ def test_perfect_resolution(self):
49
+ score = grade_task3(
50
+ classified_correctly=True,
51
+ action_correct=True,
52
+ action_partial=False,
53
+ reply_text="we will process your refund billing payment",
54
+ category="billing",
55
+ resolved=True,
56
+ steps_taken=1,
57
+ max_steps=5,
58
+ )
59
+ assert score > 0.9
60
+
61
+ def test_no_classification(self):
62
+ score = grade_task3(
63
+ classified_correctly=False,
64
+ action_correct=True,
65
+ action_partial=False,
66
+ reply_text="here is the refund",
67
+ category="billing",
68
+ resolved=True,
69
+ steps_taken=2,
70
+ )
71
+ # Should not get the 0.20 classification bonus
72
+ assert score < 1.0
73
+
74
+ def test_partial_action(self):
75
+ score_partial = grade_task3(
76
+ classified_correctly=True,
77
+ action_correct=False,
78
+ action_partial=True,
79
+ reply_text=None,
80
+ category="technical",
81
+ resolved=True,
82
+ steps_taken=2,
83
+ )
84
+ score_correct = grade_task3(
85
+ classified_correctly=True,
86
+ action_correct=True,
87
+ action_partial=False,
88
+ reply_text=None,
89
+ category="technical",
90
+ resolved=True,
91
+ steps_taken=2,
92
+ )
93
+ assert score_partial < score_correct
94
+
95
+ def test_score_capped_at_one(self):
96
+ score = grade_task3(
97
+ classified_correctly=True,
98
+ action_correct=True,
99
+ action_partial=False,
100
+ reply_text="refund billing payment account cancel subscription",
101
+ category="billing",
102
+ resolved=True,
103
+ steps_taken=1,
104
+ max_steps=5,
105
+ )
106
+ assert score <= 1.0
107
+
108
+
109
+ class TestLoopPenalty:
110
+ def test_no_penalty_within_limit(self):
111
+ assert loop_penalty(5) == 0.0
112
+ assert loop_penalty(10) == 0.0
113
+
114
+ def test_penalty_beyond_limit(self):
115
+ assert loop_penalty(11) < 0.0
116
+ assert loop_penalty(15) < loop_penalty(11)
117
+
118
+ def test_penalty_grows(self):
119
+ p1 = loop_penalty(12)
120
+ p2 = loop_penalty(14)
121
+ assert p2 < p1
tickets.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Realistic support ticket dataset with ground-truth labels.
3
+ Each ticket includes:
4
+ - id
5
+ - text : customer message
6
+ - category : ground-truth category
7
+ - correct_action : best first action ("reply" | "escalate" | "close")
8
+ - resolution_hint : ideal reply / close reason (used for reward scoring)
9
+ """
10
+
11
+ TICKETS = [
12
+ {
13
+ "id": "T001",
14
+ "text": "Hi, I was charged twice for my subscription this month. Please help!",
15
+ "category": "billing",
16
+ "correct_action": "reply",
17
+ "resolution_hint": "apologize and initiate refund for duplicate charge",
18
+ },
19
+ {
20
+ "id": "T002",
21
+ "text": "I cannot log into my account. The password reset email never arrives.",
22
+ "category": "account",
23
+ "correct_action": "reply",
24
+ "resolution_hint": "guide user to check spam folder and verify email address",
25
+ },
26
+ {
27
+ "id": "T003",
28
+ "text": "Your app crashes every time I try to upload a file larger than 10 MB.",
29
+ "category": "technical",
30
+ "correct_action": "escalate",
31
+ "resolution_hint": "escalate to engineering team with crash details",
32
+ },
33
+ {
34
+ "id": "T004",
35
+ "text": "I'd like a full refund. I haven't used the service at all this month.",
36
+ "category": "refund",
37
+ "correct_action": "reply",
38
+ "resolution_hint": "verify account activity and process refund per policy",
39
+ },
40
+ {
41
+ "id": "T005",
42
+ "text": "What are your business hours and do you have a phone number I can call?",
43
+ "category": "general",
44
+ "correct_action": "reply",
45
+ "resolution_hint": "provide business hours and contact information",
46
+ },
47
+ {
48
+ "id": "T006",
49
+ "text": "My invoice shows a charge for a plan I never subscribed to.",
50
+ "category": "billing",
51
+ "correct_action": "escalate",
52
+ "resolution_hint": "escalate potential fraudulent charge to billing team",
53
+ },
54
+ {
55
+ "id": "T007",
56
+ "text": "How do I cancel my subscription? I can't find the option anywhere.",
57
+ "category": "account",
58
+ "correct_action": "reply",
59
+ "resolution_hint": "guide user to account settings > subscription > cancel",
60
+ },
61
+ {
62
+ "id": "T008",
63
+ "text": "The API is returning 500 errors intermittently for the past 2 hours.",
64
+ "category": "technical",
65
+ "correct_action": "escalate",
66
+ "resolution_hint": "escalate to on-call engineering with timestamps",
67
+ },
68
+ {
69
+ "id": "T009",
70
+ "text": "Thank you! The issue has been resolved. You guys are awesome.",
71
+ "category": "general",
72
+ "correct_action": "close",
73
+ "resolution_hint": "acknowledge and close the ticket",
74
+ },
75
+ {
76
+ "id": "T010",
77
+ "text": "I need an itemised invoice for my company's accounting department.",
78
+ "category": "billing",
79
+ "correct_action": "reply",
80
+ "resolution_hint": "generate and send itemised invoice to customer email",
81
+ },
82
+ ]
83
+
84
+ TICKET_LOOKUP = {t["id"]: t for t in TICKETS}