Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| from openai import OpenAI | |
| from support_triage_openenv import Action, SupportTriageEnv | |
| # Mandatory variables requested by organizers. | |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" | |
| MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BENCHMARK = os.getenv("SUPPORT_TRIAGE_BENCHMARK", "support-triage-openenv") | |
| SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.9")) | |
| SYSTEM_PROMPT = ( | |
| "You are solving customer support ticket triage. " | |
| "Return exactly one JSON object with keys: " | |
| "action_type, ticket_id, priority, category, needs_escalation, message." | |
| ) | |
| RULE_POLICY: dict[str, list[dict[str, Any]]] = { | |
| "easy_password_reset": [ | |
| {"action_type": "read_ticket", "ticket_id": "T-1001"}, | |
| { | |
| "action_type": "classify_ticket", | |
| "ticket_id": "T-1001", | |
| "priority": "medium", | |
| "category": "account", | |
| "needs_escalation": False, | |
| }, | |
| { | |
| "action_type": "draft_reply", | |
| "message": ( | |
| "We will send a reset link to your email. For security, confirm the request " | |
| "from your registered email before using the reset link." | |
| ), | |
| }, | |
| {"action_type": "resolve_ticket", "ticket_id": "T-1001"}, | |
| ], | |
| "medium_billing_dispute": [ | |
| {"action_type": "read_ticket", "ticket_id": "T-2001"}, | |
| {"action_type": "read_ticket", "ticket_id": "T-2002"}, | |
| { | |
| "action_type": "classify_ticket", | |
| "ticket_id": "T-2001", | |
| "priority": "high", | |
| "category": "billing", | |
| "needs_escalation": False, | |
| }, | |
| { | |
| "action_type": "draft_reply", | |
| "message": ( | |
| "We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. " | |
| "Refund processing typically takes 3-5 business days." | |
| ), | |
| }, | |
| {"action_type": "resolve_ticket", "ticket_id": "T-2001"}, | |
| ], | |
| "hard_outage_incident": [ | |
| {"action_type": "read_ticket", "ticket_id": "T-3001"}, | |
| {"action_type": "read_ticket", "ticket_id": "T-3002"}, | |
| {"action_type": "read_ticket", "ticket_id": "T-3003"}, | |
| { | |
| "action_type": "classify_ticket", | |
| "ticket_id": "T-3001", | |
| "priority": "urgent", | |
| "category": "technical", | |
| "needs_escalation": True, | |
| }, | |
| { | |
| "action_type": "draft_reply", | |
| "message": ( | |
| "We have escalated this incident and are investigating now. " | |
| "The status page will carry updates while we continue incident response." | |
| ), | |
| }, | |
| {"action_type": "resolve_ticket", "ticket_id": "T-3001"}, | |
| ], | |
| "easy_trial_extension": [ | |
| {"action_type": "read_ticket", "ticket_id": "T-4001"}, | |
| { | |
| "action_type": "classify_ticket", | |
| "ticket_id": "T-4001", | |
| "priority": "low", | |
| "category": "general", | |
| "needs_escalation": False, | |
| }, | |
| { | |
| "action_type": "draft_reply", | |
| "message": ( | |
| "We can review a trial extension based on eligibility. " | |
| "Please check billing settings before the next renewal so the account stays aligned." | |
| ), | |
| }, | |
| {"action_type": "resolve_ticket", "ticket_id": "T-4001"}, | |
| ], | |
| "medium_abuse_phishing": [ | |
| {"action_type": "read_ticket", "ticket_id": "T-5001"}, | |
| {"action_type": "read_ticket", "ticket_id": "T-5002"}, | |
| { | |
| "action_type": "classify_ticket", | |
| "ticket_id": "T-5001", | |
| "priority": "high", | |
| "category": "abuse", | |
| "needs_escalation": True, | |
| }, | |
| { | |
| "action_type": "draft_reply", | |
| "message": ( | |
| "We are escalating this phishing report to the abuse team. " | |
| "Please preserve evidence such as headers while we review blocked indicators and sender details." | |
| ), | |
| }, | |
| {"action_type": "resolve_ticket", "ticket_id": "T-5001"}, | |
| ], | |
| "hard_privacy_deletion": [ | |
| {"action_type": "read_ticket", "ticket_id": "T-6001"}, | |
| {"action_type": "read_ticket", "ticket_id": "T-6002"}, | |
| {"action_type": "read_ticket", "ticket_id": "T-6003"}, | |
| { | |
| "action_type": "classify_ticket", | |
| "ticket_id": "T-6001", | |
| "priority": "high", | |
| "category": "account", | |
| "needs_escalation": True, | |
| }, | |
| { | |
| "action_type": "draft_reply", | |
| "message": ( | |
| "We have routed the data deletion request to the privacy team. " | |
| "Identity verification is required, and completion is normally within 30 days." | |
| ), | |
| }, | |
| {"action_type": "resolve_ticket", "ticket_id": "T-6001"}, | |
| ], | |
| } | |
| class EpisodeResult: | |
| task_id: str | |
| steps: int | |
| score: float | |
| success: bool | |
| final_reward: float | |
| rewards: list[float] | |
| fallback_count: int | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) | |
| def _extract_json(text: str) -> str: | |
| text = text.strip() | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| raise ValueError("No JSON object found in model response") | |
| return text[start : end + 1] | |
| def heuristic_action(task_id: str, step_idx: int) -> Action: | |
| plan = RULE_POLICY[task_id] | |
| idx = min(step_idx, len(plan) - 1) | |
| return Action.model_validate(plan[idx]) | |
| def llm_action(client: OpenAI, observation: dict[str, Any], state: dict[str, Any]) -> Action: | |
| prompt = json.dumps( | |
| { | |
| "instruction": "Pick the best next single action to maximize final task score.", | |
| "observation": observation, | |
| "state": state, | |
| }, | |
| ensure_ascii=True, | |
| ) | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0, | |
| max_tokens=220, | |
| stream=False, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| payload = json.loads(_extract_json(text)) | |
| return Action.model_validate(payload) | |
| def action_to_str(action: Action) -> str: | |
| if action.action_type == "read_ticket": | |
| return f"read_ticket({action.ticket_id})" | |
| if action.action_type == "classify_ticket": | |
| return ( | |
| f"classify_ticket({action.ticket_id},{action.priority},{action.category}," | |
| f"{str(bool(action.needs_escalation)).lower()})" | |
| ) | |
| if action.action_type == "draft_reply": | |
| length = len((action.message or "").strip()) | |
| return f"draft_reply(len={length})" | |
| if action.action_type == "resolve_ticket": | |
| return f"resolve_ticket({action.ticket_id})" | |
| return action.action_type | |
| def run_episode( | |
| env: SupportTriageEnv, | |
| task_id: str, | |
| mode: str, | |
| client: OpenAI | None, | |
| started_at: float, | |
| runtime_limit_seconds: int, | |
| ) -> EpisodeResult: | |
| obs = env.reset(task_id) | |
| done = False | |
| info: dict[str, Any] = {} | |
| rewards: list[float] = [] | |
| steps_taken = 0 | |
| fallback_count = 0 | |
| success = False | |
| score = 0.0 | |
| final_reward = 0.0 | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| while not done: | |
| if time.monotonic() - started_at > runtime_limit_seconds: | |
| raise TimeoutError(f"Runtime exceeded {runtime_limit_seconds}s") | |
| step_idx = env.state()["step_count"] | |
| if mode == "heuristic": | |
| action = heuristic_action(task_id, step_idx) | |
| else: | |
| assert client is not None | |
| try: | |
| action = llm_action(client, obs.model_dump(), env.state()) | |
| except Exception: | |
| fallback_count += 1 | |
| action = heuristic_action(task_id, step_idx) | |
| step_error: str | None = None | |
| try: | |
| obs, reward, done, info = env.step(action) | |
| reward_value = float(reward.value) | |
| except Exception as exc: | |
| step_error = str(exc) | |
| reward_value = 0.0 | |
| done = True | |
| steps_taken = step_idx + 1 | |
| rewards.append(reward_value) | |
| final_reward = reward_value | |
| log_step( | |
| step=steps_taken, | |
| action=action_to_str(action), | |
| reward=reward_value, | |
| done=done, | |
| error=step_error, | |
| ) | |
| if done: | |
| break | |
| score = max(0.0, min(1.0, float(info.get("grader_score", 0.0)))) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return EpisodeResult( | |
| task_id=task_id, | |
| steps=steps_taken, | |
| score=round(score, 4), | |
| success=success, | |
| final_reward=round(final_reward, 4), | |
| rewards=[round(r, 4) for r in rewards], | |
| fallback_count=fallback_count, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Submission inference script.") | |
| parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai") | |
| parser.add_argument("--output", default="scores/inference_scores.json") | |
| parser.add_argument("--runtime-limit-seconds", type=int, default=1200) | |
| parser.add_argument("--task-id", default="", help="Optional single task id; default runs all tasks") | |
| args = parser.parse_args() | |
| if args.mode == "openai" and not HF_TOKEN: | |
| raise RuntimeError("HF_TOKEN is required for --mode openai") | |
| env = SupportTriageEnv() | |
| task_ids = [args.task_id] if args.task_id else env.task_ids | |
| client = None | |
| if args.mode == "openai": | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| started_at = time.monotonic() | |
| episodes: list[EpisodeResult] = [] | |
| for task_id in task_ids: | |
| if task_id not in env.task_ids: | |
| raise ValueError(f"Unknown task_id '{task_id}'") | |
| episodes.append( | |
| run_episode( | |
| env=env, | |
| task_id=task_id, | |
| mode=args.mode, | |
| client=client, | |
| started_at=started_at, | |
| runtime_limit_seconds=args.runtime_limit_seconds, | |
| ) | |
| ) | |
| summary = { | |
| "mode": args.mode, | |
| "api_base_url": API_BASE_URL, | |
| "model_name": MODEL_NAME, | |
| "avg_score": round(sum(e.score for e in episodes) / len(episodes), 4), | |
| "avg_final_reward": round(sum(e.final_reward for e in episodes) / len(episodes), 4), | |
| "total_steps": int(sum(e.steps for e in episodes)), | |
| "episodes": [asdict(e) for e in episodes], | |
| } | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| if __name__ == "__main__": | |
| main() | |