"""FlakySleuth baseline inference script. Environment variables: Preferred: HF_TOKEN / HUGGINGFACE_HUB_TOKEN (or OPENROUTER_API_KEY / API_KEY) API_BASE_URL (optional, defaults to https://openrouter.ai/api/v1 for router-style keys) MODEL_NAME (optional, defaults to qwen/qwen3.6-plus:free on OpenRouter) Optional fallback: OPENAI_API_KEY API_BASE_URL (defaults to https://api.openai.com/v1 when OpenAI key is used) MODEL_NAME (defaults to gpt-4o-mini for OpenAI) """ from __future__ import annotations import json import os import argparse import time from collections import defaultdict from pathlib import Path from typing import Any from openai import OpenAI try: from tqdm import tqdm except Exception: # pragma: no cover tqdm = None from env.environment import FlakySleuthEnv from env.models import FlakySleuthAction, FlakySleuthObservation OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") # Optional for environments created via from_docker_image(); kept for checklist parity. LOCAL_IMAGE_NAME = os.environ.get("LOCAL_IMAGE_NAME") RAW_API_KEY = os.environ.get("API_KEY") API_KEY = RAW_API_KEY or OPENROUTER_API_KEY or OPENAI_API_KEY or HF_TOKEN or "" def _looks_like_openrouter_key(key: str | None) -> bool: return bool(key and key.startswith("sk-or-")) DEFAULT_BASE_URL = ( "https://router.huggingface.co/v1" if ( HF_TOKEN and not RAW_API_KEY and not OPENROUTER_API_KEY and not OPENAI_API_KEY ) else ( "https://openrouter.ai/api/v1" if ( (OPENROUTER_API_KEY and not RAW_API_KEY and not OPENAI_API_KEY) or (_looks_like_openrouter_key(RAW_API_KEY) and not OPENAI_API_KEY) ) else "https://api.openai.com/v1" ) ) API_BASE_URL = os.environ.get("API_BASE_URL", DEFAULT_BASE_URL) DEFAULT_MODEL = ( "openai/gpt-oss-120b:novita" if API_BASE_URL.startswith("https://router.huggingface.co") else ( "qwen/qwen3.6-plus:free" if API_BASE_URL.startswith("https://openrouter.ai") else "gpt-4o-mini" ) ) MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL) # Keep a conservative default to stay under common hackathon runtime limits. EPISODES_PER_TASK = 2 MAX_STEPS = 20 MEMORY_MAX_CHARS = 900 LLM_MAX_RETRIES = 2 client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL) SYSTEM_PROMPT = """You are a flaky test detective. Respond ONLY with a single valid JSON object. Exploration actions: {"action_type": "read_file", "argument": "relative/path.py"} {"action_type": "search_code", "argument": "pattern"} {"action_type": "run_test", "argument": ""} Terminal actions: {"action_type": "classify_flakiness", "argument": "flaky"} {"action_type": "classify_flakiness", "argument": "stable"} {"action_type": "classify_root_cause", "argument": "OD"} {"action_type": "classify_root_cause", "argument": "OD-Brit"} {"action_type": "classify_root_cause", "argument": "OD-Vic"} {"action_type": "classify_root_cause", "argument": "NIO"} {"action_type": "classify_root_cause", "argument": "NOD"} {"action_type": "classify_root_cause", "argument": "TD"} {"action_type": "classify_root_cause", "argument": "TZD"} {"action_type": "classify_root_cause", "argument": "ID"} {"action_type": "propose_fix", "argument": "--- a/file.py\\n+++ b/file.py\\n@@ ... @@\\n-old\\n+new"} Rules: 1. Read the test file first. 2. Search for flaky signals: random, time, sleep, shared state, env vars. 3. Run the test for non-order-dependent scenarios. 4. Call one terminal action when confident. """ def _to_single_line(text: str) -> str: return " ".join(str(text).split()) def _short_error(text: str, max_chars: int = 220) -> str: one_line = _to_single_line(text) if len(one_line) <= max_chars: return one_line hidden = len(one_line) - max_chars return f"{one_line[:max_chars]}...[truncated {hidden} chars]" class _ActionParseError(Exception): def __init__(self, reason: str, detail: str) -> None: super().__init__(f"{reason}: {detail}") self.reason = reason self.detail = detail def _strip_code_fences(text: str) -> str: stripped = text.strip() if stripped.startswith("```"): lines = stripped.splitlines() if lines: lines = lines[1:] if lines and lines[-1].strip() == "```": lines = lines[:-1] stripped = "\n".join(lines).strip() if stripped.lower().startswith("json\n"): stripped = stripped[5:].strip() return stripped def _extract_first_json_object(text: str) -> str | None: start = -1 depth = 0 in_string = False escaped = False for idx, ch in enumerate(text): if in_string: if escaped: escaped = False continue if ch == "\\": escaped = True continue if ch == '"': in_string = False continue if ch == '"': in_string = True continue if ch == "{": if depth == 0: start = idx depth += 1 continue if ch == "}": if depth == 0: continue depth -= 1 if depth == 0 and start >= 0: return text[start : idx + 1] return None def _parse_action_payload(raw: str) -> tuple[FlakySleuthAction, str]: raw_text = (raw or "").strip() if not raw_text: raise _ActionParseError("llm_empty_output", "empty response body") candidates: list[str] = [] seen: set[str] = set() def add_candidate(value: str | None) -> None: if value is None: return cleaned = value.strip() if not cleaned or cleaned in seen: return seen.add(cleaned) candidates.append(cleaned) add_candidate(raw_text) stripped = _strip_code_fences(raw_text) add_candidate(stripped) add_candidate(_extract_first_json_object(stripped)) add_candidate(_extract_first_json_object(raw_text)) json_errors: list[str] = [] schema_errors: list[str] = [] for candidate in candidates: try: payload = json.loads(candidate) except json.JSONDecodeError as exc: json_errors.append(str(exc)) continue if not isinstance(payload, dict): schema_errors.append(f"top-level JSON must be an object, got {type(payload).__name__}") continue try: action = FlakySleuthAction.model_validate(payload) except Exception as exc: schema_errors.append(str(exc)) continue return action, candidate if schema_errors: raise _ActionParseError("llm_schema_error", _short_error(schema_errors[-1], max_chars=300)) if json_errors: raise _ActionParseError("llm_json_parse_error", _short_error(json_errors[-1], max_chars=300)) raise _ActionParseError("llm_json_parse_error", "unable to extract JSON object") def _json_repair_prompt(error_text: str, raw_output: str) -> str: clipped_raw = _short_error(raw_output or "(empty)", max_chars=300) clipped_err = _short_error(error_text, max_chars=260) return ( "Your previous response was invalid.\n" f"Parser error: {clipped_err}\n" f"Previous output (truncated): {clipped_raw}\n" "Respond again with ONLY one valid JSON object and no extra text.\n" 'Required schema: {"action_type": "", "argument": "", "metadata": {}}\n' 'Do NOT wrap in markdown fences. Do NOT add commentary.' ) def _chat_completion_request(messages: list[dict[str, str]]) -> Any: base_kwargs = { "model": MODEL_NAME, "messages": messages, "max_tokens": 400, "temperature": 0.0, } try: return client.chat.completions.create( response_format={"type": "json_object"}, **base_kwargs, ) except Exception as json_mode_exc: try: return client.chat.completions.create(**base_kwargs) except Exception as plain_mode_exc: raise RuntimeError( f"json_mode_error={json_mode_exc}; plain_mode_error={plain_mode_exc}" ) from plain_mode_exc def _compliance_log_start(task: str, benchmark: str, model: str) -> None: print(f"[START] task={task} env={benchmark} model={model}", flush=True) def _compliance_log_step( step: int, action: str, reward: float, done: bool, error: str | None, ) -> None: error_value = _to_single_line(error) if error else "null" print( f"[STEP] step={step} action={_to_single_line(action)} " f"reward={reward:.2f} done={str(bool(done)).lower()} error={error_value}", flush=True, ) def _compliance_log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_value = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(bool(success)).lower()} steps={steps} " f"score={score:.3f} rewards={rewards_value}", flush=True, ) def obs_to_prompt( obs: FlakySleuthObservation, *, memory_hint: str | None = None, max_steps: int = MAX_STEPS, ) -> str: tree_preview = "\n".join(obs.file_tree[:40]) return f"""TASK: {obs.task_description} Repository: {obs.repo_url} Test name: {obs.test_name} Step: {obs.step_count}/{max_steps} Test source code: ```python {obs.test_code} ``` Repository file tree: {tree_preview} Last tool output: {obs.tool_output or "(No action taken yet)"} Episode memory: {memory_hint or "(No memory yet.)"} Return only JSON action.""" def heuristic_action(obs: FlakySleuthObservation) -> FlakySleuthAction: if obs.step_count == 0 and obs.file_tree: return FlakySleuthAction(action_type="read_file", argument=obs.file_tree[0]) if obs.step_count < 2: return FlakySleuthAction(action_type="search_code", argument="random") if obs.task_type == "classify": return FlakySleuthAction(action_type="classify_flakiness", argument="flaky") if obs.task_type == "root_cause": return FlakySleuthAction(action_type="classify_root_cause", argument="NOD") return FlakySleuthAction( action_type="propose_fix", argument=( "--- a/src/math_utils.py\n" "+++ b/src/math_utils.py\n" "@@\n" "-def unstable_sum(values):\n" "- random.shuffle(values)\n" "- return values[0] + values[1]\n" "+def unstable_sum(values):\n" "+ ordered = sorted(values)\n" "+ return ordered[0] + ordered[1]\n" ), ) def llm_action( messages: list[dict[str, str]], ) -> tuple[FlakySleuthAction | None, dict[str, Any]]: meta: dict[str, Any] = { "attempted": False, "raw_output": "", "error": "", "reason": "", "attempt_count": 0, } if not API_KEY: meta["reason"] = "no_api_key" return None, meta work_messages = list(messages) last_error = "" for attempt in range(LLM_MAX_RETRIES + 1): meta["attempted"] = True meta["attempt_count"] = attempt + 1 try: response = _chat_completion_request(work_messages) except Exception as exc: last_error = f"request_failed attempt={attempt + 1}: {exc}" meta["error"] = _short_error(last_error, max_chars=500) meta["reason"] = "llm_http_error" if attempt < LLM_MAX_RETRIES: work_messages = work_messages + [ {"role": "user", "content": _json_repair_prompt(last_error, "")} ] continue return None, meta raw = (response.choices[0].message.content or "").strip() meta["raw_output"] = raw try: action, _ = _parse_action_payload(raw) meta["error"] = "" meta["reason"] = "ok" return action, meta except _ActionParseError as exc: last_error = f"{exc.reason}: {exc.detail}" meta["error"] = _short_error(last_error, max_chars=500) meta["reason"] = exc.reason if attempt < LLM_MAX_RETRIES: work_messages = work_messages + [ {"role": "user", "content": _json_repair_prompt(last_error, raw)} ] continue return None, meta meta["error"] = _short_error(last_error or "unknown llm failure", max_chars=500) meta["reason"] = meta["reason"] or "llm_error" return None, meta def _clip_text(text: str, max_chars: int) -> str: if max_chars <= 0: return text if len(text) <= max_chars: return text remaining = len(text) - max_chars return f"{text[:max_chars]}\n...[truncated {remaining} chars]" def _trace_print( enabled: bool, message: str, *, text: str | None = None, max_chars: int = 0, ) -> None: if not enabled: return print(message) if text is not None: print(_clip_text(text, max_chars)) def _format_duration(seconds: float) -> str: seconds = max(0.0, float(seconds)) mins, secs = divmod(int(round(seconds)), 60) hrs, mins = divmod(mins, 60) if hrs > 0: return f"{hrs:d}h {mins:02d}m {secs:02d}s" return f"{mins:02d}m {secs:02d}s" def _build_episode_memory( *, unique_read_files: list[str], zero_gain_read_files: set[str], search_patterns: list[str], blocked_duplicate_reads: int, no_progress_streak: int, max_chars: int, ) -> str: read_tail = ", ".join(unique_read_files[-8:]) if unique_read_files else "none" zero_tail = ", ".join(sorted(zero_gain_read_files)[-8:]) if zero_gain_read_files else "none" search_tail = ", ".join(search_patterns[-6:]) if search_patterns else "none" loop_warning = ( "WARNING: Possible loop detected. Stop repeating similar exploration. " "Switch strategy or take a terminal action." if no_progress_streak >= 3 or blocked_duplicate_reads >= 2 else "Status: exploration progress appears normal." ) memory = ( f"Read files (recent): {read_tail}\n" f"Zero-gain read files: {zero_tail}\n" f"Search patterns (recent): {search_tail}\n" f"Blocked duplicate reads: {blocked_duplicate_reads}\n" f"No-progress streak: {no_progress_streak}\n" f"{loop_warning}\n" "Guidance: Avoid rereading zero-gain files unless there is new evidence. " "Prefer targeted search_code or terminal action when confidence is enough." ) return _clip_text(memory, max_chars=max_chars) def _duplicate_read_replacement_pattern(obs: FlakySleuthObservation) -> str: test_hint = obs.test_name.split("::")[-1] if obs.test_name else "test" return ( f"{test_hint}|random|sleep|time|timeout|retry|asyncio|thread|" "fixture|global|shared|mock|patch" ) def _messages_char_count(messages: list[dict[str, str]]) -> int: # Lightweight size heuristic to avoid unbounded context growth. return sum(len(str(msg.get("content", ""))) + 32 for msg in messages) def _prune_messages_window( messages: list[dict[str, str]], *, step_number: int, prune_start_step: int, window_turns: int, max_chars: int, ) -> tuple[list[dict[str, str]], dict[str, Any] | None]: if len(messages) <= 2: return messages, None current_chars = _messages_char_count(messages) exceeds_step_threshold = step_number >= prune_start_step exceeds_char_budget = current_chars > max_chars if not exceeds_step_threshold and not exceeds_char_budget: return messages, None base = messages[:2] # system + initial prompt tail = messages[2:] keep_tail_items = max(2, window_turns * 2) if len(tail) > keep_tail_items: tail = tail[-keep_tail_items:] pruned = base + tail reason = "step_threshold" if exceeds_step_threshold else "char_budget" return pruned, { "reason": reason, "before_messages": len(messages), "after_messages": len(pruned), "before_chars": current_chars, "after_chars": _messages_char_count(pruned), "step": step_number, } def run_episode( env: FlakySleuthEnv, *, print_terminal: bool = True, trace_agent: bool = False, trace_prompts: bool = False, trace_max_chars: int = 2000, episode_label: str = "", compliance_stdout: bool = False, benchmark_name: str = "flakysleuth", compliance_task_name: str | None = None, history_prune_start_step: int = 12, history_window_turns: int = 4, history_max_chars: int = 50000, ) -> tuple[float, dict[str, Any]]: rewards: list[float] = [] steps_taken = 0 success = False episode_task_name = (compliance_task_name or episode_label.split(" ", 1)[0].strip() or "unknown") exploration_reward_total = 0.0 final_episode_score = 0.0 terminal_meta: dict[str, Any] = {} llm_steps = 0 heuristic_steps = 0 fallback_reasons: dict[str, int] = {} prune_events = 0 read_attempt_counts: dict[str, int] = {} unique_read_files: list[str] = [] zero_gain_read_files: set[str] = set() search_patterns: list[str] = [] blocked_duplicate_reads = 0 no_progress_streak = 0 memory_hint = _build_episode_memory( unique_read_files=unique_read_files, zero_gain_read_files=zero_gain_read_files, search_patterns=search_patterns, blocked_duplicate_reads=blocked_duplicate_reads, no_progress_streak=no_progress_streak, max_chars=MEMORY_MAX_CHARS, ) if compliance_stdout: _compliance_log_start(episode_task_name, benchmark_name, MODEL_NAME) try: obs = env.reset() initial_prompt = obs_to_prompt(obs, memory_hint=memory_hint, max_steps=env.max_steps) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": initial_prompt}, ] if not compliance_stdout: _trace_print( trace_agent, ( f"\n[trace] {episode_label} " f"task={obs.task_type} repo={obs.repo_url} test={obs.test_name}" ).strip(), ) if trace_prompts and not compliance_stdout: _trace_print( trace_agent, "[trace] system prompt:", text=SYSTEM_PROMPT, max_chars=trace_max_chars, ) _trace_print( trace_agent, "[trace] initial user prompt:", text=initial_prompt, max_chars=trace_max_chars, ) for step_idx in range(env.max_steps): messages, prune_info = _prune_messages_window( messages, step_number=step_idx + 1, prune_start_step=history_prune_start_step, window_turns=history_window_turns, max_chars=history_max_chars, ) if prune_info: prune_events += 1 if trace_agent and not compliance_stdout: print( "[trace] context_pruned " f"reason={prune_info['reason']} " f"step={prune_info['step']} " f"messages={prune_info['before_messages']}->{prune_info['after_messages']} " f"chars={prune_info['before_chars']}->{prune_info['after_chars']}" ) action: FlakySleuthAction action_source = "heuristic" llm_meta: dict[str, Any] = {"attempted": False, "raw_output": "", "error": ""} step_fallback_reason: str | None = None try: candidate, llm_meta = llm_action(messages) if candidate is not None: action = candidate action_source = "llm" else: action = heuristic_action(obs) if llm_meta.get("attempted"): llm_meta["error"] = ( "Model response unavailable, using heuristic fallback." ) except Exception as exc: llm_meta["error"] = str(exc) action = heuristic_action(obs) if action.action_type == "read_file": prior_reads = read_attempt_counts.get(action.argument, 0) if prior_reads >= 1: blocked_duplicate_reads += 1 replacement = FlakySleuthAction( action_type="search_code", argument=_duplicate_read_replacement_pattern(obs), ) if trace_agent and not compliance_stdout: print( "[trace] action_overridden " f"reason=duplicate_read file={action.argument} " f"replacement={replacement.action_type}" ) action = replacement if action_source == "llm": llm_steps += 1 else: heuristic_steps += 1 if not API_KEY: reason_key = "no_api_key" elif llm_meta.get("reason") and llm_meta.get("reason") != "ok": reason_key = str(llm_meta.get("reason")) elif llm_meta.get("error"): reason_key = "llm_error" elif llm_meta.get("attempted"): reason_key = "empty_or_invalid_response" else: reason_key = "heuristic_default" step_fallback_reason = reason_key fallback_reasons[reason_key] = fallback_reasons.get(reason_key, 0) + 1 if trace_agent and not compliance_stdout: print(f"[trace] step={step_idx + 1} action_source={action_source}") if llm_meta.get("attempted"): _trace_print( True, "[trace] raw model output:", text=str(llm_meta.get("raw_output", "")), max_chars=trace_max_chars, ) if llm_meta.get("error"): print(f"[trace] llm_error={llm_meta['error']}") print(f"[trace] action={action.model_dump_json()}") obs, reward, done, info = env.step(action) rewards.append(reward) steps_taken = step_idx + 1 if action.action_type == "read_file": read_attempt_counts[action.argument] = read_attempt_counts.get(action.argument, 0) + 1 if action.argument not in unique_read_files: unique_read_files.append(action.argument) if reward <= 0: zero_gain_read_files.add(action.argument) elif action.action_type == "search_code": if action.argument not in search_patterns: search_patterns.append(action.argument) if done: no_progress_streak = 0 elif reward <= 0: no_progress_streak += 1 else: no_progress_streak = 0 memory_hint = _build_episode_memory( unique_read_files=unique_read_files, zero_gain_read_files=zero_gain_read_files, search_patterns=search_patterns, blocked_duplicate_reads=blocked_duplicate_reads, no_progress_streak=no_progress_streak, max_chars=MEMORY_MAX_CHARS, ) step_error: str | None = None if isinstance(info, dict): raw_err = info.get("last_action_error") if raw_err: step_error = str(raw_err) if not step_error and obs.tool_output and str(obs.tool_output).startswith("ERROR:"): step_error = str(obs.tool_output) if step_fallback_reason: fallback_detail = "" if llm_meta.get("error"): fallback_detail = f" detail={_short_error(str(llm_meta['error']))}" fallback_suffix = f"llm_fallback:{step_fallback_reason}{fallback_detail}" if step_error: step_error = f"{step_error}; {fallback_suffix}" else: step_error = fallback_suffix if compliance_stdout: _compliance_log_step( step=steps_taken, action=action.model_dump_json(), reward=reward, done=done, error=step_error, ) if trace_agent and not compliance_stdout: print( f"[trace] step_result reward={reward:.3f} done={done} " f"step_count={obs.step_count}" ) if obs.tool_output: _trace_print( True, "[trace] tool_output:", text=obs.tool_output, max_chars=trace_max_chars, ) if done: # Terminal reward already includes cumulative progress + terminal score. final_episode_score = reward terminal_meta = { "action_type": action.action_type, "terminal_score": float(info.get("terminal_score", 0) or 0), "progress_score": float(info.get("progress_score", 0) or 0), "explore_sum": exploration_reward_total, "episode_score": final_episode_score, "llm_steps": llm_steps, "heuristic_steps": heuristic_steps, "fallback_reasons": dict(fallback_reasons), "context_prune_events": prune_events, "duplicate_read_blocks": blocked_duplicate_reads, } success = final_episode_score > 0.0 if print_terminal: print( f" Terminal: {action.action_type}({action.argument[:40]}) " f"-> terminal={info.get('terminal_score', 0):.2f} " f"progress={info.get('progress_score', 0):.2f} " f"explore_sum={exploration_reward_total:.3f} " f"episode_score={final_episode_score:.3f}" ) break exploration_reward_total += reward messages.append({"role": "assistant", "content": action.model_dump_json()}) next_prompt = obs_to_prompt(obs, memory_hint=memory_hint, max_steps=env.max_steps) messages.append({"role": "user", "content": next_prompt}) if trace_agent and trace_prompts and not compliance_stdout: _trace_print( True, f"[trace] next user prompt (step={step_idx + 1}):", text=next_prompt, max_chars=trace_max_chars, ) except Exception as exc: terminal_meta["error"] = str(exc) success = False if not compliance_stdout: raise finally: if compliance_stdout: try: env.close() except Exception: pass _compliance_log_end( success=success, steps=steps_taken, score=min(max(final_episode_score, 0.001), 0.999), rewards=rewards, ) return final_episode_score, terminal_meta def _looks_like_placeholder_dataset(dataset_path: str) -> bool: path = Path(dataset_path) if not path.exists(): return False try: text = path.read_text(encoding="utf-8", errors="replace") except Exception: return False return "fixture://" in text def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run FlakySleuth baseline inference.") parser.add_argument( "--dataset-path", default="dataset/py_tasks.csv", help="Processed task CSV used by the environment.", ) parser.add_argument( "--episodes-per-task", type=int, default=EPISODES_PER_TASK, help="Episodes per task type.", ) parser.add_argument( "--task-types", default="classify,root_cause,fix_proposal", help="Comma-separated task types to run (classify,root_cause,fix_proposal).", ) parser.add_argument( "--max-steps", type=int, default=MAX_STEPS, help="Max steps per episode.", ) parser.add_argument( "--no-progress", action="store_true", help="Disable progress bars and print classic per-episode logs.", ) parser.add_argument( "--trace-agent", action="store_true", help=( "Print detailed agent trace: model output, chosen action/tool call, and " "step results for every episode." ), ) parser.add_argument( "--trace-prompts", action="store_true", help="When tracing, also print full prompts sent to the model.", ) parser.add_argument( "--trace-max-chars", type=int, default=2500, help="Max chars per traced text block (prompt/model output/tool output).", ) parser.add_argument( "--compliance-stdout", dest="compliance_stdout", action="store_true", help=( "Emit strict compliance logs to stdout using only [START]/[STEP]/[END] lines " "for each episode." ), ) parser.add_argument( "--no-compliance-stdout", dest="compliance_stdout", action="store_false", help="Disable strict compliance logs and print baseline summaries/progress.", ) parser.add_argument( "--benchmark-name", default="flakysleuth", help="Benchmark name used in [START] lines when --compliance-stdout is enabled.", ) parser.add_argument( "--history-prune-start-step", type=int, default=12, help="Start pruning conversation history only from this step onward.", ) parser.add_argument( "--history-window-turns", type=int, default=4, help="When pruning is active, keep this many recent assistant/user turns.", ) parser.add_argument( "--history-max-chars", type=int, default=50000, help="Approx max chars for messages before forced pruning by size.", ) parser.set_defaults(compliance_stdout=True) return parser.parse_args() def main() -> None: run_start = time.perf_counter() args = _parse_args() env = FlakySleuthEnv(dataset_path=args.dataset_path, max_steps=args.max_steps) allowed_task_types = {"classify", "root_cause", "fix_proposal"} task_types = [t.strip() for t in args.task_types.split(",") if t.strip()] invalid = [t for t in task_types if t not in allowed_task_types] if invalid: raise ValueError( f"Invalid task type(s): {invalid}. " "Valid values: classify,root_cause,fix_proposal." ) if not task_types: raise ValueError( "No task types selected. Pass --task-types with at least one value." ) results: dict[str, list[float]] = defaultdict(list) if _looks_like_placeholder_dataset(args.dataset_path) and not args.compliance_stdout: print( "[warning] dataset appears to contain fixture rows (fixture://...). " "Build real dataset from py-data.csv for real evaluation." ) use_progress = ( (tqdm is not None) and (not args.no_progress) and (not args.compliance_stdout) and os.isatty(1) ) if args.trace_agent and use_progress and not args.compliance_stdout: print( "[info] --trace-agent enabled, disabling progress bars for readable trace logs." ) use_progress = False overall_bar = None if use_progress: overall_bar = tqdm( total=len(task_types) * args.episodes_per_task, desc="All tasks", unit="ep", dynamic_ncols=True, ) for task_type in task_types: task_start = time.perf_counter() if not args.compliance_stdout: print(f"\n-- Task type: {task_type} --") env.loader.force_task_type(task_type) task_bar = None if use_progress: task_bar = tqdm( total=args.episodes_per_task, desc=f"{task_type}", unit="ep", leave=False, dynamic_ncols=True, ) for episode in range(args.episodes_per_task): score, meta = run_episode( env, print_terminal=(not use_progress) and (not args.compliance_stdout), trace_agent=args.trace_agent, trace_prompts=args.trace_prompts, trace_max_chars=args.trace_max_chars, episode_label=f"{task_type} ep={episode + 1}/{args.episodes_per_task}", compliance_stdout=args.compliance_stdout, benchmark_name=args.benchmark_name, compliance_task_name=task_type, history_prune_start_step=args.history_prune_start_step, history_window_turns=args.history_window_turns, history_max_chars=args.history_max_chars, ) results[task_type].append(score) if use_progress and task_bar is not None: task_bar.update(1) task_avg = sum(results[task_type]) / len(results[task_type]) task_bar.set_postfix( score=f"{score:.3f}", avg=f"{task_avg:.3f}", term=f"{meta.get('terminal_score', 0):.2f}", ) if overall_bar is not None: overall_bar.update(1) all_scores = [s for values in results.values() for s in values] overall_avg = sum(all_scores) / len(all_scores) overall_bar.set_postfix(task=task_type, avg=f"{overall_avg:.3f}") elif not args.compliance_stdout: print(f" Episode {episode + 1}: {score:.3f}") if task_bar is not None: task_bar.close() task_elapsed = time.perf_counter() - task_start if not args.compliance_stdout: avg_task = sum(results[task_type]) / max(1, len(results[task_type])) print( f" [time] task={task_type} elapsed={_format_duration(task_elapsed)} " f"avg_ep={task_elapsed / max(1, args.episodes_per_task):.2f}s " f"avg_score={avg_task:.3f}" ) if overall_bar is not None: overall_bar.close() if args.compliance_stdout: return total_elapsed = time.perf_counter() - run_start print("\n== BASELINE RESULTS ==") all_scores: list[float] = [] for task_type in task_types: scores = results[task_type] avg = sum(scores) / len(scores) all_scores.extend(scores) print(f" {task_type:12s} avg={avg:.3f} scores={[round(s, 3) for s in scores]}") overall = sum(all_scores) / len(all_scores) print(f" {'OVERALL':12s} avg={overall:.3f}") print( f" {'RUNTIME':12s} total={_format_duration(total_elapsed)} " f"episodes={len(all_scores)} " f"avg_ep={(total_elapsed / max(1, len(all_scores))):.2f}s" ) if __name__ == "__main__": main()