| """ |
| Baseline inference for support-triage-openenv. |
| |
| Required environment variables before submission: |
| - API_BASE_URL |
| - MODEL_NAME |
| - HF_TOKEN |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| from dataclasses import asdict, dataclass |
| from typing import Dict, Optional |
|
|
| from openai import OpenAI |
|
|
| from support_triage_env.models import SupportTriageAction, SupportTriageObservation |
| from support_triage_env.server.environment import SupportTriageEnvironment |
| from support_triage_env.tasks import TASK_ORDER |
|
|
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") |
| MODEL_NAME = os.getenv("MODEL_NAME") |
|
|
| MAX_STEPS = 14 |
| TEMPERATURE = 0.1 |
| MAX_TOKENS = 220 |
|
|
| ACTION_TYPES = { |
| "set_priority", |
| "route_team", |
| "add_tag", |
| "draft_reply", |
| "request_info", |
| "close_ticket", |
| "noop", |
| } |
|
|
| SYSTEM_PROMPT = ( |
| "You are a customer support triage agent operating in an RL environment. " |
| "Return exactly one JSON object with keys action_type and value. " |
| "Valid action_type values are: set_priority, route_team, add_tag, " |
| "draft_reply, request_info, close_ticket, noop. " |
| "Do not include markdown, explanations, or extra text." |
| ) |
|
|
|
|
| @dataclass |
| class EpisodeReport: |
| task_id: str |
| steps: int |
| score: float |
| breakdown: Dict[str, float] |
|
|
|
|
| def build_user_prompt(step: int, obs: SupportTriageObservation) -> str: |
| return ( |
| f"Step: {step}\n" |
| f"Task: {obs.task_id} ({obs.difficulty})\n" |
| f"Objective: {obs.objective}\n" |
| f"Title: {obs.title}\n" |
| f"Customer tier: {obs.customer_tier}\n" |
| f"Customer message: {obs.customer_message}\n" |
| f"Current priority: {obs.priority}\n" |
| f"Current team: {obs.routed_team}\n" |
| f"Current tags: {obs.tags}\n" |
| f"Info requested: {obs.info_requested}\n" |
| f"Current draft reply: {obs.draft_reply}\n" |
| f"Steps remaining: {obs.steps_remaining}\n" |
| f"Last feedback: {obs.last_feedback}\n" |
| f"Allowed actions: {obs.allowed_actions}\n" |
| "Respond with JSON only." |
| ) |
|
|
|
|
| def _extract_json(text: str) -> Optional[Dict[str, object]]: |
| text = (text or "").strip() |
| if not text: |
| return None |
|
|
| try: |
| parsed = json.loads(text) |
| if isinstance(parsed, dict): |
| return parsed |
| except json.JSONDecodeError: |
| pass |
|
|
| match = re.search(r"\{.*\}", text, re.DOTALL) |
| if not match: |
| return None |
|
|
| try: |
| parsed = json.loads(match.group(0)) |
| except json.JSONDecodeError: |
| return None |
|
|
| return parsed if isinstance(parsed, dict) else None |
|
|
|
|
| def fallback_action(obs: SupportTriageObservation) -> SupportTriageAction: |
| |
| if not obs.priority: |
| mapping = { |
| "easy_password_reset": "medium", |
| "medium_double_charge": "high", |
| "hard_account_takeover": "urgent", |
| } |
| return SupportTriageAction(action_type="set_priority", value=mapping.get(obs.task_id, "medium")) |
|
|
| if not obs.routed_team: |
| mapping = { |
| "easy_password_reset": "account", |
| "medium_double_charge": "billing", |
| "hard_account_takeover": "trust_safety", |
| } |
| return SupportTriageAction(action_type="route_team", value=mapping.get(obs.task_id, "technical")) |
|
|
| if obs.task_id == "easy_password_reset" and "password-reset" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="password-reset") |
| if obs.task_id == "easy_password_reset" and "login" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="login") |
|
|
| if obs.task_id == "medium_double_charge" and "refund" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="refund") |
| if obs.task_id == "medium_double_charge" and "double-charge" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="double-charge") |
| if obs.task_id == "medium_double_charge" and "vip" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="vip") |
|
|
| if obs.task_id == "hard_account_takeover" and "security" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="security") |
| if obs.task_id == "hard_account_takeover" and "account-takeover" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="account-takeover") |
| if obs.task_id == "hard_account_takeover" and "fraud" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="fraud") |
| if obs.task_id == "hard_account_takeover" and "content-abuse" not in obs.tags: |
| return SupportTriageAction(action_type="add_tag", value="content-abuse") |
|
|
| if obs.task_id == "easy_password_reset" and not obs.draft_reply: |
| return SupportTriageAction( |
| action_type="draft_reply", |
| value=( |
| "Sorry for the login trouble. Please use the reset link again and " |
| "enable 2FA after login. If this continues, support can verify your token." |
| ), |
| ) |
|
|
| if obs.task_id == "medium_double_charge" and not obs.info_requested: |
| return SupportTriageAction( |
| action_type="request_info", |
| value="Please share the transaction ID and last 4 digits of the charged card.", |
| ) |
|
|
| if obs.task_id == "medium_double_charge" and not obs.draft_reply: |
| return SupportTriageAction( |
| action_type="draft_reply", |
| value=( |
| "Sorry for this frustration. Our billing team will investigate the " |
| "double charge and process any eligible refund after verification." |
| ), |
| ) |
|
|
| if obs.task_id == "hard_account_takeover" and not obs.info_requested: |
| return SupportTriageAction( |
| action_type="request_info", |
| value="Please share screenshot evidence, timestamps, and the suspicious order ID.", |
| ) |
|
|
| if obs.task_id == "hard_account_takeover" and not obs.draft_reply: |
| return SupportTriageAction( |
| action_type="draft_reply", |
| value=( |
| "We have escalated this security case. Please secure your account, reset " |
| "password now, and enable two-factor authentication immediately." |
| ), |
| ) |
|
|
| return SupportTriageAction(action_type="close_ticket", value="") |
|
|
|
|
| def parse_action(response_text: str, obs: SupportTriageObservation) -> SupportTriageAction: |
| parsed = _extract_json(response_text) |
| if not parsed: |
| return fallback_action(obs) |
|
|
| action_type = str(parsed.get("action_type", "noop")).strip() |
| value_obj = parsed.get("value") |
| value = "" if value_obj is None else str(value_obj) |
|
|
| if action_type not in ACTION_TYPES: |
| return fallback_action(obs) |
|
|
| return SupportTriageAction(action_type=action_type, value=value) |
|
|
|
|
| def run_task(client: OpenAI, task_id: str) -> EpisodeReport: |
| env = SupportTriageEnvironment() |
| obs = env.reset(task_id=task_id) |
|
|
| for step in range(1, MAX_STEPS + 1): |
| if obs.done: |
| break |
|
|
| user_prompt = build_user_prompt(step, obs) |
|
|
| try: |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| temperature=TEMPERATURE, |
| max_tokens=MAX_TOKENS, |
| stream=False, |
| ) |
| response_text = completion.choices[0].message.content or "" |
| except Exception as exc: |
| print(f"Model call failed on {task_id} step {step}: {exc}. Falling back to heuristic.") |
| response_text = "" |
|
|
| action = parse_action(response_text, obs) |
| obs = env.step(action) |
|
|
| print( |
| f"[{task_id}] step={step} action={action.action_type}:{action.value} " |
| f"reward={obs.reward:+.3f} done={obs.done}" |
| ) |
|
|
| if obs.done: |
| break |
|
|
| final = env.evaluate() |
| return EpisodeReport( |
| task_id=task_id, |
| steps=int(final["steps"]), |
| score=float(final["score"]), |
| breakdown=dict(final["breakdown"]), |
| ) |
|
|
|
|
| def main() -> None: |
| if not API_KEY: |
| raise RuntimeError("Missing HF_TOKEN (or OPENAI_API_KEY fallback) environment variable.") |
| if not MODEL_NAME: |
| raise RuntimeError("Missing MODEL_NAME environment variable.") |
|
|
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| reports = [run_task(client, task_id) for task_id in TASK_ORDER] |
|
|
| avg_score = sum(r.score for r in reports) / len(reports) |
| print("\n=== Baseline Scores ===") |
| for report in reports: |
| print(f"{report.task_id}: score={report.score:.4f} steps={report.steps}") |
| print(f"Average score: {avg_score:.4f}") |
|
|
| payload = { |
| "model": MODEL_NAME, |
| "api_base_url": API_BASE_URL, |
| "average_score": round(avg_score, 4), |
| "tasks": [asdict(report) for report in reports], |
| } |
|
|
| with open("baseline_scores.json", "w", encoding="utf-8") as f: |
| json.dump(payload, f, indent=2) |
|
|
| print("Saved baseline_scores.json") |
|
|
|
|
| if __name__ == "__main__": |
| main()
|
|
|