Spaces:
Running
Running
| """ | |
| Inference Script — TriageSieve-OpenEnv | |
| ======================================== | |
| MANDATORY environment variables (see pre-submission checklist): | |
| API_BASE_URL The API endpoint for the LLM. | |
| MODEL_NAME The model identifier to use for inference. | |
| HF_TOKEN Your Hugging Face / API key. | |
| OPTIONAL: | |
| LOCAL_IMAGE_NAME Docker image name (when running locally via from_docker_image()). | |
| ENV_URL Remote environment URL (overrides default HF Space URL). | |
| When LOCAL_IMAGE_NAME is set, connects via Docker. Otherwise, connects to the | |
| deployed HF Space at the default URL (or ENV_URL if provided). | |
| Defaults are set only for API_BASE_URL and MODEL_NAME. | |
| All LLM calls use the OpenAI client configured via these variables. | |
| Stdout logs follow the required structured format ([START]/[STEP]/[END]). | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| # Load .env file if present (so HF_TOKEN, LOCAL_IMAGE_NAME etc. work without | |
| # manually exporting in the shell). python-dotenv is already available via litellm. | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| import textwrap | |
| from typing import Any, List, Optional | |
| from openai import OpenAI | |
| from triagesieve_env import TriageSieveEnv | |
| from triagesieve_env.models import ( | |
| ActionType, | |
| CloseReason, | |
| Impact, | |
| IssueFamily, | |
| IssueSubtype, | |
| QueueId, | |
| TriageSieveAction, | |
| TriageSieveObservation, | |
| TaskDifficulty, | |
| Urgency, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Configuration (matches pre-submission checklist EXACTLY) | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Environment connection: Docker (local) or HF Space (remote) | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| # HF Space URL — used when LOCAL_IMAGE_NAME is not set (e.g., during hackathon validation) | |
| HF_SPACE_URL = os.getenv("ENV_URL", "https://angshuman28-triagesieve-env.hf.space") | |
| BENCHMARK = "triagesieve_env" | |
| TEMPERATURE = 0.0 | |
| MAX_TOKENS = 512 | |
| SUCCESS_SCORE_THRESHOLD = 0.5 # minimum final score for "success" | |
| # Task ladder: matches episode_engine budget exactly, plus a small overflow buffer | |
| TASK_CONFIGS = [ | |
| {"task_name": "easy", "seed": 0, "difficulty": "easy", "max_steps": 8}, | |
| {"task_name": "medium", "seed": 1, "difficulty": "medium", "max_steps": 14}, | |
| {"task_name": "hard", "seed": 2, "difficulty": "hard", "max_steps": 20}, | |
| ] | |
| # Enum fields requiring lowercase normalization when parsing LLM output | |
| _ENUM_FIELDS: dict[str, type] = { | |
| "action_type": ActionType, | |
| "issue_family": IssueFamily, | |
| "issue_subtype": IssueSubtype, | |
| "impact": Impact, | |
| "urgency": Urgency, | |
| "queue_id": QueueId, | |
| "close_reason": CloseReason, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Mandatory stdout logging (DO NOT MODIFY FORMAT) | |
| # --------------------------------------------------------------------------- | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Observation → text serialization (mirrors baseline/llm_baseline.py) | |
| # --------------------------------------------------------------------------- | |
| def serialize_observation(obs: TriageSieveObservation) -> str: | |
| parts: list[str] = [] | |
| parts.append( | |
| f"=== Episode Context ===\n" | |
| f"Step: {obs.step_count} | Budget remaining: {obs.action_budget_remaining} | " | |
| f"Difficulty: {obs.task_difficulty.value} | Time: {obs.current_time}\n" | |
| f"Last action result: {obs.last_action_result}" | |
| ) | |
| parts.append("=== Inbox ===") | |
| for item in obs.inbox_summaries: | |
| sla = f"{item.sla_remaining_minutes}min" if item.sla_remaining_minutes is not None else "n/a" | |
| parts.append( | |
| f"- [{item.ticket_id}] {item.subject} | from: {item.sender_email} | " | |
| f"status: {item.status.value} | tier: {item.customer_tier.value} | " | |
| f"SLA: {sla} | attachment: {item.has_attachment}\n" | |
| f" Preview: {item.short_preview}" | |
| ) | |
| if obs.focused_ticket is not None: | |
| ft = obs.focused_ticket | |
| parts.append( | |
| f"=== Focused Ticket: {ft.ticket_id} ===\n" | |
| f"Subject: {ft.subject}\n" | |
| f"Latest message: {ft.latest_message}" | |
| ) | |
| if ft.thread_history: | |
| parts.append("Thread history:") | |
| for msg in ft.thread_history: | |
| parts.append(f" [{msg.get('role', '?')}] {msg.get('content', '')}") | |
| if ft.attachments: | |
| parts.append(f"Attachments: {', '.join(ft.attachments)}") | |
| if ft.visible_internal_notes: | |
| parts.append(f"Internal notes: {'; '.join(ft.visible_internal_notes)}") | |
| if ft.prior_actions_taken: | |
| parts.append(f"Prior actions: {', '.join(ft.prior_actions_taken)}") | |
| parts.append( | |
| f"=== Legal Actions ===\n" | |
| f"{', '.join(a.value for a in obs.legal_actions)}" | |
| ) | |
| parts.append("=== Routing Policies ===") | |
| for card in obs.routing_policy_cards: | |
| prereqs = ", ".join(card.prerequisites) if card.prerequisites else "none" | |
| families = ", ".join(f.value for f in card.handles_families) | |
| parts.append( | |
| f"- {card.queue_id.value}: {card.description} | " | |
| f"prereqs: {prereqs} | families: {families}" | |
| ) | |
| parts.append("=== SLA Policies ===") | |
| for card in obs.sla_policy_cards: | |
| parts.append( | |
| f"- {card.tier.value}: respond {card.response_deadline_minutes}min, " | |
| f"resolve {card.resolution_deadline_minutes}min" | |
| ) | |
| if obs.available_templates: | |
| parts.append("=== Templates ===") | |
| for tpl in obs.available_templates: | |
| parts.append( | |
| f"- {tpl.get('template_id', '?')}: {tpl.get('name', '?')} " | |
| f"({tpl.get('applies_to', '?')})" | |
| ) | |
| if obs.hint: | |
| parts.append(f"=== Hint ===\n{obs.hint}") | |
| return "\n\n".join(parts) | |
| # --------------------------------------------------------------------------- | |
| # System prompt | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = textwrap.dedent(""" | |
| You are a support-ticket triage agent. Your job is to process an inbox of support tickets by taking structured actions. | |
| You must respond with EXACTLY ONE JSON object per turn. No extra text, no markdown fences, just the JSON. | |
| == ACTION TYPES AND REQUIRED FIELDS == | |
| 1. open_ticket: {"action_type": "open_ticket", "ticket_id": "<id>"} | |
| 2. classify_ticket: {"action_type": "classify_ticket", "ticket_id": "<id>", "issue_family": "<family>", "issue_subtype": "<subtype>"} | |
| 3. set_impact_urgency: {"action_type": "set_impact_urgency", "ticket_id": "<id>", "impact": "<impact>", "urgency": "<urgency>"} | |
| 4. route_ticket: {"action_type": "route_ticket", "ticket_id": "<id>", "queue_id": "<queue>"} | |
| 5. request_information: {"action_type": "request_information", "ticket_id": "<id>", "requested_fields": ["field1", ...], "template_id": "<optional>"} | |
| 6. escalate_ticket: {"action_type": "escalate_ticket", "ticket_id": "<id>", "queue_id": "<queue>", "reason_code": "<reason>"} | |
| 7. merge_duplicate: {"action_type": "merge_duplicate", "ticket_id": "<id>", "target_ticket_id": "<original_id>"} | |
| 8. close_ticket: {"action_type": "close_ticket", "ticket_id": "<id>", "close_reason": "<reason>", "template_id": "<optional>"} | |
| 9. skip_turn: {"action_type": "skip_turn"} | |
| 10. finish_episode: {"action_type": "finish_episode"} | |
| == ENUM VALUES == | |
| issue_family: billing, technical, account, security, shipping | |
| issue_subtype: | |
| billing: refund, invoice_error, failed_charge | |
| technical: bug_report, api_error, integration_failure | |
| account: password_reset, sso_issue, account_lockout | |
| security: suspicious_login, exposure_risk, abuse_report | |
| shipping: delay, tracking_problem, lost_package | |
| impact: single_user, team, org_wide, revenue_affecting | |
| urgency: low, medium, high, critical | |
| queue_id: billing_team, tech_support_l1, tech_support_l2, account_team, security_team, shipping_team, refund_team, spam_filter, sales_or_feature_requests | |
| close_reason: resolved, duplicate, non_actionable, feature_request, no_response | |
| == PRIORITY DERIVATION (for your reasoning only) == | |
| single_user: low/low/medium/high (columns: urgency low/medium/high/critical) | |
| team: low/medium/high/high | |
| org_wide: medium/high/high/critical | |
| revenue_affecting: high/high/critical/critical | |
| == STRATEGY == | |
| 1. Open tickets starting with highest-priority ones (enterprise/critical SLA first). | |
| 2. Classify after reading the ticket content carefully. | |
| 3. Set impact and urgency based on the ticket details. | |
| 4. Request missing information if needed before routing. | |
| 5. Route to the correct queue. Note: tech_support_l2 and security_team are gated (need classification + impact/urgency first). | |
| 6. Close with the appropriate reason and template. | |
| 7. If a ticket looks like spam or non-actionable, close it as non_actionable. | |
| 8. If a ticket is a duplicate, merge it with the original. | |
| 9. Use finish_episode when all tickets are fully handled. | |
| 10. Only use skip_turn if you truly cannot determine any useful action. | |
| Respond with ONLY the JSON action object. No explanation. | |
| """).strip() | |
| # --------------------------------------------------------------------------- | |
| # LLM call | |
| # --------------------------------------------------------------------------- | |
| def get_model_action( | |
| client: OpenAI, | |
| obs_text: str, | |
| last_reward: float, | |
| step: int, | |
| ) -> str: | |
| user_content = f"Step {step} | Last reward: {last_reward:.2f}\n\n{obs_text}" | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| return (completion.choices[0].message.content or "").strip() | |
| except Exception as exc: | |
| print(f"[DEBUG] LLM call failed at step {step}: {exc}", flush=True) | |
| return "" | |
| # --------------------------------------------------------------------------- | |
| # Action parsing (mirrors baseline/llm_baseline.py) | |
| # --------------------------------------------------------------------------- | |
| def parse_action(raw_text: str) -> Optional[TriageSieveAction]: | |
| if not raw_text or not raw_text.strip(): | |
| return None | |
| text = raw_text.strip() | |
| text = re.sub(r"```(?:json)?\s*", "", text) | |
| text = re.sub(r"```\s*$", "", text) | |
| text = text.strip() | |
| data: Optional[dict[str, Any]] = None | |
| try: | |
| data = json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| if data is None: | |
| start = text.find("{") | |
| if start == -1: | |
| return None | |
| depth, end = 0, -1 | |
| for i, ch in enumerate(text[start:], start): | |
| if ch == "{": | |
| depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| end = i | |
| break | |
| if end == -1: | |
| return None | |
| try: | |
| data = json.loads(text[start: end + 1]) | |
| except json.JSONDecodeError: | |
| return None | |
| if not isinstance(data, dict) or "action_type" not in data: | |
| return None | |
| for field_name in _ENUM_FIELDS: | |
| if field_name in data and isinstance(data[field_name], str): | |
| data[field_name] = data[field_name].lower() | |
| data.setdefault("metadata", {}) | |
| try: | |
| return TriageSieveAction(**data) | |
| except (ValueError, TypeError) as exc: | |
| print(f"[DEBUG] Action validation failed: {exc}", flush=True) | |
| return None | |
| def action_to_str(action: TriageSieveAction) -> str: | |
| """Produce a concise one-token-ish string for [STEP] logging.""" | |
| parts = [action.action_type.value] | |
| if action.ticket_id: | |
| parts.append(action.ticket_id) | |
| if action.queue_id: | |
| parts.append(action.queue_id.value) | |
| if action.issue_family: | |
| parts.append(action.issue_family.value) | |
| if action.close_reason: | |
| parts.append(action.close_reason.value) | |
| return ":".join(parts) | |
| # --------------------------------------------------------------------------- | |
| # Main inference loop | |
| # --------------------------------------------------------------------------- | |
| async def run_task( | |
| client: OpenAI, | |
| env: TriageSieveEnv, | |
| task_name: str, | |
| seed: int, | |
| difficulty: str, | |
| max_steps: int, | |
| ) -> dict[str, Any]: | |
| rewards: list[float] = [] | |
| steps_taken = 0 | |
| score = 0.0 | |
| success = False | |
| episode_done = False | |
| log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) | |
| try: | |
| result = await env.reset(seed=seed, difficulty=difficulty, mode="eval_strict") | |
| obs: TriageSieveObservation = result.observation | |
| last_reward = 0.0 | |
| for step in range(1, max_steps + 1): | |
| if episode_done or obs.action_budget_remaining <= 0: | |
| break | |
| obs_text = serialize_observation(obs) | |
| raw = get_model_action(client, obs_text, last_reward, step) | |
| action = parse_action(raw) | |
| if action is None: | |
| print(f"[DEBUG] Parse failure at step {step}, using skip_turn", flush=True) | |
| action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={}) | |
| result = await env.step(action) | |
| obs = result.observation | |
| reward = result.reward if result.reward is not None else 0.0 | |
| episode_done = result.done or obs.done | |
| error_str = None if obs.last_action_result == "ok" else obs.last_action_result | |
| rewards.append(reward) | |
| steps_taken = step | |
| last_reward = reward | |
| log_step(step=step, action=action_to_str(action), reward=reward, done=episode_done, error=error_str) | |
| if episode_done: | |
| break | |
| # Send finish_episode if budget ran out but episode isn't done | |
| if not episode_done: | |
| finish = TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={}) | |
| result = await env.step(finish) | |
| obs = result.observation | |
| reward = result.reward if result.reward is not None else 0.0 | |
| episode_done = True | |
| steps_taken += 1 | |
| rewards.append(reward) | |
| log_step(step=steps_taken, action="finish_episode", reward=reward, done=True, error=None) | |
| # Final score is the terminal observation.reward (already normalized to [0, 1]) | |
| score = rewards[-1] if rewards else 0.0 | |
| # Phase 2 requires scores strictly in (0, 1); eps >= 1e-3 so .3f never rounds to "0.000"/"1.000" | |
| score = min(max(score, 1e-3), 1.0 - 1e-3) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| finally: | |
| try: | |
| await env.close() | |
| except Exception as exc: | |
| print(f"[DEBUG] env.close() error: {exc}", flush=True) | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return {"task": task_name, "score": score, "success": success, "steps": steps_taken} | |
| async def create_env_from_docker(image_name: str, timeout_s: float = 120.0) -> TriageSieveEnv: | |
| """Start a Docker container and connect with a generous timeout. | |
| The default 30s from_docker_image timeout is too tight for first-start | |
| on some machines (Windows, CI). This helper gives 120s instead. | |
| """ | |
| from openenv.core.containers.runtime.providers import LocalDockerProvider | |
| provider = LocalDockerProvider() | |
| base_url = provider.start_container(image_name) | |
| provider.wait_for_ready(base_url, timeout_s=timeout_s) | |
| client = TriageSieveEnv(base_url=base_url, provider=provider) | |
| await client.connect() | |
| return client | |
| async def create_env_from_space(space_url: str) -> TriageSieveEnv: | |
| """Connect to an already-running HF Space (or any remote OpenEnv server).""" | |
| client = TriageSieveEnv(base_url=space_url) | |
| await client.connect() | |
| return client | |
| async def main() -> None: | |
| if not HF_TOKEN: | |
| raise SystemExit("ERROR: HF_TOKEN environment variable is not set.") | |
| use_docker = bool(LOCAL_IMAGE_NAME) | |
| if use_docker: | |
| print(f"[INFO] Using Docker image: {LOCAL_IMAGE_NAME}", flush=True) | |
| else: | |
| print(f"[INFO] Using HF Space: {HF_SPACE_URL}", flush=True) | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| results = [] | |
| for cfg in TASK_CONFIGS: | |
| if use_docker: | |
| env = await create_env_from_docker(LOCAL_IMAGE_NAME) | |
| else: | |
| env = await create_env_from_space(HF_SPACE_URL) | |
| result = await run_task( | |
| client=client, | |
| env=env, | |
| task_name=cfg["task_name"], | |
| seed=cfg["seed"], | |
| difficulty=cfg["difficulty"], | |
| max_steps=cfg["max_steps"], | |
| ) | |
| results.append(result) | |
| print("\n=== RESULTS SUMMARY ===", flush=True) | |
| for r in results: | |
| status = "PASS" if r["success"] else "FAIL" | |
| print( | |
| f" {r['task']}: score={r['score']:.3f} steps={r['steps']} [{status}]", | |
| flush=True, | |
| ) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |