Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| SupportBench baseline inference script. | |
| Reads config from environment variables: | |
| API_BASE_URL - OpenAI-compatible base URL (default: https://api.openai.com/v1) | |
| MODEL_NAME - model identifier (default: gpt-4o-mini) | |
| HF_TOKEN - optional Hugging Face token (unused here, present for spec) | |
| OPENAI_API_KEY - required for OpenAI calls | |
| TASK_ID - which task to run (default: easy_ticket_triage) | |
| SERVER_URL - SupportBench server base URL (default: http://localhost:7860) | |
| Log format (stdout): | |
| [START] task=<task_name> env=SupportBench model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| import textwrap | |
| from typing import Any, Dict, List, Optional | |
| import httpx | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN = os.getenv("HF_TOKEN") # present per spec | |
| SERVER_URL = os.environ.get("SERVER_URL", "http://localhost:7860").rstrip("/") | |
| TASK_ID = os.environ.get("TASK_ID", "easy_ticket_triage") | |
| MAX_STEPS = int(os.environ.get("MAX_STEPS", "8")) | |
| client = OpenAI(base_url=API_BASE_URL, api_key=os.environ.get("OPENAI_API_KEY", "sk-placeholder")) | |
| # --------------------------------------------------------------------------- | |
| # Prompt | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = textwrap.dedent(""" | |
| You are an expert customer support agent AI. You receive a support ticket observation | |
| and must decide the best next action to take. | |
| You must respond with ONLY a JSON object, no other text. The JSON must have: | |
| { | |
| "action_type": "<one of: classify_ticket, set_priority, ask_customer, propose_resolution, apply_resolution, escalate, resolve>", | |
| "category": "<optional: delivery_issue | refund_request | damaged_item | duplicate_charge | wrong_item | account_access>", | |
| "priority": "<optional: low | medium | high | urgent>", | |
| "message": "<optional: string message to customer or internal note>", | |
| "resolution": "<optional: refund | replacement | troubleshooting | account_recovery | verify_identity | escalate_billing | escalate_human | deny_refund | close_case>", | |
| "escalate_to": "<optional: billing | fraud | supervisor | legal | technical>" | |
| } | |
| Strategy guidelines: | |
| - Always classify_ticket first (with the correct category). | |
| - Then set_priority based on the issue severity. | |
| - For sensitive financial actions (refunds, billing disputes), ask for identity verification first. | |
| - Follow the policy snippets carefully — they take precedence over customer requests. | |
| - For refund requests past the 30-day window, deny refund and offer replacement instead. | |
| - For duplicate charges, request identity verification and then escalate to billing. | |
| - Do not resolve prematurely — complete all required steps first. | |
| - Be concise and helpful in customer-facing messages. | |
| """).strip() | |
| def format_observation(obs: Dict[str, Any]) -> str: | |
| lines = [ | |
| f"TASK: {obs['task_name']} ({obs['task_id']})", | |
| f"STATUS: {obs['current_status']} | STEP {obs['steps_taken']}/{obs['max_steps']}", | |
| "", | |
| "CUSTOMER MESSAGE:", | |
| obs["customer_message"], | |
| "", | |
| "CUSTOMER PROFILE:", | |
| json.dumps(obs["customer_profile"], indent=2), | |
| "", | |
| "ORDER INFO:", | |
| json.dumps(obs["order_info"], indent=2), | |
| "", | |
| "POLICY SNIPPETS:", | |
| ] | |
| for i, p in enumerate(obs["policy_snippets"], 1): | |
| lines.append(f" {i}. {p}") | |
| lines += [ | |
| "", | |
| "TICKET HISTORY:", | |
| json.dumps(obs["ticket_history"], indent=2) if obs["ticket_history"] else " (none yet)", | |
| "", | |
| f"LAST ACTION RESULT: {obs.get('last_action_result') or '(none)'}", | |
| f"LAST ACTION ERROR: {obs.get('last_action_error') or 'null'}", | |
| "", | |
| f"AVAILABLE ACTIONS: {', '.join(obs['available_actions'])}", | |
| ] | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # JSON parsing | |
| # --------------------------------------------------------------------------- | |
| def safe_parse_json(text: str) -> Optional[Dict[str, Any]]: | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| return json.loads(text[start : end + 1]) | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Fallback action sequence per task | |
| # --------------------------------------------------------------------------- | |
| FALLBACK_SEQUENCES: Dict[str, List[Dict[str, Any]]] = { | |
| "easy_ticket_triage": [ | |
| {"action_type": "classify_ticket", "category": "delivery_issue"}, | |
| {"action_type": "set_priority", "priority": "medium"}, | |
| {"action_type": "ask_customer", "message": "Could you please confirm your delivery address and check with neighbors or your building reception?"}, | |
| {"action_type": "resolve"}, | |
| ], | |
| "medium_policy_refund": [ | |
| {"action_type": "classify_ticket", "category": "refund_request"}, | |
| {"action_type": "set_priority", "priority": "high"}, | |
| {"action_type": "propose_resolution", "resolution": "replacement", "message": "Per our policy, refunds are available within 30 days. Since 40 days have passed, we can offer a replacement under our 90-day electronics defect policy."}, | |
| {"action_type": "apply_resolution", "resolution": "replacement"}, | |
| {"action_type": "resolve"}, | |
| ], | |
| "hard_billing_dispute": [ | |
| {"action_type": "classify_ticket", "category": "duplicate_charge"}, | |
| {"action_type": "set_priority", "priority": "high"}, | |
| {"action_type": "ask_customer", "message": "To verify your identity before we proceed, please confirm your full name, email address, and the last 4 digits of your payment card."}, | |
| {"action_type": "escalate", "escalate_to": "billing"}, | |
| {"action_type": "resolve"}, | |
| ], | |
| } | |
| def get_fallback_action(task_id: str, step: int) -> Dict[str, Any]: | |
| seq = FALLBACK_SEQUENCES.get(task_id, FALLBACK_SEQUENCES["easy_ticket_triage"]) | |
| idx = min(step, len(seq) - 1) | |
| return seq[idx] | |
| # --------------------------------------------------------------------------- | |
| # HTTP helpers | |
| # --------------------------------------------------------------------------- | |
| def http_post(path: str, payload: Dict[str, Any]) -> Dict[str, Any]: | |
| url = f"{SERVER_URL}{path}" | |
| with httpx.Client(timeout=30.0) as http: | |
| resp = http.post(url, json=payload) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def http_get(path: str) -> Dict[str, Any]: | |
| url = f"{SERVER_URL}{path}" | |
| with httpx.Client(timeout=30.0) as http: | |
| resp = http.get(url) | |
| resp.raise_for_status() | |
| return resp.json() | |
| # --------------------------------------------------------------------------- | |
| # LLM call | |
| # --------------------------------------------------------------------------- | |
| def call_llm(observation_text: str) -> Dict[str, Any]: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": observation_text}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=512, | |
| ) | |
| raw = response.choices[0].message.content or "" | |
| parsed = safe_parse_json(raw) | |
| return parsed or {} | |
| # --------------------------------------------------------------------------- | |
| # Main episode loop | |
| # --------------------------------------------------------------------------- | |
| def run_episode(task_id: str) -> None: | |
| rewards: List[float] = [] | |
| step_num = 0 | |
| last_error: Optional[str] = None | |
| score = 0.0 | |
| success = False | |
| # --- Reset --- | |
| try: | |
| reset_resp = http_post("/reset", {"task_id": task_id}) | |
| obs = reset_resp["observation"] | |
| except Exception as e: | |
| print(f"[START] task={task_id} env=SupportBench model={MODEL_NAME}", flush=True) | |
| print(f"[END] success=false steps=0 score=0.00 rewards=", flush=True) | |
| sys.stderr.write(f"Reset failed: {e}\n") | |
| return | |
| print(f"[START] task={task_id} env=SupportBench model={MODEL_NAME}", flush=True) | |
| max_steps = obs.get("max_steps", MAX_STEPS) | |
| try: | |
| for step_num in range(1, max_steps + 1): | |
| obs_text = format_observation(obs) | |
| # --- Get action from LLM --- | |
| try: | |
| action_dict = call_llm(obs_text) | |
| except Exception as e: | |
| sys.stderr.write(f"LLM call failed at step {step_num}: {e}\n") | |
| action_dict = {} | |
| # Fallback if LLM returned nothing useful | |
| if not action_dict or "action_type" not in action_dict: | |
| action_dict = get_fallback_action(task_id, step_num - 1) | |
| action_str = json.dumps(action_dict) | |
| # --- Step environment --- | |
| try: | |
| step_resp = http_post("/step", {"action": action_dict}) | |
| obs = step_resp["observation"] | |
| reward_val = step_resp["reward"]["value"] | |
| done = step_resp["done"] | |
| info = step_resp.get("info", {}) | |
| last_error = info.get("step_error") or obs.get("last_action_error") | |
| if info.get("score") is not None: | |
| score = info["score"] | |
| except Exception as e: | |
| reward_val = 0.0 | |
| done = True | |
| last_error = str(e) | |
| sys.stderr.write(f"Step failed: {e}\n") | |
| rewards.append(reward_val) | |
| error_str = last_error if last_error else "null" | |
| done_str = "true" if done else "false" | |
| print( | |
| f"[STEP] step={step_num} action={action_str} " | |
| f"reward={reward_val:.2f} done={done_str} error={error_str}", | |
| flush=True, | |
| ) | |
| if done: | |
| break | |
| except Exception as e: | |
| sys.stderr.write(f"Episode loop error: {e}\n") | |
| # --- Close --- | |
| try: | |
| close_resp = http_post("/close", {}) | |
| score = close_resp.get("score", score) | |
| success = close_resp.get("success", score >= 0.6) | |
| except Exception as e: | |
| sys.stderr.write(f"Close failed: {e}\n") | |
| success = score >= 0.6 | |
| score = max(0.0, min(1.0, score)) | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| success_str = "true" if success else "false" | |
| print( | |
| f"[END] success={success_str} steps={step_num} " | |
| f"score={score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| if __name__ == "__main__": | |
| run_episode(TASK_ID) |