#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os import time from dataclasses import asdict, dataclass from pathlib import Path from typing import Any from openai import OpenAI from support_triage_openenv import Action, SupportTriageEnv # Mandatory variables requested by organizers. API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" HF_TOKEN = os.getenv("HF_TOKEN") BENCHMARK = os.getenv("SUPPORT_TRIAGE_BENCHMARK", "support-triage-openenv") SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.9")) SYSTEM_PROMPT = ( "You are solving customer support ticket triage. " "Return exactly one JSON object with keys: " "action_type, ticket_id, priority, category, needs_escalation, message." ) RULE_POLICY: dict[str, list[dict[str, Any]]] = { "easy_password_reset": [ {"action_type": "read_ticket", "ticket_id": "T-1001"}, { "action_type": "classify_ticket", "ticket_id": "T-1001", "priority": "medium", "category": "account", "needs_escalation": False, }, { "action_type": "draft_reply", "message": ( "We will send a reset link to your email. For security, confirm the request " "from your registered email before using the reset link." ), }, {"action_type": "resolve_ticket", "ticket_id": "T-1001"}, ], "medium_billing_dispute": [ {"action_type": "read_ticket", "ticket_id": "T-2001"}, {"action_type": "read_ticket", "ticket_id": "T-2002"}, { "action_type": "classify_ticket", "ticket_id": "T-2001", "priority": "high", "category": "billing", "needs_escalation": False, }, { "action_type": "draft_reply", "message": ( "We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. " "Refund processing typically takes 3-5 business days." ), }, {"action_type": "resolve_ticket", "ticket_id": "T-2001"}, ], "hard_outage_incident": [ {"action_type": "read_ticket", "ticket_id": "T-3001"}, {"action_type": "read_ticket", "ticket_id": "T-3002"}, {"action_type": "read_ticket", "ticket_id": "T-3003"}, { "action_type": "classify_ticket", "ticket_id": "T-3001", "priority": "urgent", "category": "technical", "needs_escalation": True, }, { "action_type": "draft_reply", "message": ( "We have escalated this incident and are investigating now. " "The status page will carry updates while we continue incident response." ), }, {"action_type": "resolve_ticket", "ticket_id": "T-3001"}, ], } @dataclass class EpisodeResult: task_id: str steps: int score: float success: bool final_reward: float rewards: list[float] fallback_count: int def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) def _extract_json(text: str) -> str: text = text.strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError("No JSON object found in model response") return text[start : end + 1] def heuristic_action(task_id: str, step_idx: int) -> Action: plan = RULE_POLICY[task_id] idx = min(step_idx, len(plan) - 1) return Action.model_validate(plan[idx]) def llm_action(client: OpenAI, observation: dict[str, Any], state: dict[str, Any]) -> Action: prompt = json.dumps( { "instruction": "Pick the best next single action to maximize final task score.", "observation": observation, "state": state, }, ensure_ascii=True, ) completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ], temperature=0, max_tokens=220, stream=False, ) text = (completion.choices[0].message.content or "").strip() payload = json.loads(_extract_json(text)) return Action.model_validate(payload) def action_to_str(action: Action) -> str: if action.action_type == "read_ticket": return f"read_ticket({action.ticket_id})" if action.action_type == "classify_ticket": return ( f"classify_ticket({action.ticket_id},{action.priority},{action.category}," f"{str(bool(action.needs_escalation)).lower()})" ) if action.action_type == "draft_reply": length = len((action.message or "").strip()) return f"draft_reply(len={length})" if action.action_type == "resolve_ticket": return f"resolve_ticket({action.ticket_id})" return action.action_type def run_episode( env: SupportTriageEnv, task_id: str, mode: str, client: OpenAI | None, started_at: float, runtime_limit_seconds: int, ) -> EpisodeResult: obs = env.reset(task_id) done = False info: dict[str, Any] = {} rewards: list[float] = [] steps_taken = 0 fallback_count = 0 success = False score = 0.0 final_reward = 0.0 log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) while not done: if time.monotonic() - started_at > runtime_limit_seconds: raise TimeoutError(f"Runtime exceeded {runtime_limit_seconds}s") step_idx = env.state()["step_count"] if mode == "heuristic": action = heuristic_action(task_id, step_idx) else: assert client is not None try: action = llm_action(client, obs.model_dump(), env.state()) except Exception: fallback_count += 1 action = heuristic_action(task_id, step_idx) step_error: str | None = None try: obs, reward, done, info = env.step(action) reward_value = float(reward.value) except Exception as exc: step_error = str(exc) reward_value = 0.0 done = True steps_taken = step_idx + 1 rewards.append(reward_value) final_reward = reward_value log_step( step=steps_taken, action=action_to_str(action), reward=reward_value, done=done, error=step_error, ) if done: break score = max(0.0, min(1.0, float(info.get("grader_score", 0.0)))) success = score >= SUCCESS_SCORE_THRESHOLD log_end(success=success, steps=steps_taken, score=score, rewards=rewards) return EpisodeResult( task_id=task_id, steps=steps_taken, score=round(score, 4), success=success, final_reward=round(final_reward, 4), rewards=[round(r, 4) for r in rewards], fallback_count=fallback_count, ) def main() -> None: parser = argparse.ArgumentParser(description="Submission inference script.") parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai") parser.add_argument("--output", default="scores/inference_scores.json") parser.add_argument("--runtime-limit-seconds", type=int, default=1200) parser.add_argument("--task-id", default="", help="Optional single task id; default runs all tasks") args = parser.parse_args() if args.mode == "openai" and not HF_TOKEN: raise RuntimeError("HF_TOKEN is required for --mode openai") env = SupportTriageEnv() task_ids = [args.task_id] if args.task_id else env.task_ids client = None if args.mode == "openai": client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) started_at = time.monotonic() episodes: list[EpisodeResult] = [] for task_id in task_ids: if task_id not in env.task_ids: raise ValueError(f"Unknown task_id '{task_id}'") episodes.append( run_episode( env=env, task_id=task_id, mode=args.mode, client=client, started_at=started_at, runtime_limit_seconds=args.runtime_limit_seconds, ) ) summary = { "mode": args.mode, "api_base_url": API_BASE_URL, "model_name": MODEL_NAME, "avg_score": round(sum(e.score for e in episodes) / len(episodes), 4), "avg_final_reward": round(sum(e.final_reward for e in episodes) / len(episodes), 4), "total_steps": int(sum(e.steps for e in episodes)), "episodes": [asdict(e) for e in episodes], } output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") if __name__ == "__main__": main()