clarify-rl / inference.py
Anurag Agarwal
Port 7860 + neon theme
657f0fa
"""
ClarifyRL — Baseline Inference Script
======================================
MANDATORY:
- Before submitting, ensure the following variables are defined in your
environment configuration:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
- The inference script must be named `inference.py` and placed in the root
directory of the project.
- Participants must use OpenAI Client for all LLM calls using above variables.
"""
from __future__ import annotations
import asyncio
import json
import os
import re
import sys
import textwrap
import time
from typing import Optional
try:
import truststore; truststore.inject_into_ssl()
except ImportError:
pass
from openai import OpenAI
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
BASELINE_MODE = os.getenv("BASELINE_MODE", "hybrid").lower()
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
TEMPERATURE = 0.7
# Qwen3 with enable_thinking=False usually fits in <200 tokens; we leave 800
# as a safety margin in case any backend still emits a <think> block (e.g. if
# `chat_template_kwargs` is silently dropped by an OpenAI-style proxy that
# doesn't forward `extra_body` to vLLM).
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "800"))
MAX_LLM_STEPS_PER_TASK = int(os.getenv("MAX_LLM_STEPS_PER_TASK", "8"))
SUCCESS_SCORE_THRESHOLD = 0.5
SYSTEM_PROMPT = (
"You are a helpful assistant that books and plans things for users.\n"
"The user's request will be intentionally ambiguous \u2014 you do NOT yet have all the information needed to make a good plan.\n"
"\n"
"You have three tools:\n"
" - ask_question(question): ask the user ONE targeted clarifying question (max 6 across the episode).\n"
" - propose_plan(plan): submit your final plan as a JSON STRING with the required fields. This ENDS the episode.\n"
" - get_task_info(): re-read the original user request.\n"
"\n"
"Strategy:\n"
" 1. Read the required plan fields listed in the task description.\n"
" 2. Use ask_question to ask about EACH required field you do not already know.\n"
" 3. When you have enough info, call propose_plan with a JSON string containing ALL required fields.\n"
"\n"
"Rules:\n"
" - Be efficient. Each unnecessary question costs reward.\n"
" - Your plan MUST include every required field listed in the task. Missing fields score zero.\n"
" - NEVER include fields in your plan that you weren't told about. No hallucinating values.\n"
" - The `plan` argument MUST be a JSON STRING (not a dict). Use the exact field names from the required fields list.\n"
)
REQUIRED_KEYS_BY_FAMILY: dict[str, list[str]] = {
"coding_requirements": ["stack", "scale", "auth", "datastore"],
"medical_intake": ["primary_symptom", "duration", "severity"],
"support_triage": ["order_id", "item_issue", "refund_or_replace"],
"meeting_scheduling": ["participants", "date", "time"],
"event_planning": ["event_type", "date", "guest_count", "venue"],
}
POLICY_PLANS = {
"easy": [
("get_task_info", {}),
("ask_question", {"question": "What is the main requirement?"}),
("ask_question", {"question": "Any specific preferences or constraints?"}),
],
"medium": [
("get_task_info", {}),
("ask_question", {"question": "What is the main requirement?"}),
("ask_question", {"question": "What are the specific details needed?"}),
("ask_question", {"question": "Any constraints or preferences?"}),
("ask_question", {"question": "What is the timeline or deadline?"}),
],
"hard": [
("get_task_info", {}),
("ask_question", {"question": "What is the main requirement?"}),
("ask_question", {"question": "What are the technical specifications?"}),
("ask_question", {"question": "What is the scale or scope?"}),
("ask_question", {"question": "Any constraints or limitations?"}),
("ask_question", {"question": "What is the timeline?"}),
("ask_question", {"question": "Any other preferences?"}),
],
}
def create_client() -> Optional[OpenAI]:
if BASELINE_MODE == "policy":
return None
if not API_KEY:
print("[DEBUG] No API key found; policy fallback will be used.", flush=True)
return None
try:
return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
except Exception as exc:
print(f"[DEBUG] Failed to create OpenAI client: {exc}", flush=True)
return None
_PREFIX_TO_TOOL = {
"ASK": "ask_question",
"ASK_QUESTION": "ask_question",
"QUESTION": "ask_question",
"Q": "ask_question",
"PROPOSE": "propose_plan",
"PROPOSE_PLAN": "propose_plan",
"PLAN": "propose_plan",
"INFO": "get_task_info",
"GET_TASK_INFO": "get_task_info",
"TASK_INFO": "get_task_info",
}
def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
cleaned = _strip_reasoning(response_text)
tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE)
raw_args = _extract_args_block(cleaned)
if tool_match:
tool_name = tool_match.group(1).strip()
args = {}
if raw_args:
args = _load_json_like(raw_args)
return tool_name, args
json_tool_name, json_tool_args = _parse_json_tool_call(cleaned)
if json_tool_name:
return json_tool_name, json_tool_args
fn_call = _find_balanced_func_call(cleaned)
if fn_call:
tool_name, raw_body = fn_call
if tool_name in ("ask_question", "propose_plan", "get_task_info"):
args = _parse_positional_args(tool_name, raw_body.strip())
return tool_name, args
prefix_tool, prefix_args = _parse_prefixed_call(cleaned)
if prefix_tool:
return prefix_tool, prefix_args
action_match = re.search(
r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)',
cleaned, re.DOTALL,
)
if action_match:
tool_name = action_match.group(1)
if action_match.group(2) and action_match.group(3) is not None:
key = action_match.group(2)
val = action_match.group(3).replace('\\"', '"').replace("\\'", "'")
return tool_name, {key: val}
elif action_match.group(4):
raw = action_match.group(4).strip()
if "=" in raw:
k, _, v = raw.partition("=")
return tool_name, {k.strip(): v.strip().strip("\"'")}
return tool_name, {}
return None, {}
def _parse_prefixed_call(text: str) -> tuple[Optional[str], dict]:
"""Handle Qwen3 GRPO outputs like:
ASK: {"question": "What is the budget?"}
ASK: What is the budget?
PROPOSE: {"date": "2024-12-25", ...}
Q: What is the budget?
The 0.6B GRPO checkpoint emits these ~20% of the time. We map the
prefix to the canonical tool name and parse the rest as either a JSON
object or a free-form question/plan string.
"""
match = re.match(r"^\s*([A-Za-z_]+)\s*:\s*(.*)$", text, flags=re.DOTALL)
if not match:
return None, {}
prefix = match.group(1).upper().replace("-", "_")
if prefix not in _PREFIX_TO_TOOL:
return None, {}
tool_name = _PREFIX_TO_TOOL[prefix]
rest = match.group(2).strip()
if rest.startswith("{"):
parsed = _load_json_like(rest)
if isinstance(parsed, dict) and parsed:
if tool_name == "ask_question":
question = parsed.get("question") or parsed.get("q") or parsed.get("text")
if isinstance(question, str):
return tool_name, {"question": question}
return tool_name, {"question": json.dumps(parsed)}
if tool_name == "propose_plan":
inner = parsed.get("plan") if isinstance(parsed.get("plan"), (dict, str)) else None
if inner is not None:
plan_str = inner if isinstance(inner, str) else json.dumps(inner)
return tool_name, {"plan": plan_str}
return tool_name, {"plan": json.dumps(parsed)}
return tool_name, {}
if tool_name == "ask_question":
question = rest.strip().strip('"').strip("'")
if question:
return tool_name, {"question": question}
if tool_name == "propose_plan" and rest:
return tool_name, {"plan": rest}
if tool_name == "get_task_info":
return tool_name, {}
return None, {}
def _strip_reasoning(response_text: str) -> str:
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE)
cleaned = cleaned.replace("```json", "```")
cleaned = cleaned.replace("```tool", "```")
return cleaned.strip()
def _extract_args_block(response_text: str) -> Optional[str]:
args_marker = re.search(r"ARGS:\s*", response_text, re.IGNORECASE)
if not args_marker:
return None
start = response_text.find("{", args_marker.end())
if start == -1:
return None
depth = 0
in_string = False
escape = False
for index in range(start, len(response_text)):
char = response_text[index]
if in_string:
if escape:
escape = False
elif char == "\\":
escape = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
elif char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
return response_text[start:index + 1]
return None
def _candidate_json_objects(text: str) -> list[str]:
candidates = []
start = None
depth = 0
in_string = False
escape = False
for index, char in enumerate(text):
if start is None:
if char == "{":
start = index
depth = 1
continue
if in_string:
if escape:
escape = False
elif char == "\\":
escape = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
elif char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0 and start is not None:
candidates.append(text[start:index + 1])
start = None
return candidates
def _load_json_like(raw: str) -> dict:
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
normalized = raw.strip()
normalized = re.sub(r"(\w+)\s*=", r'"\1": ', normalized)
normalized = normalized.replace("'", '"')
try:
parsed = json.loads(normalized)
except json.JSONDecodeError:
return _parse_args_fallback(raw)
return parsed if isinstance(parsed, dict) else {}
def _parse_json_tool_call(response_text: str) -> tuple[Optional[str], dict]:
for candidate in _candidate_json_objects(response_text):
parsed = _load_json_like(candidate)
if not parsed:
continue
tool_name = (
parsed.get("tool") or parsed.get("tool_name")
or parsed.get("name") or parsed.get("action")
)
if not isinstance(tool_name, str):
continue
args = parsed.get("args") or parsed.get("arguments") or parsed.get("parameters") or {}
if isinstance(args, str) and args.strip().startswith("{"):
args = _load_json_like(args)
if not isinstance(args, dict):
args = {}
return tool_name.strip(), args
return None, {}
def _parse_args_fallback(raw: str) -> dict:
args = {}
for match in re.finditer(r'"(\w+)"\s*:\s*"([^"]*)"', raw):
args[match.group(1)] = match.group(2)
for match in re.finditer(r'"(\w+)"\s*:\s*(\d+)', raw):
args[match.group(1)] = int(match.group(2))
return args
_TOOL_NAMES = ("ask_question", "propose_plan", "get_task_info")
def _find_balanced_func_call(text: str) -> Optional[tuple[str, str]]:
"""Find the first `tool_name(...)` call with balanced parens.
Returns (name, body) where body is the parenthesized content with the
outer parens stripped. Handles nested parens inside JSON plans and quoted
questions like `What is your budget? (in USD)`. None if no recognised
tool name is found.
"""
for match in re.finditer(r"\b(\w+)\s*\(", text):
name = match.group(1)
if name not in _TOOL_NAMES:
continue
body_start = match.end()
depth = 1
in_str = False
quote_char = ""
escape = False
for index in range(body_start, len(text)):
char = text[index]
if escape:
escape = False
continue
if in_str:
if char == "\\":
escape = True
elif char == quote_char:
in_str = False
continue
if char in ("'", '"'):
in_str = True
quote_char = char
continue
if char == "(":
depth += 1
elif char == ")":
depth -= 1
if depth == 0:
return name, text[body_start:index]
return None
def _parse_positional_args(tool_name: str, raw_args: str) -> dict:
"""Parse the body of a `tool_name(...)` call.
Handles three syntaxes the trained Qwen3 models actually produce:
1. Single keyword arg with quoted value: `question="What is your budget?"`
2. Bare keyword arg (unquoted JSON-ish): `plan={"event_type": "wedding"}`
3. Pure positional (legacy): `What is your budget?`
The previous implementation just split on `,` and stripped end quotes,
which corrupted `question="..."` into a literal `question="...` string.
"""
if not raw_args:
return {}
arg_map = {
"ask_question": ["question"],
"propose_plan": ["plan"],
}
param_names = arg_map.get(tool_name, [])
text = raw_args.strip()
kw_quoted = re.match(
r"^\s*(\w+)\s*=\s*(['\"])(.*)\2\s*$",
text,
flags=re.DOTALL,
)
if kw_quoted:
key = kw_quoted.group(1)
val = kw_quoted.group(3).replace('\\"', '"').replace("\\'", "'")
return {key: val}
kw_brace = re.match(r"^\s*(\w+)\s*=\s*(\{.*\})\s*$", text, flags=re.DOTALL)
if kw_brace:
return {kw_brace.group(1): kw_brace.group(2)}
if "=" in text and len(param_names) == 1:
key, _, val = text.partition("=")
key_clean = key.strip()
if key_clean and key_clean.isidentifier():
return {key_clean: val.strip().strip("'\"")}
quoted = re.match(r"""^\s*(['"])(.*)\1\s*$""", text, flags=re.DOTALL)
if quoted and param_names:
val = quoted.group(2).replace('\\"', '"').replace("\\'", "'")
return {param_names[0]: val}
if param_names and text.startswith("{") and text.endswith("}"):
return {param_names[0]: text}
parts = _split_top_level_commas(text)
args: dict = {}
for i, part in enumerate(parts):
cleaned = part.strip().strip("'\"")
if i < len(param_names):
args[param_names[i]] = cleaned
return args
def _split_top_level_commas(text: str) -> list[str]:
"""Split on commas only when not inside quotes / brackets / braces."""
out: list[str] = []
depth_paren = 0
depth_brace = 0
depth_brack = 0
in_str = False
quote = ""
escape = False
buf: list[str] = []
for ch in text:
if escape:
buf.append(ch)
escape = False
continue
if in_str:
if ch == "\\":
escape = True
elif ch == quote:
in_str = False
buf.append(ch)
continue
if ch in ("'", '"'):
in_str = True
quote = ch
buf.append(ch)
continue
if ch == "(":
depth_paren += 1
elif ch == ")":
depth_paren -= 1
elif ch == "{":
depth_brace += 1
elif ch == "}":
depth_brace -= 1
elif ch == "[":
depth_brack += 1
elif ch == "]":
depth_brack -= 1
elif ch == "," and depth_paren == 0 and depth_brace == 0 and depth_brack == 0:
out.append("".join(buf))
buf = []
continue
buf.append(ch)
if buf:
out.append("".join(buf))
return out
def _parse_result_field(obs: dict) -> str:
result_raw = obs.get("result", "")
if not result_raw:
return str(obs)
try:
parsed = json.loads(result_raw)
if isinstance(parsed, dict) and "tool_result" in parsed:
return parsed["tool_result"]
return json.dumps(parsed, indent=2)
except (json.JSONDecodeError, TypeError):
return str(result_raw)
def _next_policy_action(
task_id: str, step_index: int, request_text: str, revealed: dict,
) -> tuple[str, dict]:
plan = POLICY_PLANS.get(task_id, POLICY_PLANS["medium"])
if step_index < len(plan):
return plan[step_index]
return ("propose_plan", {"plan": json.dumps(revealed)})
def _choose_action(
task_id: str,
messages: list[dict],
llm_client: Optional[OpenAI],
step_index: int,
llm_attempts: int,
request_text: str,
revealed: dict,
) -> tuple[str, dict, bool, int]:
policy_action = _next_policy_action(task_id, step_index, request_text, revealed)
if BASELINE_MODE == "policy" or llm_client is None:
return policy_action[0], policy_action[1], True, llm_attempts
if llm_attempts >= MAX_LLM_STEPS_PER_TASK:
return policy_action[0], policy_action[1], True, llm_attempts
try:
# Qwen3 ships with reasoning ("<think>...</think>") enabled by default,
# which on a 300-token budget burns the entire reply inside <think> and
# never reaches the TOOL/ARGS block we parse. Training disables it via
# `chat_template_kwargs={"enable_thinking": False}` (see train_grpo.py),
# so eval must do the same to match the deployment contract. vLLM
# forwards `chat_template_kwargs` from `extra_body` straight into the
# tokenizer's apply_chat_template; backends that don't support it
# (HF Router) silently drop the field, so it's safe to always include.
response = llm_client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
)
assistant_msg = response.choices[0].message.content or ""
llm_attempts += 1
except Exception as exc:
print(f" LLM unavailable, switching to policy: {exc}")
return policy_action[0], policy_action[1], True, MAX_LLM_STEPS_PER_TASK
tool_name, args = parse_tool_call(assistant_msg)
if tool_name and tool_name in ("ask_question", "propose_plan", "get_task_info"):
messages.append({"role": "assistant", "content": assistant_msg})
return tool_name, args, False, llm_attempts
if tool_name:
print(f" LLM suggested unknown tool {tool_name}; using policy instead.")
else:
print(" Could not parse tool call; using policy instead.")
messages.append({"role": "assistant", "content": assistant_msg})
return policy_action[0], policy_action[1], True, MAX_LLM_STEPS_PER_TASK
def _get_ws_url() -> str:
ws_url = ENV_BASE_URL.replace("https://", "wss://").replace("http://", "ws://")
return f"{ws_url}/ws"
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: Optional[str] = None,
) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
async def ws_reset(ws, task_id: str) -> dict:
await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id}}))
resp = json.loads(await ws.recv())
if resp.get("type") == "error":
return {"observation": {}, "reward": 0.0, "done": False, "error": resp.get("data", {})}
data = resp.get("data", {})
return {
"observation": data.get("observation", {}),
"reward": data.get("reward", 0.0),
"done": data.get("done", False),
}
async def ws_step(ws, tool_name: str, args: dict) -> dict:
action = {"type": "call_tool", "tool_name": tool_name, "arguments": args}
await ws.send(json.dumps({"type": "step", "data": action}))
resp = json.loads(await ws.recv())
if resp.get("type") == "error":
return {
"observation": {"result": json.dumps({"error": resp.get("data", {}).get("message", "Unknown error")})},
"reward": 0.0,
"done": False,
}
data = resp.get("data", {})
return {
"observation": data.get("observation", {}),
"reward": data.get("reward", 0.0),
"done": data.get("done", False),
}
def wait_for_server(base_url: str, timeout: int = 120) -> bool:
import urllib.request
import urllib.error
import ssl
ctx = ssl.create_default_context()
try:
import certifi
ctx.load_verify_locations(certifi.where())
except ImportError:
pass
urls = [f"{base_url}/health", f"{base_url}/"]
deadline = time.time() + timeout
while time.time() < deadline:
for url in urls:
try:
req = urllib.request.urlopen(url, timeout=5, context=ctx)
if req.status == 200:
return True
except Exception:
pass
time.sleep(2)
return False
async def run_task_async(llm_client: Optional[OpenAI], task_id: str, task_title: str) -> float:
rewards_list: list[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy")
try:
import websockets
print(f"\nTask: {task_id} ({task_title})")
print("-" * 50)
ws_url = _get_ws_url()
async with websockets.connect(ws_url, open_timeout=30, close_timeout=10) as ws:
reset_result = await ws_reset(ws, task_id)
obs = reset_result.get("observation", {})
initial_result = obs.get("result", "")
try:
initial_data = json.loads(initial_result) if initial_result else {}
except (json.JSONDecodeError, TypeError):
initial_data = {}
request_text = initial_data.get("request", str(initial_data))
max_steps = initial_data.get("max_steps", 10)
family = initial_data.get("family", "")
questions_remaining = initial_data.get("questions_remaining", 6)
rk = REQUIRED_KEYS_BY_FAMILY.get(family, [])
required_keys_str = ", ".join(rk) if rk else "unknown"
initial_context = (
f"USER REQUEST: {request_text}\n"
f"Task family: {family}\n"
f"Required plan fields: {required_keys_str}\n"
f"You have {max_steps} turns and may ask up to {questions_remaining} clarifying questions.\n"
f"Use the tools to ask about each required field, then call propose_plan with a JSON string containing ALL required fields."
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": initial_context},
]
task_step_budget = max_steps
llm_attempts = 0
revealed: dict = {}
for step in range(1, task_step_budget + 1):
tool_name, args, used_policy, llm_attempts = _choose_action(
task_id, messages, llm_client, step - 1, llm_attempts,
request_text, revealed,
)
args_str = json.dumps(args) if args else "{}"
action_str = f"{tool_name}({args_str})"
source = "policy" if used_policy else "llm"
print(f" Step {step}: [{source}] {action_str}")
result = await ws_step(ws, tool_name, args)
obs_data = result.get("observation", {})
reward = result.get("reward", 0.0)
done = result.get("done", False)
tool_result = _parse_result_field(obs_data)
try:
result_parsed = json.loads(tool_result) if isinstance(tool_result, str) else tool_result
if isinstance(result_parsed, dict):
for k, v in result_parsed.items():
if k not in ("error", "episode_done", "questions_remaining", "fields_revealed"):
revealed[k] = v
except (json.JSONDecodeError, TypeError):
pass
rewards_list.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=None)
if len(str(tool_result)) > 1500:
tool_result = str(tool_result)[:1500] + "... [truncated]"
if used_policy:
messages.append({
"role": "assistant",
"content": f"TOOL: {tool_name}\nARGS: {json.dumps(args)}",
})
messages.append({
"role": "user",
"content": f"Tool result:\n{tool_result}\n\nReward: {reward}\nSteps remaining: {max_steps - step}",
})
if done:
try:
terminal_data = json.loads(obs_data.get("result", "{}"))
except (json.JSONDecodeError, TypeError):
terminal_data = {}
score = terminal_data.get("final_score", terminal_data.get("score", reward))
if score is None:
score = reward
success = score >= SUCCESS_SCORE_THRESHOLD
breakdown = terminal_data.get("score_breakdown", {})
print(f" --> Episode ended. Score: {score}")
if breakdown:
for comp, val in breakdown.items():
print(f" {comp}: {val}")
break
else:
score = sum(rewards_list) if rewards_list else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
print(f" --> Max steps reached. Score: {score}")
except Exception as exc:
print(f"[DEBUG] Task {task_id} error: {exc}", flush=True)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards_list)
return score
def main():
print("=" * 60)
print(" ClarifyRL — Baseline Inference")
print("=" * 60)
print(f"Mode: {BASELINE_MODE}")
print(f"API: {API_BASE_URL}")
print(f"Model: {MODEL_NAME}")
print(f"Environment: {ENV_BASE_URL}")
tasks = [
("easy", "Mild Ambiguity (2-3 fields)"),
("medium", "Moderate Ambiguity (4-5 fields)"),
("hard", "High Ambiguity (6-7 fields)"),
]
print("\nWaiting for environment server...", flush=True)
server_ok = wait_for_server(ENV_BASE_URL)
if not server_ok:
print("ERROR: Environment server not reachable.", flush=True)
for task_id, title in tasks:
log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy")
log_end(success=False, steps=0, score=0.0, rewards=[])
print("Emitted zero-score logs for all tasks. Exiting.", flush=True)
sys.exit(0)
print("Server is ready.\n", flush=True)
llm_client = create_client()
task_timeout = 300
scores = {}
for task_id, title in tasks:
try:
score = asyncio.run(
asyncio.wait_for(run_task_async(llm_client, task_id, title), timeout=task_timeout)
)
except asyncio.TimeoutError:
print(f"[DEBUG] Task {task_id} timed out after {task_timeout}s", flush=True)
log_start(task=task_id, env="clarify_rl", model=MODEL_NAME or "policy")
log_end(success=False, steps=0, score=0.0, rewards=[])
score = 0.0
except Exception as exc:
print(f"[DEBUG] Task {task_id} crashed: {exc}", flush=True)
score = 0.0
scores[task_id] = score
print("\n" + "=" * 60)
print(" Summary")
print("=" * 60)
for task_id, title in tasks:
print(f" {task_id:<8s} ({title}): {scores.get(task_id, 0.0):.2f}")
avg = sum(scores.values()) / len(scores) if scores else 0.0
print(f"\n Average: {avg:.2f}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except SystemExit:
raise
except Exception as exc:
print(f"[DEBUG] Fatal error in main: {exc}", flush=True)
sys.exit(0)