Spaces:
Sleeping
Sleeping
| """ | |
| ClarifyRL — Baseline Inference Script | |
| ====================================== | |
| MANDATORY: | |
| - Before submitting, ensure the following variables are defined in your | |
| environment configuration: | |
| API_BASE_URL The API endpoint for the LLM. | |
| MODEL_NAME The model identifier to use for inference. | |
| HF_TOKEN Your Hugging Face / API key. | |
| - The inference script must be named `inference.py` and placed in the root | |
| directory of the project. | |
| - Participants must use OpenAI Client for all LLM calls using above variables. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import textwrap | |
| import time | |
| from typing import Optional | |
| try: | |
| import truststore; truststore.inject_into_ssl() | |
| except ImportError: | |
| pass | |
| from openai import OpenAI | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| API_KEY = HF_TOKEN or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") | |
| BASELINE_MODE = os.getenv("BASELINE_MODE", "hybrid").lower() | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") | |
| TEMPERATURE = 0.7 | |
| # Qwen3 with enable_thinking=False usually fits in <200 tokens; we leave 800 | |
| # as a safety margin in case any backend still emits a <think> block (e.g. if | |
| # `chat_template_kwargs` is silently dropped by an OpenAI-style proxy that | |
| # doesn't forward `extra_body` to vLLM). | |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "800")) | |
| MAX_LLM_STEPS_PER_TASK = int(os.getenv("MAX_LLM_STEPS_PER_TASK", "8")) | |
| SUCCESS_SCORE_THRESHOLD = 0.5 | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful assistant that books and plans things for users.\n" | |
| "The user's request will be intentionally ambiguous \u2014 you do NOT yet have all the information needed to make a good plan.\n" | |
| "\n" | |
| "You have three tools:\n" | |
| " - ask_question(question): ask the user ONE targeted clarifying question (max 6 across the episode).\n" | |
| " - propose_plan(plan): submit your final plan as a JSON STRING with the required fields. This ENDS the episode.\n" | |
| " - get_task_info(): re-read the original user request.\n" | |
| "\n" | |
| "Strategy:\n" | |
| " 1. Read the required plan fields listed in the task description.\n" | |
| " 2. Use ask_question to ask about EACH required field you do not already know.\n" | |
| " 3. When you have enough info, call propose_plan with a JSON string containing ALL required fields.\n" | |
| "\n" | |
| "Rules:\n" | |
| " - Be efficient. Each unnecessary question costs reward.\n" | |
| " - Your plan MUST include every required field listed in the task. Missing fields score zero.\n" | |
| " - NEVER include fields in your plan that you weren't told about. No hallucinating values.\n" | |
| " - The `plan` argument MUST be a JSON STRING (not a dict). Use the exact field names from the required fields list.\n" | |
| ) | |
| REQUIRED_KEYS_BY_FAMILY: dict[str, list[str]] = { | |
| "coding_requirements": ["stack", "scale", "auth", "datastore"], | |
| "medical_intake": ["primary_symptom", "duration", "severity"], | |
| "support_triage": ["order_id", "item_issue", "refund_or_replace"], | |
| "meeting_scheduling": ["participants", "date", "time"], | |
| "event_planning": ["event_type", "date", "guest_count", "venue"], | |
| } | |
| POLICY_PLANS = { | |
| "easy": [ | |
| ("get_task_info", {}), | |
| ("ask_question", {"question": "What is the main requirement?"}), | |
| ("ask_question", {"question": "Any specific preferences or constraints?"}), | |
| ], | |
| "medium": [ | |
| ("get_task_info", {}), | |
| ("ask_question", {"question": "What is the main requirement?"}), | |
| ("ask_question", {"question": "What are the specific details needed?"}), | |
| ("ask_question", {"question": "Any constraints or preferences?"}), | |
| ("ask_question", {"question": "What is the timeline or deadline?"}), | |
| ], | |
| "hard": [ | |
| ("get_task_info", {}), | |
| ("ask_question", {"question": "What is the main requirement?"}), | |
| ("ask_question", {"question": "What are the technical specifications?"}), | |
| ("ask_question", {"question": "What is the scale or scope?"}), | |
| ("ask_question", {"question": "Any constraints or limitations?"}), | |
| ("ask_question", {"question": "What is the timeline?"}), | |
| ("ask_question", {"question": "Any other preferences?"}), | |
| ], | |
| } | |
| def create_client() -> Optional[OpenAI]: | |
| if BASELINE_MODE == "policy": | |
| return None | |
| if not API_KEY: | |
| print("[DEBUG] No API key found; policy fallback will be used.", flush=True) | |
| return None | |
| try: | |
| return OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| except Exception as exc: | |
| print(f"[DEBUG] Failed to create OpenAI client: {exc}", flush=True) | |
| return None | |
| _PREFIX_TO_TOOL = { | |
| "ASK": "ask_question", | |
| "ASK_QUESTION": "ask_question", | |
| "QUESTION": "ask_question", | |
| "Q": "ask_question", | |
| "PROPOSE": "propose_plan", | |
| "PROPOSE_PLAN": "propose_plan", | |
| "PLAN": "propose_plan", | |
| "INFO": "get_task_info", | |
| "GET_TASK_INFO": "get_task_info", | |
| "TASK_INFO": "get_task_info", | |
| } | |
| def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]: | |
| cleaned = _strip_reasoning(response_text) | |
| tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE) | |
| raw_args = _extract_args_block(cleaned) | |
| if tool_match: | |
| tool_name = tool_match.group(1).strip() | |
| args = {} | |
| if raw_args: | |
| args = _load_json_like(raw_args) | |
| return tool_name, args | |
| json_tool_name, json_tool_args = _parse_json_tool_call(cleaned) | |
| if json_tool_name: | |
| return json_tool_name, json_tool_args | |
| fn_call = _find_balanced_func_call(cleaned) | |
| if fn_call: | |
| tool_name, raw_body = fn_call | |
| if tool_name in ("ask_question", "propose_plan", "get_task_info"): | |
| args = _parse_positional_args(tool_name, raw_body.strip()) | |
| return tool_name, args | |
| prefix_tool, prefix_args = _parse_prefixed_call(cleaned) | |
| if prefix_tool: | |
| return prefix_tool, prefix_args | |
| action_match = re.search( | |
| r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)', | |
| cleaned, re.DOTALL, | |
| ) | |
| if action_match: | |
| tool_name = action_match.group(1) | |
| if action_match.group(2) and action_match.group(3) is not None: | |
| key = action_match.group(2) | |
| val = action_match.group(3).replace('\\"', '"').replace("\\'", "'") | |
| return tool_name, {key: val} | |
| elif action_match.group(4): | |
| raw = action_match.group(4).strip() | |
| if "=" in raw: | |
| k, _, v = raw.partition("=") | |
| return tool_name, {k.strip(): v.strip().strip("\"'")} | |
| return tool_name, {} | |
| return None, {} | |
| def _parse_prefixed_call(text: str) -> tuple[Optional[str], dict]: | |
| """Handle Qwen3 GRPO outputs like: | |
| ASK: {"question": "What is the budget?"} | |
| ASK: What is the budget? | |
| PROPOSE: {"date": "2024-12-25", ...} | |
| Q: What is the budget? | |
| The 0.6B GRPO checkpoint emits these ~20% of the time. We map the | |
| prefix to the canonical tool name and parse the rest as either a JSON | |
| object or a free-form question/plan string. | |
| """ | |
| match = re.match(r"^\s*([A-Za-z_]+)\s*:\s*(.*)$", text, flags=re.DOTALL) | |
| if not match: | |
| return None, {} | |
| prefix = match.group(1).upper().replace("-", "_") | |
| if prefix not in _PREFIX_TO_TOOL: | |
| return None, {} | |
| tool_name = _PREFIX_TO_TOOL[prefix] | |
| rest = match.group(2).strip() | |
| if rest.startswith("{"): | |
| parsed = _load_json_like(rest) | |
| if isinstance(parsed, dict) and parsed: | |
| if tool_name == "ask_question": | |
| question = parsed.get("question") or parsed.get("q") or parsed.get("text") | |
| if isinstance(question, str): | |
| return tool_name, {"question": question} | |
| return tool_name, {"question": json.dumps(parsed)} | |
| if tool_name == "propose_plan": | |
| inner = parsed.get("plan") if isinstance(parsed.get("plan"), (dict, str)) else None | |
| if inner is not None: | |
| plan_str = inner if isinstance(inner, str) else json.dumps(inner) | |
| return tool_name, {"plan": plan_str} | |
| return tool_name, {"plan": json.dumps(parsed)} | |
| return tool_name, {} | |
| if tool_name == "ask_question": | |
| question = rest.strip().strip('"').strip("'") | |
| if question: | |
| return tool_name, {"question": question} | |
| if tool_name == "propose_plan" and rest: | |
| return tool_name, {"plan": rest} | |
| if tool_name == "get_task_info": | |
| return tool_name, {} | |
| return None, {} | |
| def _strip_reasoning(response_text: str) -> str: | |
| cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE) | |
| cleaned = cleaned.replace("```json", "```") | |
| cleaned = cleaned.replace("```tool", "```") | |
| return cleaned.strip() | |
| def _extract_args_block(response_text: str) -> Optional[str]: | |
| args_marker = re.search(r"ARGS:\s*", response_text, re.IGNORECASE) | |
| if not args_marker: | |
| return None | |
| start = response_text.find("{", args_marker.end()) | |
| if start == -1: | |
| return None | |
| depth = 0 | |
| in_string = False | |
| escape = False | |
| for index in range(start, len(response_text)): | |
| char = response_text[index] | |
| if in_string: | |
| if escape: | |
| escape = False | |
| elif char == "\\": | |
| escape = True | |
| elif char == '"': | |
| in_string = False | |
| continue | |
| if char == '"': | |
| in_string = True | |
| elif char == "{": | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| return response_text[start:index + 1] | |
| return None | |
| def _candidate_json_objects(text: str) -> list[str]: | |
| candidates = [] | |
| start = None | |
| depth = 0 | |
| in_string = False | |
| escape = False | |
| for index, char in enumerate(text): | |
| if start is None: | |
| if char == "{": | |
| start = index | |
| depth = 1 | |
| continue | |
| if in_string: | |
| if escape: | |
| escape = False | |
| elif char == "\\": | |
| escape = True | |
| elif char == '"': | |
| in_string = False | |
| continue | |
| if char == '"': | |
| in_string = True | |
| elif char == "{": | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0 and start is not None: | |
| candidates.append(text[start:index + 1]) | |
| start = None | |
| return candidates | |
| def _load_json_like(raw: str) -> dict: | |
| try: | |
| parsed = json.loads(raw) | |
| except json.JSONDecodeError: | |
| normalized = raw.strip() | |
| normalized = re.sub(r"(\w+)\s*=", r'"\1": ', normalized) | |
| normalized = normalized.replace("'", '"') | |
| try: | |
| parsed = json.loads(normalized) | |
| except json.JSONDecodeError: | |
| return _parse_args_fallback(raw) | |
| return parsed if isinstance(parsed, dict) else {} | |
| def _parse_json_tool_call(response_text: str) -> tuple[Optional[str], dict]: | |
| for candidate in _candidate_json_objects(response_text): | |
| parsed = _load_json_like(candidate) | |
| if not parsed: | |
| continue | |
| tool_name = ( | |
| parsed.get("tool") or parsed.get("tool_name") | |
| or parsed.get("name") or parsed.get("action") | |
| ) | |
| if not isinstance(tool_name, str): | |
| continue | |
| args = parsed.get("args") or parsed.get("arguments") or parsed.get("parameters") or {} | |
| if isinstance(args, str) and args.strip().startswith("{"): | |
| args = _load_json_like(args) | |
| if not isinstance(args, dict): | |
| args = {} | |
| return tool_name.strip(), args | |
| return None, {} | |
| def _parse_args_fallback(raw: str) -> dict: | |
| args = {} | |
| for match in re.finditer(r'"(\w+)"\s*:\s*"([^"]*)"', raw): | |
| args[match.group(1)] = match.group(2) | |
| for match in re.finditer(r'"(\w+)"\s*:\s*(\d+)', raw): | |
| args[match.group(1)] = int(match.group(2)) | |
| return args | |
| _TOOL_NAMES = ("ask_question", "propose_plan", "get_task_info") | |
| def _find_balanced_func_call(text: str) -> Optional[tuple[str, str]]: | |
| """Find the first `tool_name(...)` call with balanced parens. | |
| Returns (name, body) where body is the parenthesized content with the | |
| outer parens stripped. Handles nested parens inside JSON plans and quoted | |
| questions like `What is your budget? (in USD)`. None if no recognised | |
| tool name is found. | |
| """ | |
| for match in re.finditer(r"\b(\w+)\s*\(", text): | |
| name = match.group(1) | |
| if name not in _TOOL_NAMES: | |
| continue | |
| body_start = match.end() | |
| depth = 1 | |
| in_str = False | |
| quote_char = "" | |
| escape = False | |
| for index in range(body_start, len(text)): | |
| char = text[index] | |
| if escape: | |
| escape = False | |
| continue | |
| if in_str: | |
| if char == "\\": | |
| escape = True | |
| elif char == quote_char: | |
| in_str = False | |
| continue | |
| if char in ("'", '"'): | |
| in_str = True | |
| quote_char = char | |
| continue | |
| if char == "(": | |
| depth += 1 | |
| elif char == ")": | |
| depth -= 1 | |
| if depth == 0: | |
| return name, text[body_start:index] | |
| return None | |
| def _parse_positional_args(tool_name: str, raw_args: str) -> dict: | |
| """Parse the body of a `tool_name(...)` call. | |
| Handles three syntaxes the trained Qwen3 models actually produce: | |
| 1. Single keyword arg with quoted value: `question="What is your budget?"` | |
| 2. Bare keyword arg (unquoted JSON-ish): `plan={"event_type": "wedding"}` | |
| 3. Pure positional (legacy): `What is your budget?` | |
| The previous implementation just split on `,` and stripped end quotes, | |
| which corrupted `question="..."` into a literal `question="...` string. | |
| """ | |
| if not raw_args: | |
| return {} | |
| arg_map = { | |
| "ask_question": ["question"], | |
| "propose_plan": ["plan"], | |
| } | |
| param_names = arg_map.get(tool_name, []) | |
| text = raw_args.strip() | |
| kw_quoted = re.match( | |
| r"^\s*(\w+)\s*=\s*(['\"])(.*)\2\s*$", | |
| text, | |
| flags=re.DOTALL, | |
| ) | |
| if kw_quoted: | |
| key = kw_quoted.group(1) | |
| val = kw_quoted.group(3).replace('\\"', '"').replace("\\'", "'") | |
| return {key: val} | |
| kw_brace = re.match(r"^\s*(\w+)\s*=\s*(\{.*\})\s*$", text, flags=re.DOTALL) | |
| if kw_brace: | |
| return {kw_brace.group(1): kw_brace.group(2)} | |
| if "=" in text and len(param_names) == 1: | |
| key, _, val = text.partition("=") | |
| key_clean = key.strip() | |
| if key_clean and key_clean.isidentifier(): | |
| return {key_clean: val.strip().strip("'\"")} | |
| quoted = re.match(r"""^\s*(['"])(.*)\1\s*$""", text, flags=re.DOTALL) | |
| if quoted and param_names: | |
| val = quoted.group(2).replace('\\"', '"').replace("\\'", "'") | |
| return {param_names[0]: val} | |
| if param_names and text.startswith("{") and text.endswith("}"): | |
| return {param_names[0]: text} | |
| parts = _split_top_level_commas(text) | |
| args: dict = {} | |
| for i, part in enumerate(parts): | |
| cleaned = part.strip().strip("'\"") | |
| if i < len(param_names): | |
| args[param_names[i]] = cleaned | |
| return args | |
| def _split_top_level_commas(text: str) -> list[str]: | |
| """Split on commas only when not inside quotes / brackets / braces.""" | |
| out: list[str] = [] | |
| depth_paren = 0 | |
| depth_brace = 0 | |
| depth_brack = 0 | |
| in_str = False | |
| quote = "" | |
| escape = False | |
| buf: list[str] = [] | |
| for ch in text: | |
| if escape: | |
| buf.append(ch) | |
| escape = False | |
| continue | |
| if in_str: | |
| if ch == "\\": | |
| escape = True | |
| elif ch == quote: | |
| in_str = False | |
| buf.append(ch) | |
| continue | |
| if ch in ("'", '"'): | |
| in_str = True | |
| quote = ch | |
| buf.append(ch) | |
| continue | |
| if ch == "(": | |
| depth_paren += 1 | |
| elif ch == ")": | |
| depth_paren -= 1 | |
| elif ch == "{": | |
| depth_brace += 1 | |
| elif ch == "}": | |
| depth_brace -= 1 | |
| elif ch == "[": | |
| depth_brack += 1 | |
| elif ch == "]": | |
| depth_brack -= 1 | |
| elif ch == "," and depth_paren == 0 and depth_brace == 0 and depth_brack == 0: | |
| out.append("".join(buf)) | |
| buf = [] | |
| continue | |
| buf.append(ch) | |
| if buf: | |
| out.append("".join(buf)) | |
| return out | |
| def _parse_result_field(obs: dict) -> str: | |
| result_raw = obs.get("result", "") | |
| if not result_raw: | |
| return str(obs) | |
| try: | |
| parsed = json.loads(result_raw) | |
| if isinstance(parsed, dict) and "tool_result" in parsed: | |
| return parsed["tool_result"] | |
| return json.dumps(parsed, indent=2) | |
| except (json.JSONDecodeError, TypeError): | |
| return str(result_raw) | |
| def _next_policy_action( | |
| task_id: str, step_index: int, request_text: str, revealed: dict, | |
| ) -> tuple[str, dict]: | |
| plan = POLICY_PLANS.get(task_id, POLICY_PLANS["medium"]) | |
| if step_index < len(plan): | |
| return plan[step_index] | |
| return ("propose_plan", {"plan": json.dumps(revealed)}) | |
| def _choose_action( | |
| task_id: str, | |
| messages: list[dict], | |
| llm_client: Optional[OpenAI], | |
| step_index: int, | |
| llm_attempts: int, | |
| request_text: str, | |
| revealed: dict, | |
| ) -> tuple[str, dict, bool, int]: | |
| policy_action = _next_policy_action(task_id, step_index, request_text, revealed) | |
| if BASELINE_MODE == "policy" or llm_client is None: | |
| return policy_action[0], policy_action[1], True, llm_attempts | |
| if llm_attempts >= MAX_LLM_STEPS_PER_TASK: | |
| return policy_action[0], policy_action[1], True, llm_attempts | |
| try: | |
| # Qwen3 ships with reasoning ("<think>...</think>") enabled by default, | |
| # which on a 300-token budget burns the entire reply inside <think> and | |
| # never reaches the TOOL/ARGS block we parse. Training disables it via | |
| # `chat_template_kwargs={"enable_thinking": False}` (see train_grpo.py), | |
| # so eval must do the same to match the deployment contract. vLLM | |
| # forwards `chat_template_kwargs` from `extra_body` straight into the | |
| # tokenizer's apply_chat_template; backends that don't support it | |
| # (HF Router) silently drop the field, so it's safe to always include. | |
| response = llm_client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| extra_body={"chat_template_kwargs": {"enable_thinking": False}}, | |
| ) | |
| assistant_msg = response.choices[0].message.content or "" | |
| llm_attempts += 1 | |
| except Exception as exc: | |
| print(f" LLM unavailable, switching to policy: {exc}") | |
| return policy_action[0], policy_action[1], True, MAX_LLM_STEPS_PER_TASK | |
| tool_name, args = parse_tool_call(assistant_msg) | |
| if tool_name and tool_name in ("ask_question", "propose_plan", "get_task_info"): | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| return tool_name, args, False, llm_attempts | |
| if tool_name: | |
| print(f" LLM suggested unknown tool {tool_name}; using policy instead.") | |
| else: | |
| print(" Could not parse tool call; using policy instead.") | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| return policy_action[0], policy_action[1], True, MAX_LLM_STEPS_PER_TASK | |
| def _get_ws_url() -> str: | |
| ws_url = ENV_BASE_URL.replace("https://", "wss://").replace("http://", "ws://") | |
| return f"{ws_url}/ws" | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] = None, | |
| ) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| async def ws_reset(ws, task_id: str) -> dict: | |
| await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id}})) | |
| resp = json.loads(await ws.recv()) | |
| if resp.get("type") == "error": | |
| return {"observation": {}, "reward": 0.0, "done": False, "error": resp.get("data", {})} | |
| data = resp.get("data", {}) | |
| return { | |
| "observation": data.get("observation", {}), | |
| "reward": data.get("reward", 0.0), | |
| "done": data.get("done", False), | |
| } | |
| async def ws_step(ws, tool_name: str, args: dict) -> dict: | |
| action = {"type": "call_tool", "tool_name": tool_name, "arguments": args} | |
| await ws.send(json.dumps({"type": "step", "data": action})) | |
| resp = json.loads(await ws.recv()) | |
| if resp.get("type") == "error": | |
| return { | |
| "observation": {"result": json.dumps({"error": resp.get("data", {}).get("message", "Unknown error")})}, | |
| "reward": 0.0, | |
| "done": False, | |
| } | |
| data = resp.get("data", {}) | |
| return { | |
| "observation": data.get("observation", {}), | |
| "reward": data.get("reward", 0.0), | |
| "done": data.get("done", False), | |
| } | |
| def wait_for_server(base_url: str, timeout: int = 120) -> bool: | |
| import urllib.request | |
| import urllib.error | |
| import ssl | |
| ctx = ssl.create_default_context() | |
| try: | |
| import certifi | |
| ctx.load_verify_locations(certifi.where()) | |
| except ImportError: | |
| pass | |
| urls = [f"{base_url}/health", f"{base_url}/"] | |
| deadline = time.time() + timeout | |
| while time.time() < deadline: | |
| for url in urls: | |
| try: | |
| req = urllib.request.urlopen(url, timeout=5, context=ctx) | |
| if req.status == 200: | |
| return True | |
| except Exception: | |
| pass | |
| time.sleep(2) | |
| return False | |
| async def run_task_async(llm_client: Optional[OpenAI], task_id: str, task_title: str) -> float: | |
| rewards_list: list[float] = [] | |
| steps_taken = 0 | |
| score = 0.0 | |
| success = False | |
| log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy") | |
| try: | |
| import websockets | |
| print(f"\nTask: {task_id} ({task_title})") | |
| print("-" * 50) | |
| ws_url = _get_ws_url() | |
| async with websockets.connect(ws_url, open_timeout=30, close_timeout=10) as ws: | |
| reset_result = await ws_reset(ws, task_id) | |
| obs = reset_result.get("observation", {}) | |
| initial_result = obs.get("result", "") | |
| try: | |
| initial_data = json.loads(initial_result) if initial_result else {} | |
| except (json.JSONDecodeError, TypeError): | |
| initial_data = {} | |
| request_text = initial_data.get("request", str(initial_data)) | |
| max_steps = initial_data.get("max_steps", 10) | |
| family = initial_data.get("family", "") | |
| questions_remaining = initial_data.get("questions_remaining", 6) | |
| rk = REQUIRED_KEYS_BY_FAMILY.get(family, []) | |
| required_keys_str = ", ".join(rk) if rk else "unknown" | |
| initial_context = ( | |
| f"USER REQUEST: {request_text}\n" | |
| f"Task family: {family}\n" | |
| f"Required plan fields: {required_keys_str}\n" | |
| f"You have {max_steps} turns and may ask up to {questions_remaining} clarifying questions.\n" | |
| f"Use the tools to ask about each required field, then call propose_plan with a JSON string containing ALL required fields." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": initial_context}, | |
| ] | |
| task_step_budget = max_steps | |
| llm_attempts = 0 | |
| revealed: dict = {} | |
| for step in range(1, task_step_budget + 1): | |
| tool_name, args, used_policy, llm_attempts = _choose_action( | |
| task_id, messages, llm_client, step - 1, llm_attempts, | |
| request_text, revealed, | |
| ) | |
| args_str = json.dumps(args) if args else "{}" | |
| action_str = f"{tool_name}({args_str})" | |
| source = "policy" if used_policy else "llm" | |
| print(f" Step {step}: [{source}] {action_str}") | |
| result = await ws_step(ws, tool_name, args) | |
| obs_data = result.get("observation", {}) | |
| reward = result.get("reward", 0.0) | |
| done = result.get("done", False) | |
| tool_result = _parse_result_field(obs_data) | |
| try: | |
| result_parsed = json.loads(tool_result) if isinstance(tool_result, str) else tool_result | |
| if isinstance(result_parsed, dict): | |
| for k, v in result_parsed.items(): | |
| if k not in ("error", "episode_done", "questions_remaining", "fields_revealed"): | |
| revealed[k] = v | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| rewards_list.append(reward) | |
| steps_taken = step | |
| log_step(step=step, action=action_str, reward=reward, done=done, error=None) | |
| if len(str(tool_result)) > 1500: | |
| tool_result = str(tool_result)[:1500] + "... [truncated]" | |
| if used_policy: | |
| messages.append({ | |
| "role": "assistant", | |
| "content": f"TOOL: {tool_name}\nARGS: {json.dumps(args)}", | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": f"Tool result:\n{tool_result}\n\nReward: {reward}\nSteps remaining: {max_steps - step}", | |
| }) | |
| if done: | |
| try: | |
| terminal_data = json.loads(obs_data.get("result", "{}")) | |
| except (json.JSONDecodeError, TypeError): | |
| terminal_data = {} | |
| score = terminal_data.get("final_score", terminal_data.get("score", reward)) | |
| if score is None: | |
| score = reward | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| breakdown = terminal_data.get("score_breakdown", {}) | |
| print(f" --> Episode ended. Score: {score}") | |
| if breakdown: | |
| for comp, val in breakdown.items(): | |
| print(f" {comp}: {val}") | |
| break | |
| else: | |
| score = sum(rewards_list) if rewards_list else 0.0 | |
| score = min(max(score, 0.0), 1.0) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| print(f" --> Max steps reached. Score: {score}") | |
| except Exception as exc: | |
| print(f"[DEBUG] Task {task_id} error: {exc}", flush=True) | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards_list) | |
| return score | |
| def main(): | |
| print("=" * 60) | |
| print(" ClarifyRL — Baseline Inference") | |
| print("=" * 60) | |
| print(f"Mode: {BASELINE_MODE}") | |
| print(f"API: {API_BASE_URL}") | |
| print(f"Model: {MODEL_NAME}") | |
| print(f"Environment: {ENV_BASE_URL}") | |
| tasks = [ | |
| ("easy", "Mild Ambiguity (2-3 fields)"), | |
| ("medium", "Moderate Ambiguity (4-5 fields)"), | |
| ("hard", "High Ambiguity (6-7 fields)"), | |
| ] | |
| print("\nWaiting for environment server...", flush=True) | |
| server_ok = wait_for_server(ENV_BASE_URL) | |
| if not server_ok: | |
| print("ERROR: Environment server not reachable.", flush=True) | |
| for task_id, title in tasks: | |
| log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy") | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| print("Emitted zero-score logs for all tasks. Exiting.", flush=True) | |
| sys.exit(0) | |
| print("Server is ready.\n", flush=True) | |
| llm_client = create_client() | |
| task_timeout = 300 | |
| scores = {} | |
| for task_id, title in tasks: | |
| try: | |
| score = asyncio.run( | |
| asyncio.wait_for(run_task_async(llm_client, task_id, title), timeout=task_timeout) | |
| ) | |
| except asyncio.TimeoutError: | |
| print(f"[DEBUG] Task {task_id} timed out after {task_timeout}s", flush=True) | |
| log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy") | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| score = 0.0 | |
| except Exception as exc: | |
| print(f"[DEBUG] Task {task_id} crashed: {exc}", flush=True) | |
| score = 0.0 | |
| scores[task_id] = score | |
| print("\n" + "=" * 60) | |
| print(" Summary") | |
| print("=" * 60) | |
| for task_id, title in tasks: | |
| print(f" {task_id:<8s} ({title}): {scores.get(task_id, 0.0):.2f}") | |
| avg = sum(scores.values()) / len(scores) if scores else 0.0 | |
| print(f"\n Average: {avg:.2f}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except SystemExit: | |
| raise | |
| except Exception as exc: | |
| print(f"[DEBUG] Fatal error in main: {exc}", flush=True) | |
| sys.exit(0) | |