| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| from openai import OpenAI |
|
|
| from support_triage_openenv import Action, SupportTriageEnv |
|
|
| |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" |
| MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| BENCHMARK = os.getenv("SUPPORT_TRIAGE_BENCHMARK", "support-triage-openenv") |
| SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.9")) |
|
|
| SYSTEM_PROMPT = ( |
| "You are solving customer support ticket triage. " |
| "Return exactly one JSON object with keys: " |
| "action_type, ticket_id, priority, category, needs_escalation, message." |
| ) |
|
|
| RULE_POLICY: dict[str, list[dict[str, Any]]] = { |
| "easy_password_reset": [ |
| {"action_type": "read_ticket", "ticket_id": "T-1001"}, |
| { |
| "action_type": "classify_ticket", |
| "ticket_id": "T-1001", |
| "priority": "medium", |
| "category": "account", |
| "needs_escalation": False, |
| }, |
| { |
| "action_type": "draft_reply", |
| "message": ( |
| "We will send a reset link to your email. For security, confirm the request " |
| "from your registered email before using the reset link." |
| ), |
| }, |
| {"action_type": "resolve_ticket", "ticket_id": "T-1001"}, |
| ], |
| "medium_billing_dispute": [ |
| {"action_type": "read_ticket", "ticket_id": "T-2001"}, |
| {"action_type": "read_ticket", "ticket_id": "T-2002"}, |
| { |
| "action_type": "classify_ticket", |
| "ticket_id": "T-2001", |
| "priority": "high", |
| "category": "billing", |
| "needs_escalation": False, |
| }, |
| { |
| "action_type": "draft_reply", |
| "message": ( |
| "We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. " |
| "Refund processing typically takes 3-5 business days." |
| ), |
| }, |
| {"action_type": "resolve_ticket", "ticket_id": "T-2001"}, |
| ], |
| "hard_outage_incident": [ |
| {"action_type": "read_ticket", "ticket_id": "T-3001"}, |
| {"action_type": "read_ticket", "ticket_id": "T-3002"}, |
| {"action_type": "read_ticket", "ticket_id": "T-3003"}, |
| { |
| "action_type": "classify_ticket", |
| "ticket_id": "T-3001", |
| "priority": "urgent", |
| "category": "technical", |
| "needs_escalation": True, |
| }, |
| { |
| "action_type": "draft_reply", |
| "message": ( |
| "We have escalated this incident and are investigating now. " |
| "The status page will carry updates while we continue incident response." |
| ), |
| }, |
| {"action_type": "resolve_ticket", "ticket_id": "T-3001"}, |
| ], |
| } |
|
|
|
|
| @dataclass |
| class EpisodeResult: |
| task_id: str |
| steps: int |
| score: float |
| success: bool |
| final_reward: float |
| rewards: list[float] |
| fallback_count: int |
|
|
|
|
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None: |
| error_val = error if error else "null" |
| done_val = str(done).lower() |
| print( |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", |
| flush=True, |
| ) |
|
|
|
|
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) |
|
|
|
|
| def _extract_json(text: str) -> str: |
| text = text.strip() |
| start = text.find("{") |
| end = text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| raise ValueError("No JSON object found in model response") |
| return text[start : end + 1] |
|
|
|
|
| def heuristic_action(task_id: str, step_idx: int) -> Action: |
| plan = RULE_POLICY[task_id] |
| idx = min(step_idx, len(plan) - 1) |
| return Action.model_validate(plan[idx]) |
|
|
|
|
| def llm_action(client: OpenAI, observation: dict[str, Any], state: dict[str, Any]) -> Action: |
| prompt = json.dumps( |
| { |
| "instruction": "Pick the best next single action to maximize final task score.", |
| "observation": observation, |
| "state": state, |
| }, |
| ensure_ascii=True, |
| ) |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ], |
| temperature=0, |
| max_tokens=220, |
| stream=False, |
| ) |
| text = (completion.choices[0].message.content or "").strip() |
| payload = json.loads(_extract_json(text)) |
| return Action.model_validate(payload) |
|
|
|
|
| def action_to_str(action: Action) -> str: |
| if action.action_type == "read_ticket": |
| return f"read_ticket({action.ticket_id})" |
| if action.action_type == "classify_ticket": |
| return ( |
| f"classify_ticket({action.ticket_id},{action.priority},{action.category}," |
| f"{str(bool(action.needs_escalation)).lower()})" |
| ) |
| if action.action_type == "draft_reply": |
| length = len((action.message or "").strip()) |
| return f"draft_reply(len={length})" |
| if action.action_type == "resolve_ticket": |
| return f"resolve_ticket({action.ticket_id})" |
| return action.action_type |
|
|
|
|
| def run_episode( |
| env: SupportTriageEnv, |
| task_id: str, |
| mode: str, |
| client: OpenAI | None, |
| started_at: float, |
| runtime_limit_seconds: int, |
| ) -> EpisodeResult: |
| obs = env.reset(task_id) |
| done = False |
| info: dict[str, Any] = {} |
| rewards: list[float] = [] |
| steps_taken = 0 |
| fallback_count = 0 |
| success = False |
| score = 0.0 |
| final_reward = 0.0 |
|
|
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) |
|
|
| while not done: |
| if time.monotonic() - started_at > runtime_limit_seconds: |
| raise TimeoutError(f"Runtime exceeded {runtime_limit_seconds}s") |
|
|
| step_idx = env.state()["step_count"] |
|
|
| if mode == "heuristic": |
| action = heuristic_action(task_id, step_idx) |
| else: |
| assert client is not None |
| try: |
| action = llm_action(client, obs.model_dump(), env.state()) |
| except Exception: |
| fallback_count += 1 |
| action = heuristic_action(task_id, step_idx) |
|
|
| step_error: str | None = None |
| try: |
| obs, reward, done, info = env.step(action) |
| reward_value = float(reward.value) |
| except Exception as exc: |
| step_error = str(exc) |
| reward_value = 0.0 |
| done = True |
|
|
| steps_taken = step_idx + 1 |
| rewards.append(reward_value) |
| final_reward = reward_value |
|
|
| log_step( |
| step=steps_taken, |
| action=action_to_str(action), |
| reward=reward_value, |
| done=done, |
| error=step_error, |
| ) |
|
|
| if done: |
| break |
|
|
| score = max(0.0, min(1.0, float(info.get("grader_score", 0.0)))) |
| success = score >= SUCCESS_SCORE_THRESHOLD |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) |
|
|
| return EpisodeResult( |
| task_id=task_id, |
| steps=steps_taken, |
| score=round(score, 4), |
| success=success, |
| final_reward=round(final_reward, 4), |
| rewards=[round(r, 4) for r in rewards], |
| fallback_count=fallback_count, |
| ) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Submission inference script.") |
| parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai") |
| parser.add_argument("--output", default="scores/inference_scores.json") |
| parser.add_argument("--runtime-limit-seconds", type=int, default=1200) |
| parser.add_argument("--task-id", default="", help="Optional single task id; default runs all tasks") |
| args = parser.parse_args() |
|
|
| if args.mode == "openai" and not HF_TOKEN: |
| raise RuntimeError("HF_TOKEN is required for --mode openai") |
|
|
| env = SupportTriageEnv() |
| task_ids = [args.task_id] if args.task_id else env.task_ids |
|
|
| client = None |
| if args.mode == "openai": |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) |
|
|
| started_at = time.monotonic() |
| episodes: list[EpisodeResult] = [] |
| for task_id in task_ids: |
| if task_id not in env.task_ids: |
| raise ValueError(f"Unknown task_id '{task_id}'") |
| episodes.append( |
| run_episode( |
| env=env, |
| task_id=task_id, |
| mode=args.mode, |
| client=client, |
| started_at=started_at, |
| runtime_limit_seconds=args.runtime_limit_seconds, |
| ) |
| ) |
|
|
| summary = { |
| "mode": args.mode, |
| "api_base_url": API_BASE_URL, |
| "model_name": MODEL_NAME, |
| "avg_score": round(sum(e.score for e in episodes) / len(episodes), 4), |
| "avg_final_reward": round(sum(e.final_reward for e in episodes) / len(episodes), 4), |
| "total_steps": int(sum(e.steps for e in episodes)), |
| "episodes": [asdict(e) for e in episodes], |
| } |
|
|
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|