Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, List | |
| try: | |
| from openai import OpenAI | |
| import openai as openai_module | |
| except ImportError: | |
| OpenAI = None | |
| import openai as openai_module | |
| from support_queue_env.client import SupportQueueEnv | |
| from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation | |
| from support_queue_env.tasks import TASKS | |
| def load_dotenv_file(path: str = ".env") -> None: | |
| env_path = Path(path) | |
| if not env_path.exists(): | |
| return | |
| for raw_line in env_path.read_text(encoding="utf-8").splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("#") or "=" not in line: | |
| continue | |
| key, value = line.split("=", 1) | |
| key = key.strip() | |
| value = value.strip().strip('"').strip("'") | |
| if key and key not in os.environ: | |
| os.environ[key] = value | |
| load_dotenv_file() | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| API_KEY = os.getenv("API_KEY") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| PROXY_API_KEY = API_KEY or HF_TOKEN | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL") | |
| ALLOW_DIRECT_OPENAI = os.getenv("ALLOW_DIRECT_OPENAI") == "1" | |
| BENCHMARK = "support_queue_env" | |
| SUCCESS_SCORE_THRESHOLD = 0.80 | |
| MAX_TOKENS = 250 | |
| SCORE_EPSILON = 0.0001 | |
| def clamp_task_score(score: float) -> float: | |
| return min(max(score, SCORE_EPSILON), 1.0 - SCORE_EPSILON) | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None: | |
| error_value = "none" if error is None else error.replace("\n", " ") | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.4f} done={str(done).lower()} error={error_value}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={json.dumps([round(r, 4) for r in rewards])}", | |
| flush=True, | |
| ) | |
| def create_openai_client() -> Any: | |
| # Support both the newer API_KEY contract and the earlier HF_TOKEN contract. | |
| # In either case, all traffic still goes through API_BASE_URL. | |
| if not PROXY_API_KEY: | |
| return None | |
| if "api.openai.com" in API_BASE_URL and not ALLOW_DIRECT_OPENAI: | |
| print( | |
| "[DEBUG] Refusing to use direct OpenAI base URL. Set API_BASE_URL to the provided proxy, or set ALLOW_DIRECT_OPENAI=1 for local-only testing.", | |
| flush=True, | |
| ) | |
| return None | |
| if OpenAI is not None: | |
| return OpenAI(base_url=API_BASE_URL, api_key=PROXY_API_KEY) | |
| openai_module.api_base = API_BASE_URL | |
| openai_module.api_key = PROXY_API_KEY | |
| return openai_module | |
| def warmup_model_client(client: Any) -> None: | |
| if client is None: | |
| print("[DEBUG] No API_KEY/HF_TOKEN found; skipping model warmup.", flush=True) | |
| return | |
| try: | |
| if hasattr(client, "chat") and hasattr(client.chat, "completions"): | |
| client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": "Reply with ok."}, | |
| {"role": "user", "content": "ok"}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=2, | |
| stream=False, | |
| ) | |
| else: | |
| client.ChatCompletion.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": "Reply with ok."}, | |
| {"role": "user", "content": "ok"}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=2, | |
| stream=False, | |
| ) | |
| except Exception as exc: | |
| print(f"[DEBUG] Model warmup failed: {exc}", flush=True) | |
| def get_model_message( | |
| client: Any, | |
| step: int, | |
| observation: SupportQueueObservation, | |
| last_reward: float, | |
| history: List[str], | |
| ) -> str: | |
| if client is None: | |
| return "hello" | |
| prompt = ( | |
| "Return a short support-triage recommendation as JSON with fields priority, queue, disposition, summary, response. " | |
| f"Step: {step}. Last reward: {last_reward:.4f}. History: {history[-4:]}. Observation: {observation.model_dump_json()}" | |
| ) | |
| try: | |
| if hasattr(client, "chat") and hasattr(client.chat, "completions"): | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": "You are assisting a support triage agent."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| else: | |
| completion = client.ChatCompletion.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": "You are assisting a support triage agent."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| text = (completion["choices"][0]["message"]["content"] or "").strip() | |
| return text if text else "hello" | |
| except Exception as exc: | |
| print(f"[DEBUG] Model request failed: {exc}", flush=True) | |
| return "hello" | |
| def available_tasks() -> list[TaskCard]: | |
| return [ | |
| TaskCard( | |
| task_id=task.task_id, | |
| title=task.title, | |
| difficulty=task.difficulty, | |
| description=task.description, | |
| ticket_count=len(task.tickets), | |
| ) | |
| for task in TASKS | |
| ] | |
| def heuristic_action(observation: SupportQueueObservation) -> SupportQueueAction: | |
| text = " ".join( | |
| [ | |
| observation.ticket.subject, | |
| observation.ticket.body, | |
| " ".join(observation.ticket.recent_events), | |
| observation.task_title, | |
| ] | |
| ).lower() | |
| if any(word in text for word in ["password reset", "account is locked", "locked out"]): | |
| return SupportQueueAction( | |
| priority="P3", | |
| queue="technical", | |
| disposition="respond", | |
| summary="Customer account locked after password reset in the admin portal.", | |
| response=( | |
| "Thanks for reporting this. Please verify the account owner details and we will unlock the account and " | |
| "confirm the next reset step for you." | |
| ), | |
| confidence=0.82, | |
| ) | |
| if any(word in text for word in ["phishing", "credentials", "oauth", "unknown ip", "contractor", "security"]): | |
| return SupportQueueAction( | |
| priority="P1", | |
| queue="security", | |
| disposition="escalate", | |
| summary="Security issue involving phishing, credentials, or unknown OAuth access.", | |
| response=( | |
| "Thanks for flagging this quickly. This is escalated to our security team now. Please do not click the message " | |
| "again, revoke suspicious access where possible, and keep audit logs ready." | |
| ), | |
| confidence=0.9, | |
| ) | |
| if any(word in text for word in ["502", "500", "webhook", "login", "blocked", "outage", "rollout"]): | |
| priority = "P1" if any(word in text for word in ["all agents", "entire", "502", "blocked"]) else "P2" | |
| return SupportQueueAction( | |
| priority=priority, | |
| queue="technical", | |
| disposition="escalate", | |
| summary="Technical incident affecting login, webhook delivery, or a recent rollout.", | |
| response=( | |
| "I am escalating this incident to engineering right away. Please keep example timestamps and logs handy while " | |
| "we investigate the rollout behavior and urgent production impact." | |
| ), | |
| confidence=0.88, | |
| ) | |
| if any(word in text for word in ["renewal", "discount", "cfo", "quote"]): | |
| return SupportQueueAction( | |
| priority="P2", | |
| queue="success", | |
| disposition="escalate", | |
| summary="Renewal quote issue where the committed discount is blocking the CFO review.", | |
| response=( | |
| "I am escalating this to the account manager now. We will review the quote, confirm the discount commitment, " | |
| "and share the escalated renewal update as soon as possible." | |
| ), | |
| confidence=0.83, | |
| ) | |
| if any(word in text for word in ["cancel", "data export"]): | |
| return SupportQueueAction( | |
| priority="P3", | |
| queue="success", | |
| disposition="request_info", | |
| summary="Customer wants cancellation and a data export after verification.", | |
| response=( | |
| "I can help with the export and cancellation flow. Please verify that you are the account owner and confirm " | |
| "the workspace name so we can start the export safely." | |
| ), | |
| confidence=0.8, | |
| ) | |
| if any(word in text for word in ["invoice", "charged", "billed", "refund", "billing"]): | |
| unclear = any(word in text for word in ["maybe", "not fully sure", "thinks", "what details"]) | |
| return SupportQueueAction( | |
| priority="P2" if any(word in text for word in ["charged twice", "double billed", "two identical charges"]) else "P3", | |
| queue="billing", | |
| disposition="request_info" if unclear else "respond", | |
| summary=( | |
| "Billing issue is unclear because only one invoice is visible today." | |
| if unclear | |
| else "Duplicate charge appears tied to a specific invoice in billing." | |
| ), | |
| response=( | |
| "I can review this with billing. Please send the invoice number, charged amount, and the last four digits of " | |
| "the payment method so we can compare the records." | |
| if unclear | |
| else "I am checking this with our billing team now. If this is a duplicate charge, we will investigate the invoice and share the refund update for you." | |
| ), | |
| confidence=0.84, | |
| ) | |
| return SupportQueueAction( | |
| priority="P3", | |
| queue="technical", | |
| disposition="respond", | |
| summary="General product issue that needs standard technical follow-up.", | |
| response="Thanks for the report. We will verify the issue and share the next reset or troubleshooting step.", | |
| confidence=0.7, | |
| ) | |
| async def build_env() -> SupportQueueEnv: | |
| if ENV_BASE_URL: | |
| env = SupportQueueEnv(base_url=ENV_BASE_URL) | |
| connect = getattr(env, "connect", None) | |
| if callable(connect): | |
| maybe_coro = connect() | |
| if asyncio.iscoroutine(maybe_coro): | |
| await maybe_coro | |
| return env | |
| return await SupportQueueEnv.from_docker_image(LOCAL_IMAGE_NAME or "support-queue-openenv") | |
| async def run_task(client: Any, env: SupportQueueEnv, task: TaskCard) -> dict[str, Any]: | |
| history: List[str] = [] | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = clamp_task_score(0.0) | |
| success = False | |
| log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME) | |
| try: | |
| result = await env.reset(task_id=task.task_id) | |
| last_reward = 0.0 | |
| for step in range(1, task.ticket_count + 1): | |
| if result.done: | |
| break | |
| observation = result.observation | |
| _ = get_model_message(client, step, observation, last_reward, history) | |
| action = heuristic_action(observation) | |
| try: | |
| result = await env.step(action) | |
| except Exception as exc: | |
| action_payload = json.dumps(action.model_dump(), separators=(",", ":"), sort_keys=True) | |
| log_step(step=step, action=action_payload, reward=0.0, done=True, error=str(exc)) | |
| break | |
| reward = result.reward or 0.0 | |
| done = result.done | |
| error = None | |
| rewards.append(reward) | |
| steps_taken = step | |
| last_reward = reward | |
| action_payload = json.dumps(action.model_dump(), separators=(",", ":"), sort_keys=True) | |
| log_step(step=step, action=action_payload, reward=reward, done=done, error=error) | |
| history.append(f"Step {step}: {action_payload} -> reward {reward:+.2f}") | |
| if done: | |
| break | |
| score = sum(rewards) / len(rewards) if rewards else 0.0 | |
| score = clamp_task_score(score) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| except Exception as exc: | |
| print(f"[DEBUG] Task {task.task_id} failed: {exc}", flush=True) | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return { | |
| "task_id": task.task_id, | |
| "score": score, | |
| "steps": steps_taken, | |
| "rewards": rewards, | |
| "success": success, | |
| } | |
| async def main() -> None: | |
| client = create_openai_client() | |
| tasks = available_tasks() | |
| results: list[dict[str, Any]] = [] | |
| env: SupportQueueEnv | None = None | |
| try: | |
| warmup_model_client(client) | |
| env = await build_env() | |
| for task in tasks: | |
| results.append(await run_task(client, env, task)) | |
| except Exception as exc: | |
| print(f"[DEBUG] Environment bootstrap failed: {exc}", flush=True) | |
| for task in tasks: | |
| log_start(task=task.task_id, env=BENCHMARK, model=MODEL_NAME) | |
| log_end(success=False, steps=0, score=clamp_task_score(0.0), rewards=[]) | |
| results.append( | |
| { | |
| "task_id": task.task_id, | |
| "score": clamp_task_score(0.0), | |
| "steps": 0, | |
| "rewards": [], | |
| "success": False, | |
| } | |
| ) | |
| finally: | |
| if env is not None: | |
| try: | |
| await env.close() | |
| except Exception as exc: | |
| print(f"[DEBUG] env.close() error (container cleanup): {exc}", flush=True) | |
| aggregate = { | |
| "benchmark": BENCHMARK, | |
| "model": MODEL_NAME, | |
| "average_score": round(sum(item["score"] for item in results) / len(results), 4) if results else 0.0, | |
| "tasks": results, | |
| } | |
| with open("inference_results.json", "w", encoding="utf-8") as handle: | |
| json.dump(aggregate, handle, indent=2) | |
| if __name__ == "__main__": | |
| try: | |
| asyncio.run(main()) | |
| except Exception as exc: | |
| print(f"[DEBUG] Fatal inference error: {exc}", flush=True) | |