""" ClarifyRL — Baseline Inference Script ====================================== MANDATORY: - Before submitting, ensure the following variables are defined in your environment configuration: API_BASE_URL The API endpoint for the LLM. MODEL_NAME The model identifier to use for inference. HF_TOKEN Your Hugging Face / API key. - The inference script must be named `inference.py` and placed in the root directory of the project. - Participants must use OpenAI Client for all LLM calls using above variables. """ from __future__ import annotations import asyncio import json import os import re import sys import textwrap import time from typing import Optional try: import truststore; truststore.inject_into_ssl() except ImportError: pass from openai import OpenAI API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct") HF_TOKEN = os.getenv("HF_TOKEN") API_KEY = HF_TOKEN or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") BASELINE_MODE = os.getenv("BASELINE_MODE", "hybrid").lower() ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") TEMPERATURE = 0.7 # Qwen3 with enable_thinking=False usually fits in <200 tokens; we leave 800 # as a safety margin in case any backend still emits a block (e.g. if # `chat_template_kwargs` is silently dropped by an OpenAI-style proxy that # doesn't forward `extra_body` to vLLM). MAX_TOKENS = int(os.getenv("MAX_TOKENS", "800")) MAX_LLM_STEPS_PER_TASK = int(os.getenv("MAX_LLM_STEPS_PER_TASK", "8")) SUCCESS_SCORE_THRESHOLD = 0.5 SYSTEM_PROMPT = ( "You are a helpful assistant that books and plans things for users.\n" "The user's request will be intentionally ambiguous \u2014 you do NOT yet have all the information needed to make a good plan.\n" "\n" "You have three tools:\n" " - ask_question(question): ask the user ONE targeted clarifying question (max 6 across the episode).\n" " - propose_plan(plan): submit your final plan as a JSON STRING with the required fields. This ENDS the episode.\n" " - get_task_info(): re-read the original user request.\n" "\n" "Strategy:\n" " 1. Read the required plan fields listed in the task description.\n" " 2. Use ask_question to ask about EACH required field you do not already know.\n" " 3. When you have enough info, call propose_plan with a JSON string containing ALL required fields.\n" "\n" "Rules:\n" " - Be efficient. Each unnecessary question costs reward.\n" " - Your plan MUST include every required field listed in the task. Missing fields score zero.\n" " - NEVER include fields in your plan that you weren't told about. No hallucinating values.\n" " - The `plan` argument MUST be a JSON STRING (not a dict). Use the exact field names from the required fields list.\n" ) REQUIRED_KEYS_BY_FAMILY: dict[str, list[str]] = { "coding_requirements": ["stack", "scale", "auth", "datastore"], "medical_intake": ["primary_symptom", "duration", "severity"], "support_triage": ["order_id", "item_issue", "refund_or_replace"], "meeting_scheduling": ["participants", "date", "time"], "event_planning": ["event_type", "date", "guest_count", "venue"], } POLICY_PLANS = { "easy": [ ("get_task_info", {}), ("ask_question", {"question": "What is the main requirement?"}), ("ask_question", {"question": "Any specific preferences or constraints?"}), ], "medium": [ ("get_task_info", {}), ("ask_question", {"question": "What is the main requirement?"}), ("ask_question", {"question": "What are the specific details needed?"}), ("ask_question", {"question": "Any constraints or preferences?"}), ("ask_question", {"question": "What is the timeline or deadline?"}), ], "hard": [ ("get_task_info", {}), ("ask_question", {"question": "What is the main requirement?"}), ("ask_question", {"question": "What are the technical specifications?"}), ("ask_question", {"question": "What is the scale or scope?"}), ("ask_question", {"question": "Any constraints or limitations?"}), ("ask_question", {"question": "What is the timeline?"}), ("ask_question", {"question": "Any other preferences?"}), ], } def create_client() -> Optional[OpenAI]: if BASELINE_MODE == "policy": return None if not API_KEY: print("[DEBUG] No API key found; policy fallback will be used.", flush=True) return None try: return OpenAI(base_url=API_BASE_URL, api_key=API_KEY) except Exception as exc: print(f"[DEBUG] Failed to create OpenAI client: {exc}", flush=True) return None _PREFIX_TO_TOOL = { "ASK": "ask_question", "ASK_QUESTION": "ask_question", "QUESTION": "ask_question", "Q": "ask_question", "PROPOSE": "propose_plan", "PROPOSE_PLAN": "propose_plan", "PLAN": "propose_plan", "INFO": "get_task_info", "GET_TASK_INFO": "get_task_info", "TASK_INFO": "get_task_info", } def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]: cleaned = _strip_reasoning(response_text) tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE) raw_args = _extract_args_block(cleaned) if tool_match: tool_name = tool_match.group(1).strip() args = {} if raw_args: args = _load_json_like(raw_args) return tool_name, args json_tool_name, json_tool_args = _parse_json_tool_call(cleaned) if json_tool_name: return json_tool_name, json_tool_args fn_call = _find_balanced_func_call(cleaned) if fn_call: tool_name, raw_body = fn_call if tool_name in ("ask_question", "propose_plan", "get_task_info"): args = _parse_positional_args(tool_name, raw_body.strip()) return tool_name, args prefix_tool, prefix_args = _parse_prefixed_call(cleaned) if prefix_tool: return prefix_tool, prefix_args action_match = re.search( r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)', cleaned, re.DOTALL, ) if action_match: tool_name = action_match.group(1) if action_match.group(2) and action_match.group(3) is not None: key = action_match.group(2) val = action_match.group(3).replace('\\"', '"').replace("\\'", "'") return tool_name, {key: val} elif action_match.group(4): raw = action_match.group(4).strip() if "=" in raw: k, _, v = raw.partition("=") return tool_name, {k.strip(): v.strip().strip("\"'")} return tool_name, {} return None, {} def _parse_prefixed_call(text: str) -> tuple[Optional[str], dict]: """Handle Qwen3 GRPO outputs like: ASK: {"question": "What is the budget?"} ASK: What is the budget? PROPOSE: {"date": "2024-12-25", ...} Q: What is the budget? The 0.6B GRPO checkpoint emits these ~20% of the time. We map the prefix to the canonical tool name and parse the rest as either a JSON object or a free-form question/plan string. """ match = re.match(r"^\s*([A-Za-z_]+)\s*:\s*(.*)$", text, flags=re.DOTALL) if not match: return None, {} prefix = match.group(1).upper().replace("-", "_") if prefix not in _PREFIX_TO_TOOL: return None, {} tool_name = _PREFIX_TO_TOOL[prefix] rest = match.group(2).strip() if rest.startswith("{"): parsed = _load_json_like(rest) if isinstance(parsed, dict) and parsed: if tool_name == "ask_question": question = parsed.get("question") or parsed.get("q") or parsed.get("text") if isinstance(question, str): return tool_name, {"question": question} return tool_name, {"question": json.dumps(parsed)} if tool_name == "propose_plan": inner = parsed.get("plan") if isinstance(parsed.get("plan"), (dict, str)) else None if inner is not None: plan_str = inner if isinstance(inner, str) else json.dumps(inner) return tool_name, {"plan": plan_str} return tool_name, {"plan": json.dumps(parsed)} return tool_name, {} if tool_name == "ask_question": question = rest.strip().strip('"').strip("'") if question: return tool_name, {"question": question} if tool_name == "propose_plan" and rest: return tool_name, {"plan": rest} if tool_name == "get_task_info": return tool_name, {} return None, {} def _strip_reasoning(response_text: str) -> str: cleaned = re.sub(r".*?", "", response_text, flags=re.DOTALL | re.IGNORECASE) cleaned = cleaned.replace("```json", "```") cleaned = cleaned.replace("```tool", "```") return cleaned.strip() def _extract_args_block(response_text: str) -> Optional[str]: args_marker = re.search(r"ARGS:\s*", response_text, re.IGNORECASE) if not args_marker: return None start = response_text.find("{", args_marker.end()) if start == -1: return None depth = 0 in_string = False escape = False for index in range(start, len(response_text)): char = response_text[index] if in_string: if escape: escape = False elif char == "\\": escape = True elif char == '"': in_string = False continue if char == '"': in_string = True elif char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0: return response_text[start:index + 1] return None def _candidate_json_objects(text: str) -> list[str]: candidates = [] start = None depth = 0 in_string = False escape = False for index, char in enumerate(text): if start is None: if char == "{": start = index depth = 1 continue if in_string: if escape: escape = False elif char == "\\": escape = True elif char == '"': in_string = False continue if char == '"': in_string = True elif char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0 and start is not None: candidates.append(text[start:index + 1]) start = None return candidates def _load_json_like(raw: str) -> dict: try: parsed = json.loads(raw) except json.JSONDecodeError: normalized = raw.strip() normalized = re.sub(r"(\w+)\s*=", r'"\1": ', normalized) normalized = normalized.replace("'", '"') try: parsed = json.loads(normalized) except json.JSONDecodeError: return _parse_args_fallback(raw) return parsed if isinstance(parsed, dict) else {} def _parse_json_tool_call(response_text: str) -> tuple[Optional[str], dict]: for candidate in _candidate_json_objects(response_text): parsed = _load_json_like(candidate) if not parsed: continue tool_name = ( parsed.get("tool") or parsed.get("tool_name") or parsed.get("name") or parsed.get("action") ) if not isinstance(tool_name, str): continue args = parsed.get("args") or parsed.get("arguments") or parsed.get("parameters") or {} if isinstance(args, str) and args.strip().startswith("{"): args = _load_json_like(args) if not isinstance(args, dict): args = {} return tool_name.strip(), args return None, {} def _parse_args_fallback(raw: str) -> dict: args = {} for match in re.finditer(r'"(\w+)"\s*:\s*"([^"]*)"', raw): args[match.group(1)] = match.group(2) for match in re.finditer(r'"(\w+)"\s*:\s*(\d+)', raw): args[match.group(1)] = int(match.group(2)) return args _TOOL_NAMES = ("ask_question", "propose_plan", "get_task_info") def _find_balanced_func_call(text: str) -> Optional[tuple[str, str]]: """Find the first `tool_name(...)` call with balanced parens. Returns (name, body) where body is the parenthesized content with the outer parens stripped. Handles nested parens inside JSON plans and quoted questions like `What is your budget? (in USD)`. None if no recognised tool name is found. """ for match in re.finditer(r"\b(\w+)\s*\(", text): name = match.group(1) if name not in _TOOL_NAMES: continue body_start = match.end() depth = 1 in_str = False quote_char = "" escape = False for index in range(body_start, len(text)): char = text[index] if escape: escape = False continue if in_str: if char == "\\": escape = True elif char == quote_char: in_str = False continue if char in ("'", '"'): in_str = True quote_char = char continue if char == "(": depth += 1 elif char == ")": depth -= 1 if depth == 0: return name, text[body_start:index] return None def _parse_positional_args(tool_name: str, raw_args: str) -> dict: """Parse the body of a `tool_name(...)` call. Handles three syntaxes the trained Qwen3 models actually produce: 1. Single keyword arg with quoted value: `question="What is your budget?"` 2. Bare keyword arg (unquoted JSON-ish): `plan={"event_type": "wedding"}` 3. Pure positional (legacy): `What is your budget?` The previous implementation just split on `,` and stripped end quotes, which corrupted `question="..."` into a literal `question="...` string. """ if not raw_args: return {} arg_map = { "ask_question": ["question"], "propose_plan": ["plan"], } param_names = arg_map.get(tool_name, []) text = raw_args.strip() kw_quoted = re.match( r"^\s*(\w+)\s*=\s*(['\"])(.*)\2\s*$", text, flags=re.DOTALL, ) if kw_quoted: key = kw_quoted.group(1) val = kw_quoted.group(3).replace('\\"', '"').replace("\\'", "'") return {key: val} kw_brace = re.match(r"^\s*(\w+)\s*=\s*(\{.*\})\s*$", text, flags=re.DOTALL) if kw_brace: return {kw_brace.group(1): kw_brace.group(2)} if "=" in text and len(param_names) == 1: key, _, val = text.partition("=") key_clean = key.strip() if key_clean and key_clean.isidentifier(): return {key_clean: val.strip().strip("'\"")} quoted = re.match(r"""^\s*(['"])(.*)\1\s*$""", text, flags=re.DOTALL) if quoted and param_names: val = quoted.group(2).replace('\\"', '"').replace("\\'", "'") return {param_names[0]: val} if param_names and text.startswith("{") and text.endswith("}"): return {param_names[0]: text} parts = _split_top_level_commas(text) args: dict = {} for i, part in enumerate(parts): cleaned = part.strip().strip("'\"") if i < len(param_names): args[param_names[i]] = cleaned return args def _split_top_level_commas(text: str) -> list[str]: """Split on commas only when not inside quotes / brackets / braces.""" out: list[str] = [] depth_paren = 0 depth_brace = 0 depth_brack = 0 in_str = False quote = "" escape = False buf: list[str] = [] for ch in text: if escape: buf.append(ch) escape = False continue if in_str: if ch == "\\": escape = True elif ch == quote: in_str = False buf.append(ch) continue if ch in ("'", '"'): in_str = True quote = ch buf.append(ch) continue if ch == "(": depth_paren += 1 elif ch == ")": depth_paren -= 1 elif ch == "{": depth_brace += 1 elif ch == "}": depth_brace -= 1 elif ch == "[": depth_brack += 1 elif ch == "]": depth_brack -= 1 elif ch == "," and depth_paren == 0 and depth_brace == 0 and depth_brack == 0: out.append("".join(buf)) buf = [] continue buf.append(ch) if buf: out.append("".join(buf)) return out def _parse_result_field(obs: dict) -> str: result_raw = obs.get("result", "") if not result_raw: return str(obs) try: parsed = json.loads(result_raw) if isinstance(parsed, dict) and "tool_result" in parsed: return parsed["tool_result"] return json.dumps(parsed, indent=2) except (json.JSONDecodeError, TypeError): return str(result_raw) def _next_policy_action( task_id: str, step_index: int, request_text: str, revealed: dict, ) -> tuple[str, dict]: plan = POLICY_PLANS.get(task_id, POLICY_PLANS["medium"]) if step_index < len(plan): return plan[step_index] return ("propose_plan", {"plan": json.dumps(revealed)}) def _choose_action( task_id: str, messages: list[dict], llm_client: Optional[OpenAI], step_index: int, llm_attempts: int, request_text: str, revealed: dict, ) -> tuple[str, dict, bool, int]: policy_action = _next_policy_action(task_id, step_index, request_text, revealed) if BASELINE_MODE == "policy" or llm_client is None: return policy_action[0], policy_action[1], True, llm_attempts if llm_attempts >= MAX_LLM_STEPS_PER_TASK: return policy_action[0], policy_action[1], True, llm_attempts try: # Qwen3 ships with reasoning ("...") enabled by default, # which on a 300-token budget burns the entire reply inside and # never reaches the TOOL/ARGS block we parse. Training disables it via # `chat_template_kwargs={"enable_thinking": False}` (see train_grpo.py), # so eval must do the same to match the deployment contract. vLLM # forwards `chat_template_kwargs` from `extra_body` straight into the # tokenizer's apply_chat_template; backends that don't support it # (HF Router) silently drop the field, so it's safe to always include. response = llm_client.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, extra_body={"chat_template_kwargs": {"enable_thinking": False}}, ) assistant_msg = response.choices[0].message.content or "" llm_attempts += 1 except Exception as exc: print(f" LLM unavailable, switching to policy: {exc}") return policy_action[0], policy_action[1], True, MAX_LLM_STEPS_PER_TASK tool_name, args = parse_tool_call(assistant_msg) if tool_name and tool_name in ("ask_question", "propose_plan", "get_task_info"): messages.append({"role": "assistant", "content": assistant_msg}) return tool_name, args, False, llm_attempts if tool_name: print(f" LLM suggested unknown tool {tool_name}; using policy instead.") else: print(" Could not parse tool call; using policy instead.") messages.append({"role": "assistant", "content": assistant_msg}) return policy_action[0], policy_action[1], True, MAX_LLM_STEPS_PER_TASK def _get_ws_url() -> str: ws_url = ENV_BASE_URL.replace("https://", "wss://").replace("http://", "ws://") return f"{ws_url}/ws" def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step( step: int, action: str, reward: float, done: bool, error: Optional[str] = None, ) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: list) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True, ) async def ws_reset(ws, task_id: str) -> dict: await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id}})) resp = json.loads(await ws.recv()) if resp.get("type") == "error": return {"observation": {}, "reward": 0.0, "done": False, "error": resp.get("data", {})} data = resp.get("data", {}) return { "observation": data.get("observation", {}), "reward": data.get("reward", 0.0), "done": data.get("done", False), } async def ws_step(ws, tool_name: str, args: dict) -> dict: action = {"type": "call_tool", "tool_name": tool_name, "arguments": args} await ws.send(json.dumps({"type": "step", "data": action})) resp = json.loads(await ws.recv()) if resp.get("type") == "error": return { "observation": {"result": json.dumps({"error": resp.get("data", {}).get("message", "Unknown error")})}, "reward": 0.0, "done": False, } data = resp.get("data", {}) return { "observation": data.get("observation", {}), "reward": data.get("reward", 0.0), "done": data.get("done", False), } def wait_for_server(base_url: str, timeout: int = 120) -> bool: import urllib.request import urllib.error import ssl ctx = ssl.create_default_context() try: import certifi ctx.load_verify_locations(certifi.where()) except ImportError: pass urls = [f"{base_url}/health", f"{base_url}/"] deadline = time.time() + timeout while time.time() < deadline: for url in urls: try: req = urllib.request.urlopen(url, timeout=5, context=ctx) if req.status == 200: return True except Exception: pass time.sleep(2) return False async def run_task_async(llm_client: Optional[OpenAI], task_id: str, task_title: str) -> float: rewards_list: list[float] = [] steps_taken = 0 score = 0.0 success = False log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy") try: import websockets print(f"\nTask: {task_id} ({task_title})") print("-" * 50) ws_url = _get_ws_url() async with websockets.connect(ws_url, open_timeout=30, close_timeout=10) as ws: reset_result = await ws_reset(ws, task_id) obs = reset_result.get("observation", {}) initial_result = obs.get("result", "") try: initial_data = json.loads(initial_result) if initial_result else {} except (json.JSONDecodeError, TypeError): initial_data = {} request_text = initial_data.get("request", str(initial_data)) max_steps = initial_data.get("max_steps", 10) family = initial_data.get("family", "") questions_remaining = initial_data.get("questions_remaining", 6) rk = REQUIRED_KEYS_BY_FAMILY.get(family, []) required_keys_str = ", ".join(rk) if rk else "unknown" initial_context = ( f"USER REQUEST: {request_text}\n" f"Task family: {family}\n" f"Required plan fields: {required_keys_str}\n" f"You have {max_steps} turns and may ask up to {questions_remaining} clarifying questions.\n" f"Use the tools to ask about each required field, then call propose_plan with a JSON string containing ALL required fields." ) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": initial_context}, ] task_step_budget = max_steps llm_attempts = 0 revealed: dict = {} for step in range(1, task_step_budget + 1): tool_name, args, used_policy, llm_attempts = _choose_action( task_id, messages, llm_client, step - 1, llm_attempts, request_text, revealed, ) args_str = json.dumps(args) if args else "{}" action_str = f"{tool_name}({args_str})" source = "policy" if used_policy else "llm" print(f" Step {step}: [{source}] {action_str}") result = await ws_step(ws, tool_name, args) obs_data = result.get("observation", {}) reward = result.get("reward", 0.0) done = result.get("done", False) tool_result = _parse_result_field(obs_data) try: result_parsed = json.loads(tool_result) if isinstance(tool_result, str) else tool_result if isinstance(result_parsed, dict): for k, v in result_parsed.items(): if k not in ("error", "episode_done", "questions_remaining", "fields_revealed"): revealed[k] = v except (json.JSONDecodeError, TypeError): pass rewards_list.append(reward) steps_taken = step log_step(step=step, action=action_str, reward=reward, done=done, error=None) if len(str(tool_result)) > 1500: tool_result = str(tool_result)[:1500] + "... [truncated]" if used_policy: messages.append({ "role": "assistant", "content": f"TOOL: {tool_name}\nARGS: {json.dumps(args)}", }) messages.append({ "role": "user", "content": f"Tool result:\n{tool_result}\n\nReward: {reward}\nSteps remaining: {max_steps - step}", }) if done: try: terminal_data = json.loads(obs_data.get("result", "{}")) except (json.JSONDecodeError, TypeError): terminal_data = {} score = terminal_data.get("final_score", terminal_data.get("score", reward)) if score is None: score = reward success = score >= SUCCESS_SCORE_THRESHOLD breakdown = terminal_data.get("score_breakdown", {}) print(f" --> Episode ended. Score: {score}") if breakdown: for comp, val in breakdown.items(): print(f" {comp}: {val}") break else: score = sum(rewards_list) if rewards_list else 0.0 score = min(max(score, 0.0), 1.0) success = score >= SUCCESS_SCORE_THRESHOLD print(f" --> Max steps reached. Score: {score}") except Exception as exc: print(f"[DEBUG] Task {task_id} error: {exc}", flush=True) finally: log_end(success=success, steps=steps_taken, score=score, rewards=rewards_list) return score def main(): print("=" * 60) print(" ClarifyRL — Baseline Inference") print("=" * 60) print(f"Mode: {BASELINE_MODE}") print(f"API: {API_BASE_URL}") print(f"Model: {MODEL_NAME}") print(f"Environment: {ENV_BASE_URL}") tasks = [ ("easy", "Mild Ambiguity (2-3 fields)"), ("medium", "Moderate Ambiguity (4-5 fields)"), ("hard", "High Ambiguity (6-7 fields)"), ] print("\nWaiting for environment server...", flush=True) server_ok = wait_for_server(ENV_BASE_URL) if not server_ok: print("ERROR: Environment server not reachable.", flush=True) for task_id, title in tasks: log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy") log_end(success=False, steps=0, score=0.0, rewards=[]) print("Emitted zero-score logs for all tasks. Exiting.", flush=True) sys.exit(0) print("Server is ready.\n", flush=True) llm_client = create_client() task_timeout = 300 scores = {} for task_id, title in tasks: try: score = asyncio.run( asyncio.wait_for(run_task_async(llm_client, task_id, title), timeout=task_timeout) ) except asyncio.TimeoutError: print(f"[DEBUG] Task {task_id} timed out after {task_timeout}s", flush=True) log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy") log_end(success=False, steps=0, score=0.0, rewards=[]) score = 0.0 except Exception as exc: print(f"[DEBUG] Task {task_id} crashed: {exc}", flush=True) score = 0.0 scores[task_id] = score print("\n" + "=" * 60) print(" Summary") print("=" * 60) for task_id, title in tasks: print(f" {task_id:<8s} ({title}): {scores.get(task_id, 0.0):.2f}") avg = sum(scores.values()) / len(scores) if scores else 0.0 print(f"\n Average: {avg:.2f}") print("=" * 60) if __name__ == "__main__": try: main() except SystemExit: raise except Exception as exc: print(f"[DEBUG] Fatal error in main: {exc}", flush=True) sys.exit(0)