Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Baseline inference script for FinePrint-Env (OpenEnv Hackathon). | |
| STDOUT FORMAT (mandatory — any deviation = incorrect scoring): | |
| [START] task=<task_name> env=<benchmark> model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| import traceback | |
| from typing import Any, List, Optional | |
| import requests | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # Configuration from environment variables | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") | |
| # Optional — if you use from_docker_image(): | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| TASKS = ["quote_accuracy", "drift_detection", "compliance_storm"] | |
| BENCHMARK = "fineprint_env" | |
| # Safety limits | |
| MAX_STEPS_PER_TASK = 30 | |
| MAX_RETRIES_HTTP = 3 | |
| HTTP_TIMEOUT = 60 | |
| SUCCESS_SCORE_THRESHOLD = 0.5 | |
| # --------------------------------------------------------------------------- | |
| # OpenAI client | |
| # --------------------------------------------------------------------------- | |
| client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=HF_TOKEN or os.getenv("OPENAI_API_KEY", ""), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Mandatory stdout log functions | |
| # --------------------------------------------------------------------------- | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # System prompt | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """\ | |
| You are an expert customer service AI agent operating in a policy \ | |
| compliance environment. Your goal is to handle customer workflows while \ | |
| maintaining strict compliance with the company's current policies. | |
| ## CRITICAL: Policies can change at any time! | |
| The company's policies may be updated during your conversation. If you \ | |
| quote outdated policies, you will be penalized heavily. Use \ | |
| 'request_verification' to check for policy updates. | |
| ## Available Commands | |
| Respond with exactly ONE JSON object per turn. The JSON must have \ | |
| "command" and "args" keys. | |
| 1. **view_policies** -- View your currently cached policy values. | |
| {"command": "view_policies", "args": {}} | |
| 2. **view_workflow** -- See the current workflow state and conversation. | |
| {"command": "view_workflow", "args": {}} | |
| 3. **check_compliance** -- Check your current compliance status. | |
| {"command": "check_compliance", "args": {}} | |
| 4. **request_verification** -- Refresh your policy cache and detect drift. | |
| {"command": "request_verification", "args": {}} | |
| 5. **quote_policy** -- Quote a specific policy field to the customer. | |
| {"command": "quote_policy", "args": {"policy_field": "return.window_days", "quoted_value": "30"}} | |
| 6. **respond_to_user** -- Send a general message to the customer. | |
| {"command": "respond_to_user", "args": {"message": "I can help you with that."}} | |
| 7. **take_action** -- Perform a workflow action (checkout, process return, etc.). | |
| {"command": "take_action", "args": {"message": "Processing your return now."}} | |
| 8. **escalate** -- Escalate to a supervisor (only when policy drift detected). | |
| {"command": "escalate", "args": {"message": "Connecting you to a supervisor."}} | |
| 9. **abort_workflow** -- Abort the current workflow (only when policies are unreliable). | |
| {"command": "abort_workflow", "args": {"message": "I need to pause this workflow."}} | |
| 10. **clarify** -- Ask the customer for clarification. | |
| {"command": "clarify", "args": {"message": "Could you clarify what you need?"}} | |
| 11. **submit** -- Submit your work for final grading. | |
| {"command": "submit", "args": {}} | |
| ## Policy Fields (dot notation) | |
| - return.window_days, return.refund_method, return.restocking_fee_percent, \ | |
| return.requires_receipt, return.electronics_window_days | |
| - shipping.free_threshold, shipping.standard_delivery_days, \ | |
| shipping.express_delivery_days, shipping.international_available, \ | |
| shipping.express_surcharge | |
| - subscription.auto_renewal, subscription.cancellation_notice_days, \ | |
| subscription.trial_period_days, subscription.monthly_fee_usd, \ | |
| subscription.refund_policy | |
| - complaint.response_sla_hours, complaint.max_compensation_usd, \ | |
| complaint.escalation_available, complaint.compensation_types | |
| - pricing.currency, pricing.price_match_guarantee, \ | |
| pricing.tax_included_in_price, pricing.bulk_discount_available | |
| - booking.cancellation_window_hours, booking.cancellation_fee_usd, \ | |
| booking.modification_allowed, booking.modification_fee_usd | |
| ## Strategy | |
| 1. Start by viewing policies with view_policies. | |
| 2. View the current workflow with view_workflow. | |
| 3. When the user asks about a policy, quote_policy with the correct field. | |
| 4. Use request_verification periodically to check for policy drift. | |
| 5. If drift is detected, re-read policies before quoting. | |
| 6. When all workflows are complete, submit. | |
| ## Scoring | |
| - 30% compliance accuracy (correct quotes / total quotes) | |
| - 50% workflow completion (completed workflows / total workflows) | |
| - 20% drift responsiveness (detected drifts / actual drifts) | |
| ## Response Format | |
| You MUST respond with a single JSON object and nothing else. \ | |
| Do not include explanations outside the JSON. Example: | |
| {"command": "view_policies", "args": {}} | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # HTTP helpers | |
| # --------------------------------------------------------------------------- | |
| def _post(endpoint: str, body: dict[str, Any]) -> dict[str, Any]: | |
| """POST to the environment server with retries.""" | |
| url = f"{ENV_URL}{endpoint}" | |
| for attempt in range(1, MAX_RETRIES_HTTP + 1): | |
| try: | |
| resp = requests.post(url, json=body, timeout=HTTP_TIMEOUT) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except (requests.RequestException, ValueError) as exc: | |
| if attempt == MAX_RETRIES_HTTP: | |
| raise | |
| time.sleep(1.0 * attempt) | |
| return {} | |
| def reset_env(task_id: str) -> dict[str, Any]: | |
| body: dict[str, Any] = {"session_id": "default", "options": {"task_id": task_id}} | |
| return _post("/reset", body) | |
| def step_env(action: dict[str, Any]) -> dict[str, Any]: | |
| body = {"session_id": "default", "action": action} | |
| return _post("/step", body) | |
| # --------------------------------------------------------------------------- | |
| # Prompt construction | |
| # --------------------------------------------------------------------------- | |
| def build_prompt( | |
| obs: dict[str, Any], task_id: str, step_num: int, max_steps: int, | |
| history: list[dict[str, str]], | |
| ) -> str: | |
| parts: list[str] = [] | |
| task_desc = obs.get("task_description", "") | |
| if task_desc: | |
| parts.append(f"## Task: {task_id}\n{task_desc}") | |
| workflow_names = obs.get("workflow_names", []) | |
| if workflow_names: | |
| parts.append(f"Available workflows: {', '.join(workflow_names)}") | |
| remaining = max_steps - step_num | |
| parts.append(f"Step {step_num}/{max_steps} (remaining: {remaining})") | |
| if history: | |
| recent = history[-3:] | |
| lines = [] | |
| for h in recent: | |
| lines.append(f" Action: {h['action']}") | |
| preview = h["result"][:300] | |
| if len(h["result"]) > 300: | |
| preview += "..." | |
| lines.append(f" Result: {preview}") | |
| parts.append("## Recent History\n" + "\n".join(lines)) | |
| output = obs.get("output", "") | |
| if output: | |
| parts.append(f"## Current Output\n{output}") | |
| if remaining <= 3: | |
| parts.append( | |
| 'WARNING: Running low on steps. Submit now: {"command": "submit", "args": {}}' | |
| ) | |
| parts.append('Respond with a single JSON object: {"command": "...", "args": {...}}') | |
| return "\n\n".join(parts) | |
| # --------------------------------------------------------------------------- | |
| # LLM response parsing | |
| # --------------------------------------------------------------------------- | |
| _JSON_BLOCK_RE = re.compile(r"```(?:json)?\s*\n?(.*?)\n?\s*```", re.DOTALL) | |
| _JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL) | |
| def parse_llm_response(text: str | None) -> dict[str, Any]: | |
| default_action: dict[str, Any] = {"command": "view_policies", "args": {}} | |
| if not text: | |
| return default_action | |
| text = text.strip() | |
| action = _try_parse_json(text) | |
| if action is not None: | |
| return action | |
| match = _JSON_BLOCK_RE.search(text) | |
| if match: | |
| action = _try_parse_json(match.group(1).strip()) | |
| if action is not None: | |
| return action | |
| match = _JSON_OBJECT_RE.search(text) | |
| if match: | |
| action = _try_parse_json(match.group(0)) | |
| if action is not None: | |
| return action | |
| return default_action | |
| def _try_parse_json(text: str) -> dict[str, Any] | None: | |
| try: | |
| data = json.loads(text) | |
| if isinstance(data, dict) and "command" in data: | |
| if "args" not in data or not isinstance(data.get("args"), dict): | |
| data["args"] = data.get("args", {}) or {} | |
| return {"command": str(data["command"]), "args": data["args"]} | |
| except (json.JSONDecodeError, TypeError, ValueError): | |
| pass | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # LLM interaction | |
| # --------------------------------------------------------------------------- | |
| def call_llm(messages: list[dict[str, str]], temperature: float = 0.0) -> str: | |
| for attempt in range(1, 3): | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, messages=messages, temperature=temperature, | |
| ) | |
| content = response.choices[0].message.content | |
| return content if content else "" | |
| except Exception as exc: | |
| if attempt == 2: | |
| raise | |
| time.sleep(2.0) | |
| return "" | |
| # --------------------------------------------------------------------------- | |
| # Main inference loop | |
| # --------------------------------------------------------------------------- | |
| def run_task(task_id: str) -> float: | |
| """Run a single task. Returns score in [0, 1].""" | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = 0.0 | |
| success = False | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| try: | |
| obs = reset_env(task_id) | |
| except Exception as exc: | |
| print(f"[ERROR] Failed to reset: {exc}", file=sys.stderr) | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| return 0.0 | |
| max_steps = MAX_STEPS_PER_TASK | |
| history: list[dict[str, str]] = [] | |
| messages: list[dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| try: | |
| while not obs.get("done", False): | |
| steps_taken += 1 | |
| prompt = build_prompt(obs, task_id, steps_taken, max_steps, history) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| llm_text = call_llm(messages) | |
| except Exception: | |
| action = {"command": "submit", "args": {}} | |
| llm_text = json.dumps(action) | |
| messages.append({"role": "assistant", "content": llm_text}) | |
| action = parse_llm_response(llm_text) | |
| history.append({"action": json.dumps(action), "result": ""}) | |
| try: | |
| obs = step_env(action) | |
| except Exception: | |
| try: | |
| obs = step_env({"command": "submit", "args": {}}) | |
| except Exception: | |
| obs = {"done": True, "reward": 0.0} | |
| break | |
| if history: | |
| history[-1]["result"] = obs.get("output", "")[:500] | |
| reward = obs.get("reward") | |
| reward_val = float(reward) if reward is not None else 0.0 | |
| done = obs.get("done", False) | |
| error = None | |
| rewards.append(reward_val) | |
| action_str = json.dumps(action) | |
| log_step(step=steps_taken, action=action_str, reward=reward_val, done=done, error=error) | |
| if steps_taken >= max_steps and not obs.get("done", False): | |
| try: | |
| obs = step_env({"command": "submit", "args": {}}) | |
| reward = obs.get("reward") | |
| reward_val = float(reward) if reward is not None else 0.0 | |
| rewards.append(reward_val) | |
| steps_taken += 1 | |
| log_step( | |
| step=steps_taken, | |
| action='{"command":"submit","args":{}}', | |
| reward=reward_val, | |
| done=obs.get("done", False), | |
| error=None, | |
| ) | |
| except Exception: | |
| obs = {"done": True, "reward": 0.0} | |
| break | |
| if len(messages) > 41: | |
| messages = [messages[0]] + messages[-40:] | |
| score = obs.get("reward", 0.0) | |
| if score is None: | |
| score = 0.0 | |
| score = float(score) | |
| score = min(max(score, 0.0), 1.0) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return score | |
| def main() -> None: | |
| try: | |
| resp = requests.get(f"{ENV_URL}/health", timeout=10) | |
| resp.raise_for_status() | |
| except Exception as exc: | |
| print(f"[FATAL] Cannot reach environment at {ENV_URL}: {exc}", file=sys.stderr) | |
| sys.exit(1) | |
| scores: dict[str, float] = {} | |
| for task_id in TASKS: | |
| try: | |
| score = run_task(task_id) | |
| scores[task_id] = score | |
| except Exception as exc: | |
| traceback.print_exc(file=sys.stderr) | |
| scores[task_id] = 0.0 | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| total = sum(scores.values()) | |
| avg = total / len(scores) if scores else 0.0 | |
| print(f"\nAverage score: {avg:.2f}", file=sys.stderr) | |
| if __name__ == "__main__": | |
| main() | |