Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| """ | |
| Batch Agent Runner | |
| This module provides parallel batch processing capabilities for running the agent | |
| across multiple prompts from a dataset. It includes: | |
| - Dataset loading and batching | |
| - Parallel batch processing with multiprocessing | |
| - Checkpointing for fault tolerance and resumption | |
| - Trajectory saving in the proper format (from/value pairs) | |
| - Tool usage statistics aggregation across all batches | |
| Usage: | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run | |
| # Resume an interrupted run | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume | |
| # Use a specific toolset distribution | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen | |
| """ | |
| # IMPORTANT: hermes_bootstrap must be the very first import — UTF-8 stdio | |
| # on Windows. No-op on POSIX. See hermes_bootstrap.py for full rationale. | |
| try: | |
| import hermes_bootstrap # noqa: F401 | |
| except ModuleNotFoundError: | |
| # Graceful fallback when hermes_bootstrap isn't registered in the venv | |
| # yet — happens during partial ``hermes update`` where git-reset landed | |
| # new code but ``uv pip install -e .`` didn't finish. Missing bootstrap | |
| # means UTF-8 stdio setup is skipped on Windows; POSIX is unaffected. | |
| pass | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from datetime import datetime | |
| from multiprocessing import Pool, Lock | |
| import traceback | |
| from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn | |
| from rich.console import Console | |
| logger = logging.getLogger(__name__) | |
| import fire | |
| from run_agent import AIAgent | |
| from toolset_distributions import ( | |
| list_distributions, | |
| sample_toolsets_from_distribution, | |
| validate_distribution | |
| ) | |
| from model_tools import TOOL_TO_TOOLSET_MAP | |
| # Global configuration for worker processes | |
| _WORKER_CONFIG = {} | |
| # All possible tools - auto-derived from the master mapping in model_tools.py. | |
| # This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP. | |
| # Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for | |
| # filtering corrupted entries during trajectory combination. | |
| ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys()) | |
| # Default stats for tools that weren't used | |
| DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0} | |
| def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: | |
| """ | |
| Normalize tool_stats to include all possible tools with consistent schema. | |
| This ensures HuggingFace datasets can load the JSONL without schema mismatch errors. | |
| Tools that weren't used get zero counts. | |
| Args: | |
| tool_stats (Dict): Raw tool statistics from extraction | |
| Returns: | |
| Dict: Normalized tool statistics with all tools present | |
| """ | |
| normalized = {} | |
| # Add all possible tools with defaults | |
| for tool in ALL_POSSIBLE_TOOLS: | |
| if tool in tool_stats: | |
| normalized[tool] = tool_stats[tool].copy() | |
| else: | |
| normalized[tool] = DEFAULT_TOOL_STATS.copy() | |
| # Also include any unexpected tools (in case new tools are added) | |
| for tool, stats in tool_stats.items(): | |
| if tool not in normalized: | |
| normalized[tool] = stats.copy() | |
| return normalized | |
| def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]: | |
| """ | |
| Normalize tool_error_counts to include all possible tools. | |
| Args: | |
| tool_error_counts (Dict): Raw error counts mapping | |
| Returns: | |
| Dict: Normalized error counts with all tools present | |
| """ | |
| normalized = {} | |
| # Add all possible tools with zero defaults | |
| for tool in ALL_POSSIBLE_TOOLS: | |
| normalized[tool] = tool_error_counts.get(tool, 0) | |
| # Also include any unexpected tools | |
| for tool, count in tool_error_counts.items(): | |
| if tool not in normalized: | |
| normalized[tool] = count | |
| return normalized | |
| def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: | |
| """ | |
| Extract tool usage statistics from message history. | |
| Args: | |
| messages (List[Dict]): Message history | |
| Returns: | |
| Dict: Tool statistics with counts and success/failure rates | |
| """ | |
| tool_stats = {} | |
| # Track tool calls and their results | |
| tool_calls_map = {} # Map tool_call_id to tool name | |
| for msg in messages: | |
| # Track tool calls from assistant messages | |
| if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: | |
| for tool_call in msg["tool_calls"]: | |
| if not tool_call or not isinstance(tool_call, dict): continue | |
| tool_name = tool_call["function"]["name"] | |
| tool_call_id = tool_call["id"] | |
| # Initialize stats for this tool if not exists | |
| if tool_name not in tool_stats: | |
| tool_stats[tool_name] = { | |
| "count": 0, | |
| "success": 0, | |
| "failure": 0 | |
| } | |
| tool_stats[tool_name]["count"] += 1 | |
| tool_calls_map[tool_call_id] = tool_name | |
| # Track tool responses | |
| elif msg["role"] == "tool": | |
| tool_call_id = msg.get("tool_call_id", "") | |
| content = msg.get("content", "") | |
| # Determine if tool call was successful | |
| is_success = True | |
| try: | |
| # Try to parse as JSON and check for actual error values | |
| content_json = json.loads(content) if isinstance(content, str) else content | |
| if isinstance(content_json, dict): | |
| # Check if error field exists AND has a non-null value | |
| if "error" in content_json and content_json["error"] is not None: | |
| is_success = False | |
| # Special handling for terminal tool responses | |
| # Terminal wraps its response in a "content" field | |
| if "content" in content_json and isinstance(content_json["content"], dict): | |
| inner_content = content_json["content"] | |
| # Check for actual error (non-null error field) | |
| # Note: non-zero exit codes are not failures - the model can self-correct | |
| if inner_content.get("error") is not None: | |
| is_success = False | |
| # Check for "success": false pattern used by some tools | |
| if content_json.get("success") is False: | |
| is_success = False | |
| except (json.JSONDecodeError, ValueError, TypeError): | |
| # If not JSON, check if content is empty or explicitly states an error | |
| # Note: We avoid simple substring matching to prevent false positives | |
| if not content: | |
| is_success = False | |
| # Only mark as failure if it explicitly starts with "Error:" or "ERROR:" | |
| elif content.strip().lower().startswith("error:"): | |
| is_success = False | |
| # Update success/failure count | |
| if tool_call_id in tool_calls_map: | |
| tool_name = tool_calls_map[tool_call_id] | |
| if is_success: | |
| tool_stats[tool_name]["success"] += 1 | |
| else: | |
| tool_stats[tool_name]["failure"] += 1 | |
| return tool_stats | |
| def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]: | |
| """ | |
| Count how many assistant turns have reasoning vs no reasoning. | |
| Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field | |
| (native thinking tokens). Returns counts for tracking reasoning coverage. | |
| Args: | |
| messages: Message history | |
| Returns: | |
| Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning' | |
| """ | |
| total = 0 | |
| with_reasoning = 0 | |
| for msg in messages: | |
| if msg.get("role") != "assistant": | |
| continue | |
| total += 1 | |
| content = msg.get("content", "") or "" | |
| has_scratchpad = "<REASONING_SCRATCHPAD>" in content | |
| has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False | |
| if has_scratchpad or has_native_reasoning: | |
| with_reasoning += 1 | |
| return { | |
| "total_assistant_turns": total, | |
| "turns_with_reasoning": with_reasoning, | |
| "turns_without_reasoning": total - with_reasoning, | |
| "has_any_reasoning": with_reasoning > 0, | |
| } | |
| def _process_single_prompt( | |
| prompt_index: int, | |
| prompt_data: Dict[str, Any], | |
| batch_num: int, | |
| config: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Process a single prompt with the agent. | |
| Args: | |
| prompt_index (int): Index of prompt in dataset | |
| prompt_data (Dict): Prompt data containing 'prompt' field and optional 'image' field | |
| batch_num (int): Batch number | |
| config (Dict): Configuration dict with agent parameters | |
| Returns: | |
| Dict: Result containing trajectory, stats, and metadata | |
| """ | |
| prompt = prompt_data["prompt"] | |
| task_id = f"task_{prompt_index}" | |
| # Per-prompt container image override: if the dataset row has an 'image' field, | |
| # register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona. | |
| container_image = prompt_data.get("image") or prompt_data.get("docker_image") | |
| if container_image: | |
| # Verify the image is accessible before spending tokens on the agent loop. | |
| # For Docker: check local cache, then try pulling. | |
| # For Modal: skip local check (Modal pulls server-side). | |
| env_type = os.getenv("TERMINAL_ENV", "local") | |
| if env_type == "docker": | |
| import subprocess as _sp | |
| try: | |
| probe = _sp.run( | |
| ["docker", "image", "inspect", container_image], | |
| capture_output=True, timeout=10, | |
| ) | |
| if probe.returncode != 0: | |
| if config.get("verbose"): | |
| print(f" Prompt {prompt_index}: Pulling docker image {container_image}...", flush=True) | |
| pull = _sp.run( | |
| ["docker", "pull", container_image], | |
| capture_output=True, text=True, timeout=600, | |
| ) | |
| if pull.returncode != 0: | |
| return { | |
| "success": False, | |
| "prompt_index": prompt_index, | |
| "error": f"Docker image not available: {container_image}\n{pull.stderr[:500]}", | |
| "trajectory": None, | |
| "tool_stats": {}, | |
| "toolsets_used": [], | |
| "metadata": {"batch_num": batch_num, "timestamp": datetime.now().isoformat()}, | |
| } | |
| except FileNotFoundError: | |
| pass # Docker CLI not installed — skip check (e.g., Modal backend) | |
| except Exception as img_err: | |
| if config.get("verbose"): | |
| print(f" Prompt {prompt_index}: Docker image check failed: {img_err}", flush=True) | |
| from tools.terminal_tool import register_task_env_overrides | |
| overrides = { | |
| "docker_image": container_image, | |
| "modal_image": container_image, | |
| "singularity_image": f"docker://{container_image}", | |
| "daytona_image": container_image, | |
| } | |
| if prompt_data.get("cwd"): | |
| overrides["cwd"] = prompt_data["cwd"] | |
| register_task_env_overrides(task_id, overrides) | |
| if config.get("verbose"): | |
| print(f" Prompt {prompt_index}: Using container image {container_image}") | |
| try: | |
| # Sample toolsets from distribution for this prompt | |
| selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) | |
| if config.get("verbose"): | |
| print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") | |
| # Initialize agent with sampled toolsets and log prefix for identification | |
| log_prefix = f"[B{batch_num}:P{prompt_index}]" | |
| agent = AIAgent( | |
| base_url=config.get("base_url"), | |
| api_key=config.get("api_key"), | |
| model=config["model"], | |
| max_iterations=config["max_iterations"], | |
| enabled_toolsets=selected_toolsets, | |
| save_trajectories=False, # We handle saving ourselves | |
| verbose_logging=config.get("verbose", False), | |
| ephemeral_system_prompt=config.get("ephemeral_system_prompt"), | |
| log_prefix_chars=config.get("log_prefix_chars", 100), | |
| log_prefix=log_prefix, | |
| providers_allowed=config.get("providers_allowed"), | |
| providers_ignored=config.get("providers_ignored"), | |
| providers_order=config.get("providers_order"), | |
| provider_sort=config.get("provider_sort"), | |
| openrouter_min_coding_score=config.get("openrouter_min_coding_score"), | |
| max_tokens=config.get("max_tokens"), | |
| reasoning_config=config.get("reasoning_config"), | |
| prefill_messages=config.get("prefill_messages"), | |
| skip_context_files=True, # Don't pollute trajectories with SOUL.md/AGENTS.md | |
| skip_memory=True, # Don't use persistent memory in batch runs | |
| ) | |
| # Run the agent with task_id to ensure each task gets its own isolated VM | |
| result = agent.run_conversation(prompt, task_id=task_id) | |
| # Extract tool usage statistics | |
| tool_stats = _extract_tool_stats(result["messages"]) | |
| # Extract reasoning coverage stats | |
| reasoning_stats = _extract_reasoning_stats(result["messages"]) | |
| # Convert to trajectory format (using existing method) | |
| trajectory = agent._convert_to_trajectory_format( | |
| result["messages"], | |
| prompt, | |
| result["completed"] | |
| ) | |
| return { | |
| "success": True, | |
| "prompt_index": prompt_index, | |
| "trajectory": trajectory, | |
| "tool_stats": tool_stats, | |
| "reasoning_stats": reasoning_stats, | |
| "completed": result["completed"], | |
| "partial": result.get("partial", False), | |
| "api_calls": result["api_calls"], | |
| "toolsets_used": selected_toolsets, | |
| "metadata": { | |
| "batch_num": batch_num, | |
| "timestamp": datetime.now().isoformat(), | |
| "model": config["model"] | |
| } | |
| } | |
| except Exception as e: | |
| print(f"❌ Error processing prompt {prompt_index}: {e}") | |
| if config.get("verbose"): | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "prompt_index": prompt_index, | |
| "error": str(e), | |
| "trajectory": None, | |
| "tool_stats": {}, | |
| "toolsets_used": [], | |
| "metadata": { | |
| "batch_num": batch_num, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| } | |
| def _process_batch_worker(args: Tuple) -> Dict[str, Any]: | |
| """ | |
| Worker function to process a single batch of prompts. | |
| Args: | |
| args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) | |
| Returns: | |
| Dict: Batch results with statistics | |
| """ | |
| batch_num, batch_data, output_dir, completed_prompts_set, config = args | |
| output_dir = Path(output_dir) | |
| print(f"\n🔄 Batch {batch_num}: Starting ({len(batch_data)} prompts)") | |
| # Output file for this batch | |
| batch_output_file = output_dir / f"batch_{batch_num}.jsonl" | |
| # Filter out already completed prompts | |
| prompts_to_process = [ | |
| (idx, data) for idx, data in batch_data | |
| if idx not in completed_prompts_set | |
| ] | |
| if not prompts_to_process: | |
| print(f"✅ Batch {batch_num}: Already completed (skipping)") | |
| return { | |
| "batch_num": batch_num, | |
| "processed": 0, | |
| "skipped": len(batch_data), | |
| "tool_stats": {}, | |
| "completed_prompts": [] | |
| } | |
| print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") | |
| # Initialize aggregated stats for this batch | |
| batch_tool_stats = {} | |
| batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} | |
| completed_in_batch = [] | |
| discarded_no_reasoning = 0 | |
| # Process each prompt sequentially in this batch | |
| for prompt_index, prompt_data in prompts_to_process: | |
| # Process the prompt | |
| result = _process_single_prompt( | |
| prompt_index, | |
| prompt_data, | |
| batch_num, | |
| config | |
| ) | |
| # Save trajectory if successful | |
| if result["success"] and result["trajectory"]: | |
| # Discard samples with zero reasoning across all turns | |
| reasoning = result.get("reasoning_stats", {}) | |
| if not reasoning.get("has_any_reasoning", True): | |
| print(f" 🚫 Prompt {prompt_index} discarded (no reasoning in any turn)") | |
| discarded_no_reasoning += 1 | |
| completed_in_batch.append(prompt_index) | |
| continue | |
| # Get and normalize tool stats for consistent schema across all entries | |
| raw_tool_stats = result.get("tool_stats", {}) | |
| tool_stats = _normalize_tool_stats(raw_tool_stats) | |
| # Create normalized tool_error_counts mapping tool names to their failure counts | |
| raw_error_counts = { | |
| tool_name: stats.get("failure", 0) | |
| for tool_name, stats in raw_tool_stats.items() | |
| } | |
| tool_error_counts = _normalize_tool_error_counts(raw_error_counts) | |
| trajectory_entry = { | |
| "prompt_index": prompt_index, | |
| "conversations": result["trajectory"], | |
| "metadata": result["metadata"], | |
| "completed": result["completed"], | |
| "partial": result.get("partial", False), # True if stopped due to invalid tool calls | |
| "api_calls": result["api_calls"], | |
| "toolsets_used": result["toolsets_used"], | |
| "tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized | |
| "tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized | |
| } | |
| # Append to batch output file | |
| with open(batch_output_file, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") | |
| # Aggregate tool statistics | |
| for tool_name, stats in result.get("tool_stats", {}).items(): | |
| if tool_name not in batch_tool_stats: | |
| batch_tool_stats[tool_name] = { | |
| "count": 0, | |
| "success": 0, | |
| "failure": 0 | |
| } | |
| batch_tool_stats[tool_name]["count"] += stats["count"] | |
| batch_tool_stats[tool_name]["success"] += stats["success"] | |
| batch_tool_stats[tool_name]["failure"] += stats["failure"] | |
| # Aggregate reasoning stats | |
| for key in batch_reasoning_stats: | |
| batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0) | |
| # Only mark as completed if successfully saved (failed prompts can be retried on resume) | |
| if result["success"] and result["trajectory"]: | |
| completed_in_batch.append(prompt_index) | |
| status = "⚠️ partial" if result.get("partial") else "✅" | |
| print(f" {status} Prompt {prompt_index} completed") | |
| else: | |
| print(f" ❌ Prompt {prompt_index} failed (will retry on resume)") | |
| print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") | |
| return { | |
| "batch_num": batch_num, | |
| "processed": len(prompts_to_process), | |
| "skipped": len(batch_data) - len(prompts_to_process), | |
| "tool_stats": batch_tool_stats, | |
| "reasoning_stats": batch_reasoning_stats, | |
| "discarded_no_reasoning": discarded_no_reasoning, | |
| "completed_prompts": completed_in_batch | |
| } | |
| class BatchRunner: | |
| """ | |
| Manages batch processing of agent prompts with checkpointing and statistics. | |
| """ | |
| def __init__( | |
| self, | |
| dataset_file: str, | |
| batch_size: int, | |
| run_name: str, | |
| distribution: str = "default", | |
| max_iterations: int = 10, | |
| base_url: str = None, | |
| api_key: str = None, | |
| model: str = "claude-opus-4-20250514", | |
| num_workers: int = 4, | |
| verbose: bool = False, | |
| ephemeral_system_prompt: str = None, | |
| log_prefix_chars: int = 100, | |
| providers_allowed: List[str] = None, | |
| providers_ignored: List[str] = None, | |
| providers_order: List[str] = None, | |
| provider_sort: str = None, | |
| openrouter_min_coding_score: Optional[float] = None, | |
| max_tokens: int = None, | |
| reasoning_config: Dict[str, Any] = None, | |
| prefill_messages: List[Dict[str, Any]] = None, | |
| max_samples: int = None, | |
| ): | |
| """ | |
| Initialize the batch runner. | |
| Args: | |
| dataset_file (str): Path to the dataset JSONL file with 'prompt' field | |
| batch_size (int): Number of prompts per batch | |
| run_name (str): Name for this run (used for checkpointing and output) | |
| distribution (str): Toolset distribution to use (default: "default") | |
| max_iterations (int): Max iterations per agent run | |
| base_url (str): Base URL for model API | |
| api_key (str): API key for model | |
| model (str): Model name to use | |
| num_workers (int): Number of parallel workers | |
| verbose (bool): Enable verbose logging | |
| ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) | |
| log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) | |
| providers_allowed (List[str]): OpenRouter providers to allow (optional) | |
| providers_ignored (List[str]): OpenRouter providers to ignore (optional) | |
| providers_order (List[str]): OpenRouter providers to try in order (optional) | |
| provider_sort (str): Sort providers by price/throughput/latency (optional) | |
| max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) | |
| reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking) | |
| prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming). | |
| NOTE: Anthropic Sonnet 4.6+ and Opus 4.6+ reject a trailing assistant-role prefill | |
| (400 error). For those models use output_config.format or structured-output | |
| schemas instead. Safe here for user-role priming and for older Claude / non-Claude models. | |
| max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) | |
| """ | |
| self.dataset_file = Path(dataset_file) | |
| self.batch_size = batch_size | |
| self.run_name = run_name | |
| self.distribution = distribution | |
| self.max_iterations = max_iterations | |
| self.base_url = base_url | |
| self.api_key = api_key | |
| self.model = model | |
| self.num_workers = num_workers | |
| self.verbose = verbose | |
| self.ephemeral_system_prompt = ephemeral_system_prompt | |
| self.log_prefix_chars = log_prefix_chars | |
| self.providers_allowed = providers_allowed | |
| self.providers_ignored = providers_ignored | |
| self.providers_order = providers_order | |
| self.provider_sort = provider_sort | |
| self.openrouter_min_coding_score = openrouter_min_coding_score | |
| self.max_tokens = max_tokens | |
| self.reasoning_config = reasoning_config | |
| self.prefill_messages = prefill_messages | |
| self.max_samples = max_samples | |
| # Validate distribution | |
| if not validate_distribution(distribution): | |
| raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") | |
| # Setup output directory | |
| self.output_dir = Path("data") / run_name | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Checkpoint file | |
| self.checkpoint_file = self.output_dir / "checkpoint.json" | |
| # Statistics file | |
| self.stats_file = self.output_dir / "statistics.json" | |
| # Load dataset (and optionally truncate to max_samples) | |
| self.dataset = self._load_dataset() | |
| if self.max_samples and self.max_samples < len(self.dataset): | |
| full_count = len(self.dataset) | |
| self.dataset = self.dataset[:self.max_samples] | |
| print(f"✂️ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)") | |
| # Create batches | |
| self.batches = self._create_batches() | |
| print("📊 Batch Runner Initialized") | |
| print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") | |
| print(f" Batch size: {self.batch_size}") | |
| print(f" Total batches: {len(self.batches)}") | |
| print(f" Run name: {self.run_name}") | |
| print(f" Distribution: {self.distribution}") | |
| print(f" Output directory: {self.output_dir}") | |
| print(f" Workers: {self.num_workers}") | |
| if self.ephemeral_system_prompt: | |
| prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt | |
| print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'") | |
| def _load_dataset(self) -> List[Dict[str, Any]]: | |
| """ | |
| Load dataset from JSONL file. | |
| Returns: | |
| List[Dict]: List of dataset entries | |
| """ | |
| if not self.dataset_file.exists(): | |
| raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") | |
| dataset = [] | |
| with open(self.dataset_file, 'r', encoding='utf-8') as f: | |
| for line_num, line in enumerate(f, 1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| entry = json.loads(line) | |
| if 'prompt' not in entry: | |
| print(f"⚠️ Warning: Line {line_num} missing 'prompt' field, skipping") | |
| continue | |
| dataset.append(entry) | |
| except json.JSONDecodeError as e: | |
| print(f"⚠️ Warning: Invalid JSON on line {line_num}: {e}") | |
| continue | |
| if not dataset: | |
| raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") | |
| return dataset | |
| def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: | |
| """ | |
| Split dataset into batches with indices. | |
| Returns: | |
| List of batches, where each batch is a list of (index, entry) tuples | |
| """ | |
| batches = [] | |
| for i in range(0, len(self.dataset), self.batch_size): | |
| batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] | |
| batches.append(batch) | |
| return batches | |
| def _load_checkpoint(self) -> Dict[str, Any]: | |
| """ | |
| Load checkpoint data if it exists. | |
| Returns: | |
| Dict: Checkpoint data with completed prompt indices | |
| """ | |
| if not self.checkpoint_file.exists(): | |
| return { | |
| "run_name": self.run_name, | |
| "completed_prompts": [], | |
| "batch_stats": {}, | |
| "last_updated": None | |
| } | |
| try: | |
| with open(self.checkpoint_file, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"⚠️ Warning: Failed to load checkpoint: {e}") | |
| return { | |
| "run_name": self.run_name, | |
| "completed_prompts": [], | |
| "batch_stats": {}, | |
| "last_updated": None | |
| } | |
| def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): | |
| """ | |
| Save checkpoint data. | |
| Args: | |
| checkpoint_data (Dict): Checkpoint data to save | |
| lock (Lock): Optional lock for thread-safe access | |
| """ | |
| checkpoint_data["last_updated"] = datetime.now().isoformat() | |
| from utils import atomic_json_write | |
| if lock: | |
| with lock: | |
| atomic_json_write(self.checkpoint_file, checkpoint_data) | |
| else: | |
| atomic_json_write(self.checkpoint_file, checkpoint_data) | |
| def _scan_completed_prompts_by_content(self) -> set: | |
| """ | |
| Scan all batch files and extract completed prompts by their actual content. | |
| This provides a more robust resume mechanism that matches on prompt text | |
| rather than indices, allowing recovery even if indices don't match. | |
| Returns: | |
| set: Set of prompt texts that have been successfully processed | |
| """ | |
| completed_prompts = set() | |
| batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) | |
| if not batch_files: | |
| return completed_prompts | |
| print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...") | |
| for batch_file in batch_files: | |
| try: | |
| with open(batch_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| entry = json.loads(line.strip()) | |
| # Skip failed entries - we want to retry these | |
| if entry.get("failed", False): | |
| continue | |
| # Extract the human/user prompt from conversations | |
| conversations = entry.get("conversations", []) | |
| for msg in conversations: | |
| if msg.get("from") == "human": | |
| prompt_text = msg.get("value", "").strip() | |
| if prompt_text: | |
| completed_prompts.add(prompt_text) | |
| break # Only need the first human message | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}") | |
| return completed_prompts | |
| def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]: | |
| """ | |
| Filter the dataset to exclude prompts that have already been completed. | |
| Args: | |
| completed_prompts: Set of prompt texts that have been completed | |
| Returns: | |
| Tuple of (filtered_dataset, skipped_indices) | |
| """ | |
| filtered_dataset = [] | |
| skipped_indices = [] | |
| for idx, entry in enumerate(self.dataset): | |
| # Extract prompt from the dataset entry | |
| prompt_text = entry.get("prompt", "").strip() | |
| # Also check conversations format | |
| if not prompt_text: | |
| conversations = entry.get("conversations", []) | |
| for msg in conversations: | |
| role = msg.get("role") or msg.get("from") | |
| if role in ("user", "human"): | |
| prompt_text = (msg.get("content") or msg.get("value", "")).strip() | |
| break | |
| if prompt_text in completed_prompts: | |
| skipped_indices.append(idx) | |
| else: | |
| # Keep original index for tracking | |
| filtered_dataset.append((idx, entry)) | |
| return filtered_dataset, skipped_indices | |
| def run(self, resume: bool = False): | |
| """ | |
| Run the batch processing pipeline. | |
| Args: | |
| resume (bool): Whether to resume from checkpoint | |
| """ | |
| print("\n" + "=" * 70) | |
| print("🚀 Starting Batch Processing") | |
| print("=" * 70) | |
| # Smart resume: scan batch files by content to find completed prompts | |
| completed_prompt_texts = set() | |
| if resume: | |
| completed_prompt_texts = self._scan_completed_prompts_by_content() | |
| if completed_prompt_texts: | |
| print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching") | |
| # Filter dataset to only include unprocessed prompts | |
| if resume and completed_prompt_texts: | |
| filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts) | |
| if not filtered_entries: | |
| print("\n✅ All prompts have already been processed!") | |
| return | |
| # Recreate batches from filtered entries (keeping original indices for tracking) | |
| batches_to_process = [] | |
| for i in range(0, len(filtered_entries), self.batch_size): | |
| batch = filtered_entries[i:i + self.batch_size] | |
| batches_to_process.append(batch) | |
| self.batches = batches_to_process | |
| # Print prominent resume summary | |
| print("\n" + "=" * 70) | |
| print("📊 RESUME SUMMARY") | |
| print("=" * 70) | |
| print(f" Original dataset size: {len(self.dataset):,} prompts") | |
| print(f" Already completed: {len(skipped_indices):,} prompts") | |
| print(" ─────────────────────────────────────────") | |
| print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts") | |
| print(f" New batches created: {len(batches_to_process)}") | |
| print("=" * 70 + "\n") | |
| # Load existing checkpoint (so resume doesn't clobber prior progress) | |
| checkpoint_data = self._load_checkpoint() | |
| if checkpoint_data.get("run_name") != self.run_name: | |
| checkpoint_data = { | |
| "run_name": self.run_name, | |
| "completed_prompts": [], | |
| "batch_stats": {}, | |
| "last_updated": None | |
| } | |
| # Prepare configuration for workers | |
| config = { | |
| "distribution": self.distribution, | |
| "model": self.model, | |
| "max_iterations": self.max_iterations, | |
| "base_url": self.base_url, | |
| "api_key": self.api_key, | |
| "verbose": self.verbose, | |
| "ephemeral_system_prompt": self.ephemeral_system_prompt, | |
| "log_prefix_chars": self.log_prefix_chars, | |
| "providers_allowed": self.providers_allowed, | |
| "providers_ignored": self.providers_ignored, | |
| "providers_order": self.providers_order, | |
| "provider_sort": self.provider_sort, | |
| "openrouter_min_coding_score": self.openrouter_min_coding_score, | |
| "max_tokens": self.max_tokens, | |
| "reasoning_config": self.reasoning_config, | |
| "prefill_messages": self.prefill_messages, | |
| } | |
| # For backward compatibility, still track by index (but this is secondary to content matching) | |
| completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) | |
| # Aggregate statistics across all batches | |
| total_tool_stats = {} | |
| start_time = time.time() | |
| print(f"\n🔧 Initializing {self.num_workers} worker processes...") | |
| # Checkpoint writes happen in the parent process; keep a lock for safety. | |
| checkpoint_lock = Lock() | |
| # Process batches in parallel | |
| with Pool(processes=self.num_workers) as pool: | |
| # Create tasks for each batch | |
| tasks = [ | |
| ( | |
| batch_num, | |
| batch_data, | |
| str(self.output_dir), # Convert Path to string for pickling | |
| completed_prompts_set, | |
| config | |
| ) | |
| for batch_num, batch_data in enumerate(self.batches) | |
| ] | |
| print(f"✅ Created {len(tasks)} batch tasks") | |
| print("🚀 Starting parallel batch processing...\n") | |
| # Use rich Progress for better visual tracking with persistent bottom bar | |
| # redirect_stdout/stderr lets rich manage all output so progress bar stays clean | |
| results = [] | |
| console = Console(force_terminal=True) | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[bold blue]📦 Batches"), | |
| BarColumn(bar_width=40), | |
| MofNCompleteColumn(), | |
| TextColumn("•"), | |
| TimeRemainingColumn(), | |
| console=console, | |
| refresh_per_second=2, | |
| transient=False, | |
| redirect_stdout=False, | |
| redirect_stderr=False, | |
| ) as progress: | |
| task = progress.add_task("Processing", total=len(tasks)) | |
| # Temporarily suppress DEBUG logging to avoid bar interference | |
| root_logger = logging.getLogger() | |
| original_level = root_logger.level | |
| root_logger.setLevel(logging.WARNING) | |
| try: | |
| for result in pool.imap_unordered(_process_batch_worker, tasks): | |
| results.append(result) | |
| progress.update(task, advance=1) | |
| # Incremental checkpoint update (so resume works after crash) | |
| try: | |
| batch_num = result.get('batch_num') | |
| completed = result.get('completed_prompts', []) or [] | |
| completed_prompts_set.update(completed) | |
| if isinstance(batch_num, int): | |
| checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = { | |
| 'processed': result.get('processed', 0), | |
| 'skipped': result.get('skipped', 0), | |
| 'discarded_no_reasoning': result.get('discarded_no_reasoning', 0), | |
| } | |
| checkpoint_data['completed_prompts'] = sorted(completed_prompts_set) | |
| self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) | |
| except Exception as ckpt_err: | |
| # Don't fail the run if checkpoint write fails | |
| print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}") | |
| except Exception as e: | |
| logger.error("Batch worker failed: %s", e, exc_info=True) | |
| raise | |
| finally: | |
| root_logger.setLevel(original_level) | |
| # Aggregate all batch statistics and update checkpoint | |
| total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} | |
| for batch_result in results: | |
| # Aggregate tool stats | |
| for tool_name, stats in batch_result.get("tool_stats", {}).items(): | |
| if tool_name not in total_tool_stats: | |
| total_tool_stats[tool_name] = { | |
| "count": 0, | |
| "success": 0, | |
| "failure": 0 | |
| } | |
| total_tool_stats[tool_name]["count"] += stats["count"] | |
| total_tool_stats[tool_name]["success"] += stats["success"] | |
| total_tool_stats[tool_name]["failure"] += stats["failure"] | |
| # Aggregate reasoning stats | |
| for key in total_reasoning_stats: | |
| total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0) | |
| # Save final checkpoint (best-effort; incremental writes already happened) | |
| try: | |
| checkpoint_data["completed_prompts"] = sorted(completed_prompts_set) | |
| self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) | |
| except Exception as ckpt_err: | |
| print(f"âš ï¸ Warning: Failed to save final checkpoint: {ckpt_err}") | |
| # Calculate success rates | |
| for tool_name in total_tool_stats: | |
| stats = total_tool_stats[tool_name] | |
| total_calls = stats["success"] + stats["failure"] | |
| if total_calls > 0: | |
| stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) | |
| stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) | |
| else: | |
| stats["success_rate"] = 0.0 | |
| stats["failure_rate"] = 0.0 | |
| # Combine ALL batch files in directory into a single trajectories.jsonl file | |
| # This includes both old batches (from previous runs) and new batches (from resume) | |
| # Also filter out corrupted entries (where model generated invalid tool names) | |
| combined_file = self.output_dir / "trajectories.jsonl" | |
| print(f"\n📦 Combining ALL batch files into {combined_file.name}...") | |
| # Valid tools auto-derived from model_tools.py — no manual updates needed | |
| VALID_TOOLS = ALL_POSSIBLE_TOOLS | |
| total_entries = 0 | |
| filtered_entries = 0 | |
| batch_files_found = 0 | |
| # Find ALL batch files in the output directory (handles resume merging old + new) | |
| all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) | |
| with open(combined_file, 'w', encoding='utf-8') as outfile: | |
| for batch_file in all_batch_files: | |
| batch_files_found += 1 | |
| batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging | |
| with open(batch_file, 'r', encoding='utf-8') as infile: | |
| for line in infile: | |
| total_entries += 1 | |
| try: | |
| data = json.loads(line) | |
| tool_stats = data.get('tool_stats', {}) | |
| # Check for invalid tool names (model hallucinations) | |
| invalid_tools = [k for k in tool_stats if k not in VALID_TOOLS] | |
| if invalid_tools: | |
| filtered_entries += 1 | |
| invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0] | |
| print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'") | |
| continue | |
| outfile.write(line) | |
| except json.JSONDecodeError: | |
| filtered_entries += 1 | |
| print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})") | |
| if filtered_entries > 0: | |
| print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total") | |
| print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)") | |
| # Save final statistics | |
| final_stats = { | |
| "run_name": self.run_name, | |
| "distribution": self.distribution, | |
| "total_prompts": len(self.dataset), | |
| "total_batches": len(self.batches), | |
| "batch_size": self.batch_size, | |
| "model": self.model, | |
| "completed_at": datetime.now().isoformat(), | |
| "duration_seconds": round(time.time() - start_time, 2), | |
| "tool_statistics": total_tool_stats, | |
| "reasoning_statistics": total_reasoning_stats, | |
| } | |
| with open(self.stats_file, 'w', encoding='utf-8') as f: | |
| json.dump(final_stats, f, indent=2, ensure_ascii=False) | |
| # Print summary | |
| print("\n" + "=" * 70) | |
| print("📊 BATCH PROCESSING COMPLETE") | |
| print("=" * 70) | |
| print(f"✅ Prompts processed this run: {sum(r.get('processed', 0) for r in results)}") | |
| print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}") | |
| print(f"✅ Total batch files merged: {batch_files_found}") | |
| print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s") | |
| print("\n📈 Tool Usage Statistics:") | |
| print("-" * 70) | |
| if total_tool_stats: | |
| # Sort by count descending | |
| sorted_tools = sorted( | |
| total_tool_stats.items(), | |
| key=lambda x: x[1]["count"], | |
| reverse=True | |
| ) | |
| print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") | |
| print("-" * 70) | |
| for tool_name, stats in sorted_tools: | |
| print( | |
| f"{tool_name:<25} " | |
| f"{stats['count']:<10} " | |
| f"{stats['success']:<10} " | |
| f"{stats['failure']:<10} " | |
| f"{stats['success_rate']:.1f}%" | |
| ) | |
| else: | |
| print("No tool calls were made during this run.") | |
| # Print reasoning coverage stats | |
| total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results) | |
| print("\n🧠 Reasoning Coverage:") | |
| print("-" * 70) | |
| total_turns = total_reasoning_stats["total_assistant_turns"] | |
| with_reasoning = total_reasoning_stats["turns_with_reasoning"] | |
| without_reasoning = total_reasoning_stats["turns_without_reasoning"] | |
| if total_turns > 0: | |
| pct_with = round(with_reasoning / total_turns * 100, 1) | |
| pct_without = round(without_reasoning / total_turns * 100, 1) | |
| print(f" Total assistant turns: {total_turns:,}") | |
| print(f" With reasoning: {with_reasoning:,} ({pct_with}%)") | |
| print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)") | |
| else: | |
| print(" No assistant turns recorded.") | |
| if total_discarded > 0: | |
| print(f" 🚫 Samples discarded (zero reasoning): {total_discarded:,}") | |
| print(f"\n💾 Results saved to: {self.output_dir}") | |
| print(" - Trajectories: trajectories.jsonl (combined)") | |
| print(" - Individual batches: batch_*.jsonl (for debugging)") | |
| print(f" - Statistics: {self.stats_file.name}") | |
| print(f" - Checkpoint: {self.checkpoint_file.name}") | |
| def main( | |
| dataset_file: str = None, | |
| batch_size: int = None, | |
| run_name: str = None, | |
| distribution: str = "default", | |
| model: str = "anthropic/claude-sonnet-4.6", | |
| api_key: str = None, | |
| base_url: str = "https://openrouter.ai/api/v1", | |
| max_turns: int = 10, | |
| num_workers: int = 4, | |
| resume: bool = False, | |
| verbose: bool = False, | |
| list_distributions: bool = False, | |
| ephemeral_system_prompt: str = None, | |
| log_prefix_chars: int = 100, | |
| providers_allowed: str = None, | |
| providers_ignored: str = None, | |
| providers_order: str = None, | |
| provider_sort: str = None, | |
| max_tokens: int = None, | |
| reasoning_effort: str = None, | |
| reasoning_disabled: bool = False, | |
| prefill_messages_file: str = None, | |
| max_samples: int = None, | |
| ): | |
| """ | |
| Run batch processing of agent prompts from a dataset. | |
| Args: | |
| dataset_file (str): Path to JSONL file with 'prompt' field in each entry | |
| batch_size (int): Number of prompts per batch | |
| run_name (str): Name for this run (used for output and checkpointing) | |
| distribution (str): Toolset distribution to use (default: "default") | |
| model (str): Model name to use (default: "claude-opus-4-20250514") | |
| api_key (str): API key for model authentication | |
| base_url (str): Base URL for model API | |
| max_turns (int): Maximum number of tool calling iterations per prompt (default: 10) | |
| num_workers (int): Number of parallel worker processes (default: 4) | |
| resume (bool): Resume from checkpoint if run was interrupted (default: False) | |
| verbose (bool): Enable verbose logging (default: False) | |
| list_distributions (bool): List available toolset distributions and exit | |
| ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) | |
| log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) | |
| providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai") | |
| providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra") | |
| providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google") | |
| provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only) | |
| max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) | |
| reasoning_effort (str): OpenRouter reasoning effort level: "none", "minimal", "low", "medium", "high", "xhigh" (default: "medium") | |
| reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False) | |
| prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts) | |
| max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) | |
| Examples: | |
| # Basic usage | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run | |
| # Resume interrupted run | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume | |
| # Use specific distribution | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen | |
| # With disabled reasoning and max tokens | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ | |
| --reasoning_disabled --max_tokens=128000 | |
| # With prefill messages from file | |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ | |
| --prefill_messages_file=configs/prefill_opus.json | |
| # List available distributions | |
| python batch_runner.py --list_distributions | |
| """ | |
| # Handle list distributions | |
| if list_distributions: | |
| from toolset_distributions import print_distribution_info | |
| print("📊 Available Toolset Distributions") | |
| print("=" * 70) | |
| all_dists = list_distributions() | |
| for dist_name in sorted(all_dists.keys()): | |
| print_distribution_info(dist_name) | |
| print("\n💡 Usage:") | |
| print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") | |
| print(" --run_name=my_run --distribution=<name>") | |
| return | |
| # Validate required arguments | |
| if not dataset_file: | |
| print("❌ Error: --dataset_file is required") | |
| return | |
| if not batch_size or batch_size < 1: | |
| print("❌ Error: --batch_size must be a positive integer") | |
| return | |
| if not run_name: | |
| print("❌ Error: --run_name is required") | |
| return | |
| # Parse provider preferences (comma-separated strings to lists) | |
| providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None | |
| providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None | |
| providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None | |
| # Build reasoning_config from CLI flags | |
| # --reasoning_disabled takes priority, then --reasoning_effort, then default (medium) | |
| reasoning_config = None | |
| if reasoning_disabled: | |
| # Completely disable reasoning/thinking tokens | |
| reasoning_config = {"effort": "none"} | |
| print("🧠 Reasoning: DISABLED (effort=none)") | |
| elif reasoning_effort: | |
| # Use specified effort level | |
| valid_efforts = ["none", "minimal", "low", "medium", "high", "xhigh"] | |
| if reasoning_effort not in valid_efforts: | |
| print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}") | |
| return | |
| reasoning_config = {"enabled": True, "effort": reasoning_effort} | |
| print(f"🧠 Reasoning effort: {reasoning_effort}") | |
| # Load prefill messages from JSON file if provided | |
| prefill_messages = None | |
| if prefill_messages_file: | |
| try: | |
| with open(prefill_messages_file, 'r', encoding='utf-8') as f: | |
| prefill_messages = json.load(f) | |
| if not isinstance(prefill_messages, list): | |
| print("❌ Error: prefill_messages_file must contain a JSON array of messages") | |
| return | |
| print(f"💬 Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}") | |
| except Exception as e: | |
| print(f"❌ Error loading prefill messages: {e}") | |
| return | |
| # Initialize and run batch runner | |
| try: | |
| runner = BatchRunner( | |
| dataset_file=dataset_file, | |
| batch_size=batch_size, | |
| run_name=run_name, | |
| distribution=distribution, | |
| max_iterations=max_turns, | |
| base_url=base_url, | |
| api_key=api_key, | |
| model=model, | |
| num_workers=num_workers, | |
| verbose=verbose, | |
| ephemeral_system_prompt=ephemeral_system_prompt, | |
| log_prefix_chars=log_prefix_chars, | |
| providers_allowed=providers_allowed_list, | |
| providers_ignored=providers_ignored_list, | |
| providers_order=providers_order_list, | |
| provider_sort=provider_sort, | |
| max_tokens=max_tokens, | |
| reasoning_config=reasoning_config, | |
| prefill_messages=prefill_messages, | |
| max_samples=max_samples, | |
| ) | |
| runner.run(resume=resume) | |
| except Exception as e: | |
| print(f"\n❌ Fatal error: {e}") | |
| if verbose: | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |