#!/usr/bin/env python3 """ SupportBench baseline inference script. Reads config from environment variables: API_BASE_URL - OpenAI-compatible base URL (default: https://api.openai.com/v1) MODEL_NAME - model identifier (default: gpt-4o-mini) HF_TOKEN - optional Hugging Face token (unused here, present for spec) OPENAI_API_KEY - required for OpenAI calls TASK_ID - which task to run (default: easy_ticket_triage) SERVER_URL - SupportBench server base URL (default: http://localhost:7860) Log format (stdout): [START] task= env=SupportBench model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= """ from __future__ import annotations import json import os import sys import textwrap from typing import Any, Dict, List, Optional import httpx from openai import OpenAI # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") HF_TOKEN = os.getenv("HF_TOKEN") # present per spec SERVER_URL = os.environ.get("SERVER_URL", "http://localhost:7860").rstrip("/") TASK_ID = os.environ.get("TASK_ID", "easy_ticket_triage") MAX_STEPS = int(os.environ.get("MAX_STEPS", "8")) client = OpenAI(base_url=API_BASE_URL, api_key=os.environ.get("OPENAI_API_KEY", "sk-placeholder")) # --------------------------------------------------------------------------- # Prompt # --------------------------------------------------------------------------- SYSTEM_PROMPT = textwrap.dedent(""" You are an expert customer support agent AI. You receive a support ticket observation and must decide the best next action to take. You must respond with ONLY a JSON object, no other text. The JSON must have: { "action_type": "", "category": "", "priority": "", "message": "", "resolution": "", "escalate_to": "" } Strategy guidelines: - Always classify_ticket first (with the correct category). - Then set_priority based on the issue severity. - For sensitive financial actions (refunds, billing disputes), ask for identity verification first. - Follow the policy snippets carefully — they take precedence over customer requests. - For refund requests past the 30-day window, deny refund and offer replacement instead. - For duplicate charges, request identity verification and then escalate to billing. - Do not resolve prematurely — complete all required steps first. - Be concise and helpful in customer-facing messages. """).strip() def format_observation(obs: Dict[str, Any]) -> str: lines = [ f"TASK: {obs['task_name']} ({obs['task_id']})", f"STATUS: {obs['current_status']} | STEP {obs['steps_taken']}/{obs['max_steps']}", "", "CUSTOMER MESSAGE:", obs["customer_message"], "", "CUSTOMER PROFILE:", json.dumps(obs["customer_profile"], indent=2), "", "ORDER INFO:", json.dumps(obs["order_info"], indent=2), "", "POLICY SNIPPETS:", ] for i, p in enumerate(obs["policy_snippets"], 1): lines.append(f" {i}. {p}") lines += [ "", "TICKET HISTORY:", json.dumps(obs["ticket_history"], indent=2) if obs["ticket_history"] else " (none yet)", "", f"LAST ACTION RESULT: {obs.get('last_action_result') or '(none)'}", f"LAST ACTION ERROR: {obs.get('last_action_error') or 'null'}", "", f"AVAILABLE ACTIONS: {', '.join(obs['available_actions'])}", ] return "\n".join(lines) # --------------------------------------------------------------------------- # JSON parsing # --------------------------------------------------------------------------- def safe_parse_json(text: str) -> Optional[Dict[str, Any]]: text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(text[start : end + 1]) except json.JSONDecodeError: pass return None # --------------------------------------------------------------------------- # Fallback action sequence per task # --------------------------------------------------------------------------- FALLBACK_SEQUENCES: Dict[str, List[Dict[str, Any]]] = { "easy_ticket_triage": [ {"action_type": "classify_ticket", "category": "delivery_issue"}, {"action_type": "set_priority", "priority": "medium"}, {"action_type": "ask_customer", "message": "Could you please confirm your delivery address and check with neighbors or your building reception?"}, {"action_type": "resolve"}, ], "medium_policy_refund": [ {"action_type": "classify_ticket", "category": "refund_request"}, {"action_type": "set_priority", "priority": "high"}, {"action_type": "propose_resolution", "resolution": "replacement", "message": "Per our policy, refunds are available within 30 days. Since 40 days have passed, we can offer a replacement under our 90-day electronics defect policy."}, {"action_type": "apply_resolution", "resolution": "replacement"}, {"action_type": "resolve"}, ], "hard_billing_dispute": [ {"action_type": "classify_ticket", "category": "duplicate_charge"}, {"action_type": "set_priority", "priority": "high"}, {"action_type": "ask_customer", "message": "To verify your identity before we proceed, please confirm your full name, email address, and the last 4 digits of your payment card."}, {"action_type": "escalate", "escalate_to": "billing"}, {"action_type": "resolve"}, ], } def get_fallback_action(task_id: str, step: int) -> Dict[str, Any]: seq = FALLBACK_SEQUENCES.get(task_id, FALLBACK_SEQUENCES["easy_ticket_triage"]) idx = min(step, len(seq) - 1) return seq[idx] # --------------------------------------------------------------------------- # HTTP helpers # --------------------------------------------------------------------------- def http_post(path: str, payload: Dict[str, Any]) -> Dict[str, Any]: url = f"{SERVER_URL}{path}" with httpx.Client(timeout=30.0) as http: resp = http.post(url, json=payload) resp.raise_for_status() return resp.json() def http_get(path: str) -> Dict[str, Any]: url = f"{SERVER_URL}{path}" with httpx.Client(timeout=30.0) as http: resp = http.get(url) resp.raise_for_status() return resp.json() # --------------------------------------------------------------------------- # LLM call # --------------------------------------------------------------------------- def call_llm(observation_text: str) -> Dict[str, Any]: response = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": observation_text}, ], temperature=0.0, max_tokens=512, ) raw = response.choices[0].message.content or "" parsed = safe_parse_json(raw) return parsed or {} # --------------------------------------------------------------------------- # Main episode loop # --------------------------------------------------------------------------- def run_episode(task_id: str) -> None: rewards: List[float] = [] step_num = 0 last_error: Optional[str] = None score = 0.0 success = False # --- Reset --- try: reset_resp = http_post("/reset", {"task_id": task_id}) obs = reset_resp["observation"] except Exception as e: print(f"[START] task={task_id} env=SupportBench model={MODEL_NAME}", flush=True) print(f"[END] success=false steps=0 score=0.00 rewards=", flush=True) sys.stderr.write(f"Reset failed: {e}\n") return print(f"[START] task={task_id} env=SupportBench model={MODEL_NAME}", flush=True) max_steps = obs.get("max_steps", MAX_STEPS) try: for step_num in range(1, max_steps + 1): obs_text = format_observation(obs) # --- Get action from LLM --- try: action_dict = call_llm(obs_text) except Exception as e: sys.stderr.write(f"LLM call failed at step {step_num}: {e}\n") action_dict = {} # Fallback if LLM returned nothing useful if not action_dict or "action_type" not in action_dict: action_dict = get_fallback_action(task_id, step_num - 1) action_str = json.dumps(action_dict) # --- Step environment --- try: step_resp = http_post("/step", {"action": action_dict}) obs = step_resp["observation"] reward_val = step_resp["reward"]["value"] done = step_resp["done"] info = step_resp.get("info", {}) last_error = info.get("step_error") or obs.get("last_action_error") if info.get("score") is not None: score = info["score"] except Exception as e: reward_val = 0.0 done = True last_error = str(e) sys.stderr.write(f"Step failed: {e}\n") rewards.append(reward_val) error_str = last_error if last_error else "null" done_str = "true" if done else "false" print( f"[STEP] step={step_num} action={action_str} " f"reward={reward_val:.2f} done={done_str} error={error_str}", flush=True, ) if done: break except Exception as e: sys.stderr.write(f"Episode loop error: {e}\n") # --- Close --- try: close_resp = http_post("/close", {}) score = close_resp.get("score", score) success = close_resp.get("success", score >= 0.6) except Exception as e: sys.stderr.write(f"Close failed: {e}\n") success = score >= 0.6 score = max(0.0, min(1.0, score)) rewards_str = ",".join(f"{r:.2f}" for r in rewards) success_str = "true" if success else "false" print( f"[END] success={success_str} steps={step_num} " f"score={score:.2f} rewards={rewards_str}", flush=True, ) if __name__ == "__main__": run_episode(TASK_ID)