| |
| """ |
| Batch Agent Runner |
| |
| This module provides parallel batch processing capabilities for running the agent |
| across multiple prompts from a dataset. It includes: |
| - Dataset loading and batching |
| - Parallel batch processing with multiprocessing |
| - Checkpointing for fault tolerance and resumption |
| - Trajectory saving in the proper format (from/value pairs) |
| - Tool usage statistics aggregation across all batches |
| |
| Usage: |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run |
| |
| # Resume an interrupted run |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume |
| |
| # Use a specific toolset distribution |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen |
| """ |
|
|
| import json |
| import logging |
| import os |
| import time |
| from pathlib import Path |
| from typing import List, Dict, Any, Optional, Tuple |
| from datetime import datetime |
| from multiprocessing import Pool, Lock |
| import traceback |
| from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn |
| from rich.console import Console |
| import fire |
|
|
| from run_agent import AIAgent |
| from toolset_distributions import ( |
| list_distributions, |
| sample_toolsets_from_distribution, |
| validate_distribution |
| ) |
| from model_tools import TOOL_TO_TOOLSET_MAP |
|
|
|
|
| |
| _WORKER_CONFIG = {} |
|
|
| |
| |
| |
| |
| ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys()) |
|
|
| |
| DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0} |
|
|
|
|
| def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: |
| """ |
| Normalize tool_stats to include all possible tools with consistent schema. |
| |
| This ensures HuggingFace datasets can load the JSONL without schema mismatch errors. |
| Tools that weren't used get zero counts. |
| |
| Args: |
| tool_stats (Dict): Raw tool statistics from extraction |
| |
| Returns: |
| Dict: Normalized tool statistics with all tools present |
| """ |
| normalized = {} |
| |
| |
| for tool in ALL_POSSIBLE_TOOLS: |
| if tool in tool_stats: |
| normalized[tool] = tool_stats[tool].copy() |
| else: |
| normalized[tool] = DEFAULT_TOOL_STATS.copy() |
| |
| |
| for tool, stats in tool_stats.items(): |
| if tool not in normalized: |
| normalized[tool] = stats.copy() |
| |
| return normalized |
|
|
|
|
| def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]: |
| """ |
| Normalize tool_error_counts to include all possible tools. |
| |
| Args: |
| tool_error_counts (Dict): Raw error counts mapping |
| |
| Returns: |
| Dict: Normalized error counts with all tools present |
| """ |
| normalized = {} |
| |
| |
| for tool in ALL_POSSIBLE_TOOLS: |
| normalized[tool] = tool_error_counts.get(tool, 0) |
| |
| |
| for tool, count in tool_error_counts.items(): |
| if tool not in normalized: |
| normalized[tool] = count |
| |
| return normalized |
|
|
|
|
| def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: |
| """ |
| Extract tool usage statistics from message history. |
| |
| Args: |
| messages (List[Dict]): Message history |
| |
| Returns: |
| Dict: Tool statistics with counts and success/failure rates |
| """ |
| tool_stats = {} |
| |
| |
| tool_calls_map = {} |
| |
| for msg in messages: |
| |
| if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: |
| for tool_call in msg["tool_calls"]: |
| if not tool_call or not isinstance(tool_call, dict): continue |
| tool_name = tool_call["function"]["name"] |
| tool_call_id = tool_call["id"] |
| |
| |
| if tool_name not in tool_stats: |
| tool_stats[tool_name] = { |
| "count": 0, |
| "success": 0, |
| "failure": 0 |
| } |
| |
| tool_stats[tool_name]["count"] += 1 |
| tool_calls_map[tool_call_id] = tool_name |
| |
| |
| elif msg["role"] == "tool": |
| tool_call_id = msg.get("tool_call_id", "") |
| content = msg.get("content", "") |
| |
| |
| is_success = True |
| try: |
| |
| content_json = json.loads(content) if isinstance(content, str) else content |
| |
| if isinstance(content_json, dict): |
| |
| if "error" in content_json and content_json["error"] is not None: |
| is_success = False |
| |
| |
| |
| if "content" in content_json and isinstance(content_json["content"], dict): |
| inner_content = content_json["content"] |
| |
| |
| if inner_content.get("error") is not None: |
| is_success = False |
| |
| |
| if content_json.get("success") is False: |
| is_success = False |
| |
| except (json.JSONDecodeError, ValueError, TypeError): |
| |
| |
| if not content: |
| is_success = False |
| |
| elif content.strip().lower().startswith("error:"): |
| is_success = False |
| |
| |
| if tool_call_id in tool_calls_map: |
| tool_name = tool_calls_map[tool_call_id] |
| if is_success: |
| tool_stats[tool_name]["success"] += 1 |
| else: |
| tool_stats[tool_name]["failure"] += 1 |
| |
| return tool_stats |
|
|
|
|
| def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]: |
| """ |
| Count how many assistant turns have reasoning vs no reasoning. |
| |
| Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field |
| (native thinking tokens). Returns counts for tracking reasoning coverage. |
| |
| Args: |
| messages: Message history |
| |
| Returns: |
| Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning' |
| """ |
| total = 0 |
| with_reasoning = 0 |
| |
| for msg in messages: |
| if msg.get("role") != "assistant": |
| continue |
| total += 1 |
| |
| content = msg.get("content", "") or "" |
| has_scratchpad = "<REASONING_SCRATCHPAD>" in content |
| has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False |
| |
| if has_scratchpad or has_native_reasoning: |
| with_reasoning += 1 |
| |
| return { |
| "total_assistant_turns": total, |
| "turns_with_reasoning": with_reasoning, |
| "turns_without_reasoning": total - with_reasoning, |
| "has_any_reasoning": with_reasoning > 0, |
| } |
|
|
|
|
| def _process_single_prompt( |
| prompt_index: int, |
| prompt_data: Dict[str, Any], |
| batch_num: int, |
| config: Dict[str, Any] |
| ) -> Dict[str, Any]: |
| """ |
| Process a single prompt with the agent. |
| |
| Args: |
| prompt_index (int): Index of prompt in dataset |
| prompt_data (Dict): Prompt data containing 'prompt' field and optional 'image' field |
| batch_num (int): Batch number |
| config (Dict): Configuration dict with agent parameters |
| |
| Returns: |
| Dict: Result containing trajectory, stats, and metadata |
| """ |
| prompt = prompt_data["prompt"] |
| task_id = f"task_{prompt_index}" |
| |
| |
| |
| container_image = prompt_data.get("image") or prompt_data.get("docker_image") |
| if container_image: |
| |
| |
| |
| env_type = os.getenv("TERMINAL_ENV", "local") |
| if env_type == "docker": |
| import subprocess as _sp |
| try: |
| probe = _sp.run( |
| ["docker", "image", "inspect", container_image], |
| capture_output=True, timeout=10, |
| ) |
| if probe.returncode != 0: |
| if config.get("verbose"): |
| print(f" Prompt {prompt_index}: Pulling docker image {container_image}...", flush=True) |
| pull = _sp.run( |
| ["docker", "pull", container_image], |
| capture_output=True, text=True, timeout=600, |
| ) |
| if pull.returncode != 0: |
| return { |
| "success": False, |
| "prompt_index": prompt_index, |
| "error": f"Docker image not available: {container_image}\n{pull.stderr[:500]}", |
| "trajectory": None, |
| "tool_stats": {}, |
| "toolsets_used": [], |
| "metadata": {"batch_num": batch_num, "timestamp": datetime.now().isoformat()}, |
| } |
| except FileNotFoundError: |
| pass |
| except Exception as img_err: |
| if config.get("verbose"): |
| print(f" Prompt {prompt_index}: Docker image check failed: {img_err}", flush=True) |
|
|
| from tools.terminal_tool import register_task_env_overrides |
| overrides = { |
| "docker_image": container_image, |
| "modal_image": container_image, |
| "singularity_image": f"docker://{container_image}", |
| "daytona_image": container_image, |
| } |
| if prompt_data.get("cwd"): |
| overrides["cwd"] = prompt_data["cwd"] |
| register_task_env_overrides(task_id, overrides) |
| if config.get("verbose"): |
| print(f" Prompt {prompt_index}: Using container image {container_image}") |
| |
| try: |
| |
| selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) |
| |
| if config.get("verbose"): |
| print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") |
| |
| |
| log_prefix = f"[B{batch_num}:P{prompt_index}]" |
| agent = AIAgent( |
| base_url=config.get("base_url"), |
| api_key=config.get("api_key"), |
| model=config["model"], |
| max_iterations=config["max_iterations"], |
| enabled_toolsets=selected_toolsets, |
| save_trajectories=False, |
| verbose_logging=config.get("verbose", False), |
| ephemeral_system_prompt=config.get("ephemeral_system_prompt"), |
| log_prefix_chars=config.get("log_prefix_chars", 100), |
| log_prefix=log_prefix, |
| providers_allowed=config.get("providers_allowed"), |
| providers_ignored=config.get("providers_ignored"), |
| providers_order=config.get("providers_order"), |
| provider_sort=config.get("provider_sort"), |
| max_tokens=config.get("max_tokens"), |
| reasoning_config=config.get("reasoning_config"), |
| prefill_messages=config.get("prefill_messages"), |
| skip_context_files=True, |
| skip_memory=True, |
| ) |
|
|
| |
| result = agent.run_conversation(prompt, task_id=task_id) |
| |
| |
| tool_stats = _extract_tool_stats(result["messages"]) |
| |
| |
| reasoning_stats = _extract_reasoning_stats(result["messages"]) |
| |
| |
| trajectory = agent._convert_to_trajectory_format( |
| result["messages"], |
| prompt, |
| result["completed"] |
| ) |
| |
| return { |
| "success": True, |
| "prompt_index": prompt_index, |
| "trajectory": trajectory, |
| "tool_stats": tool_stats, |
| "reasoning_stats": reasoning_stats, |
| "completed": result["completed"], |
| "partial": result.get("partial", False), |
| "api_calls": result["api_calls"], |
| "toolsets_used": selected_toolsets, |
| "metadata": { |
| "batch_num": batch_num, |
| "timestamp": datetime.now().isoformat(), |
| "model": config["model"] |
| } |
| } |
| |
| except Exception as e: |
| print(f"β Error processing prompt {prompt_index}: {e}") |
| if config.get("verbose"): |
| traceback.print_exc() |
| |
| return { |
| "success": False, |
| "prompt_index": prompt_index, |
| "error": str(e), |
| "trajectory": None, |
| "tool_stats": {}, |
| "toolsets_used": [], |
| "metadata": { |
| "batch_num": batch_num, |
| "timestamp": datetime.now().isoformat() |
| } |
| } |
|
|
|
|
| def _process_batch_worker(args: Tuple) -> Dict[str, Any]: |
| """ |
| Worker function to process a single batch of prompts. |
| |
| Args: |
| args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) |
| |
| Returns: |
| Dict: Batch results with statistics |
| """ |
| batch_num, batch_data, output_dir, completed_prompts_set, config = args |
| |
| output_dir = Path(output_dir) |
| print(f"\nπ Batch {batch_num}: Starting ({len(batch_data)} prompts)") |
| |
| |
| batch_output_file = output_dir / f"batch_{batch_num}.jsonl" |
| |
| |
| prompts_to_process = [ |
| (idx, data) for idx, data in batch_data |
| if idx not in completed_prompts_set |
| ] |
| |
| if not prompts_to_process: |
| print(f"β
Batch {batch_num}: Already completed (skipping)") |
| return { |
| "batch_num": batch_num, |
| "processed": 0, |
| "skipped": len(batch_data), |
| "tool_stats": {}, |
| "completed_prompts": [] |
| } |
| |
| print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") |
| |
| |
| batch_tool_stats = {} |
| batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} |
| completed_in_batch = [] |
| discarded_no_reasoning = 0 |
| |
| |
| for prompt_index, prompt_data in prompts_to_process: |
| |
| result = _process_single_prompt( |
| prompt_index, |
| prompt_data, |
| batch_num, |
| config |
| ) |
| |
| |
| if result["success"] and result["trajectory"]: |
| |
| reasoning = result.get("reasoning_stats", {}) |
| if not reasoning.get("has_any_reasoning", True): |
| print(f" π« Prompt {prompt_index} discarded (no reasoning in any turn)") |
| discarded_no_reasoning += 1 |
| continue |
| |
| |
| raw_tool_stats = result.get("tool_stats", {}) |
| tool_stats = _normalize_tool_stats(raw_tool_stats) |
| |
| |
| raw_error_counts = { |
| tool_name: stats.get("failure", 0) |
| for tool_name, stats in raw_tool_stats.items() |
| } |
| tool_error_counts = _normalize_tool_error_counts(raw_error_counts) |
| |
| trajectory_entry = { |
| "prompt_index": prompt_index, |
| "conversations": result["trajectory"], |
| "metadata": result["metadata"], |
| "completed": result["completed"], |
| "partial": result.get("partial", False), |
| "api_calls": result["api_calls"], |
| "toolsets_used": result["toolsets_used"], |
| "tool_stats": tool_stats, |
| "tool_error_counts": tool_error_counts |
| } |
| |
| |
| with open(batch_output_file, 'a', encoding='utf-8') as f: |
| f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") |
| |
| |
| for tool_name, stats in result.get("tool_stats", {}).items(): |
| if tool_name not in batch_tool_stats: |
| batch_tool_stats[tool_name] = { |
| "count": 0, |
| "success": 0, |
| "failure": 0 |
| } |
| |
| batch_tool_stats[tool_name]["count"] += stats["count"] |
| batch_tool_stats[tool_name]["success"] += stats["success"] |
| batch_tool_stats[tool_name]["failure"] += stats["failure"] |
| |
| |
| for key in batch_reasoning_stats: |
| batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0) |
| |
| |
| if result["success"] and result["trajectory"]: |
| completed_in_batch.append(prompt_index) |
| status = "β οΈ partial" if result.get("partial") else "β
" |
| print(f" {status} Prompt {prompt_index} completed") |
| else: |
| print(f" β Prompt {prompt_index} failed (will retry on resume)") |
| |
| print(f"β
Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") |
| |
| return { |
| "batch_num": batch_num, |
| "processed": len(prompts_to_process), |
| "skipped": len(batch_data) - len(prompts_to_process), |
| "tool_stats": batch_tool_stats, |
| "reasoning_stats": batch_reasoning_stats, |
| "discarded_no_reasoning": discarded_no_reasoning, |
| "completed_prompts": completed_in_batch |
| } |
|
|
|
|
| class BatchRunner: |
| """ |
| Manages batch processing of agent prompts with checkpointing and statistics. |
| """ |
| |
| def __init__( |
| self, |
| dataset_file: str, |
| batch_size: int, |
| run_name: str, |
| distribution: str = "default", |
| max_iterations: int = 10, |
| base_url: str = None, |
| api_key: str = None, |
| model: str = "claude-opus-4-20250514", |
| num_workers: int = 4, |
| verbose: bool = False, |
| ephemeral_system_prompt: str = None, |
| log_prefix_chars: int = 100, |
| providers_allowed: List[str] = None, |
| providers_ignored: List[str] = None, |
| providers_order: List[str] = None, |
| provider_sort: str = None, |
| max_tokens: int = None, |
| reasoning_config: Dict[str, Any] = None, |
| prefill_messages: List[Dict[str, Any]] = None, |
| max_samples: int = None, |
| ): |
| """ |
| Initialize the batch runner. |
| |
| Args: |
| dataset_file (str): Path to the dataset JSONL file with 'prompt' field |
| batch_size (int): Number of prompts per batch |
| run_name (str): Name for this run (used for checkpointing and output) |
| distribution (str): Toolset distribution to use (default: "default") |
| max_iterations (int): Max iterations per agent run |
| base_url (str): Base URL for model API |
| api_key (str): API key for model |
| model (str): Model name to use |
| num_workers (int): Number of parallel workers |
| verbose (bool): Enable verbose logging |
| ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) |
| log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) |
| providers_allowed (List[str]): OpenRouter providers to allow (optional) |
| providers_ignored (List[str]): OpenRouter providers to ignore (optional) |
| providers_order (List[str]): OpenRouter providers to try in order (optional) |
| provider_sort (str): Sort providers by price/throughput/latency (optional) |
| max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) |
| reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking) |
| prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming) |
| max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) |
| """ |
| self.dataset_file = Path(dataset_file) |
| self.batch_size = batch_size |
| self.run_name = run_name |
| self.distribution = distribution |
| self.max_iterations = max_iterations |
| self.base_url = base_url |
| self.api_key = api_key |
| self.model = model |
| self.num_workers = num_workers |
| self.verbose = verbose |
| self.ephemeral_system_prompt = ephemeral_system_prompt |
| self.log_prefix_chars = log_prefix_chars |
| self.providers_allowed = providers_allowed |
| self.providers_ignored = providers_ignored |
| self.providers_order = providers_order |
| self.provider_sort = provider_sort |
| self.max_tokens = max_tokens |
| self.reasoning_config = reasoning_config |
| self.prefill_messages = prefill_messages |
| self.max_samples = max_samples |
| |
| |
| if not validate_distribution(distribution): |
| raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") |
| |
| |
| self.output_dir = Path("data") / run_name |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| self.checkpoint_file = self.output_dir / "checkpoint.json" |
| |
| |
| self.stats_file = self.output_dir / "statistics.json" |
| |
| |
| self.dataset = self._load_dataset() |
| if self.max_samples and self.max_samples < len(self.dataset): |
| full_count = len(self.dataset) |
| self.dataset = self.dataset[:self.max_samples] |
| print(f"βοΈ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)") |
| |
| |
| self.batches = self._create_batches() |
| |
| print("π Batch Runner Initialized") |
| print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") |
| print(f" Batch size: {self.batch_size}") |
| print(f" Total batches: {len(self.batches)}") |
| print(f" Run name: {self.run_name}") |
| print(f" Distribution: {self.distribution}") |
| print(f" Output directory: {self.output_dir}") |
| print(f" Workers: {self.num_workers}") |
| if self.ephemeral_system_prompt: |
| prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt |
| print(f" π Ephemeral system prompt: '{prompt_preview}'") |
| |
| def _load_dataset(self) -> List[Dict[str, Any]]: |
| """ |
| Load dataset from JSONL file. |
| |
| Returns: |
| List[Dict]: List of dataset entries |
| """ |
| if not self.dataset_file.exists(): |
| raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") |
| |
| dataset = [] |
| with open(self.dataset_file, 'r', encoding='utf-8') as f: |
| for line_num, line in enumerate(f, 1): |
| line = line.strip() |
| if not line: |
| continue |
| |
| try: |
| entry = json.loads(line) |
| if 'prompt' not in entry: |
| print(f"β οΈ Warning: Line {line_num} missing 'prompt' field, skipping") |
| continue |
| dataset.append(entry) |
| except json.JSONDecodeError as e: |
| print(f"β οΈ Warning: Invalid JSON on line {line_num}: {e}") |
| continue |
| |
| if not dataset: |
| raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") |
| |
| return dataset |
| |
| def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: |
| """ |
| Split dataset into batches with indices. |
| |
| Returns: |
| List of batches, where each batch is a list of (index, entry) tuples |
| """ |
| batches = [] |
| for i in range(0, len(self.dataset), self.batch_size): |
| batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] |
| batches.append(batch) |
| |
| return batches |
| |
| def _load_checkpoint(self) -> Dict[str, Any]: |
| """ |
| Load checkpoint data if it exists. |
| |
| Returns: |
| Dict: Checkpoint data with completed prompt indices |
| """ |
| if not self.checkpoint_file.exists(): |
| return { |
| "run_name": self.run_name, |
| "completed_prompts": [], |
| "batch_stats": {}, |
| "last_updated": None |
| } |
| |
| try: |
| with open(self.checkpoint_file, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| except Exception as e: |
| print(f"β οΈ Warning: Failed to load checkpoint: {e}") |
| return { |
| "run_name": self.run_name, |
| "completed_prompts": [], |
| "batch_stats": {}, |
| "last_updated": None |
| } |
| |
| def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): |
| """ |
| Save checkpoint data. |
| |
| Args: |
| checkpoint_data (Dict): Checkpoint data to save |
| lock (Lock): Optional lock for thread-safe access |
| """ |
| checkpoint_data["last_updated"] = datetime.now().isoformat() |
|
|
| from utils import atomic_json_write |
| if lock: |
| with lock: |
| atomic_json_write(self.checkpoint_file, checkpoint_data) |
| else: |
| atomic_json_write(self.checkpoint_file, checkpoint_data) |
| |
| def _scan_completed_prompts_by_content(self) -> set: |
| """ |
| Scan all batch files and extract completed prompts by their actual content. |
| |
| This provides a more robust resume mechanism that matches on prompt text |
| rather than indices, allowing recovery even if indices don't match. |
| |
| Returns: |
| set: Set of prompt texts that have been successfully processed |
| """ |
| completed_prompts = set() |
| batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) |
| |
| if not batch_files: |
| return completed_prompts |
| |
| print(f"π Scanning {len(batch_files)} batch files for completed prompts...") |
| |
| for batch_file in batch_files: |
| try: |
| with open(batch_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| try: |
| entry = json.loads(line.strip()) |
| |
| |
| if entry.get("failed", False): |
| continue |
| |
| |
| conversations = entry.get("conversations", []) |
| for msg in conversations: |
| if msg.get("from") == "human": |
| prompt_text = msg.get("value", "").strip() |
| if prompt_text: |
| completed_prompts.add(prompt_text) |
| break |
| except json.JSONDecodeError: |
| continue |
| except Exception as e: |
| print(f" β οΈ Warning: Error reading {batch_file.name}: {e}") |
| |
| return completed_prompts |
| |
| def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]: |
| """ |
| Filter the dataset to exclude prompts that have already been completed. |
| |
| Args: |
| completed_prompts: Set of prompt texts that have been completed |
| |
| Returns: |
| Tuple of (filtered_dataset, skipped_indices) |
| """ |
| filtered_dataset = [] |
| skipped_indices = [] |
| |
| for idx, entry in enumerate(self.dataset): |
| |
| prompt_text = entry.get("prompt", "").strip() |
| |
| |
| if not prompt_text: |
| conversations = entry.get("conversations", []) |
| for msg in conversations: |
| role = msg.get("role") or msg.get("from") |
| if role in ("user", "human"): |
| prompt_text = (msg.get("content") or msg.get("value", "")).strip() |
| break |
| |
| if prompt_text in completed_prompts: |
| skipped_indices.append(idx) |
| else: |
| |
| filtered_dataset.append((idx, entry)) |
| |
| return filtered_dataset, skipped_indices |
| |
| def run(self, resume: bool = False): |
| """ |
| Run the batch processing pipeline. |
| |
| Args: |
| resume (bool): Whether to resume from checkpoint |
| """ |
| print("\n" + "=" * 70) |
| print("π Starting Batch Processing") |
| print("=" * 70) |
| |
| |
| completed_prompt_texts = set() |
| if resume: |
| completed_prompt_texts = self._scan_completed_prompts_by_content() |
| if completed_prompt_texts: |
| print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching") |
| |
| |
| if resume and completed_prompt_texts: |
| filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts) |
| |
| if not filtered_entries: |
| print("\nβ
All prompts have already been processed!") |
| return |
| |
| |
| batches_to_process = [] |
| for i in range(0, len(filtered_entries), self.batch_size): |
| batch = filtered_entries[i:i + self.batch_size] |
| batches_to_process.append(batch) |
| |
| self.batches = batches_to_process |
| |
| |
| print("\n" + "=" * 70) |
| print("π RESUME SUMMARY") |
| print("=" * 70) |
| print(f" Original dataset size: {len(self.dataset):,} prompts") |
| print(f" Already completed: {len(skipped_indices):,} prompts") |
| print(" βββββββββββββββββββββββββββββββββββββββββ") |
| print(f" π― RESUMING WITH: {len(filtered_entries):,} prompts") |
| print(f" New batches created: {len(batches_to_process)}") |
| print("=" * 70 + "\n") |
| |
| |
| checkpoint_data = self._load_checkpoint() |
| if checkpoint_data.get("run_name") != self.run_name: |
| checkpoint_data = { |
| "run_name": self.run_name, |
| "completed_prompts": [], |
| "batch_stats": {}, |
| "last_updated": None |
| } |
| |
| |
| config = { |
| "distribution": self.distribution, |
| "model": self.model, |
| "max_iterations": self.max_iterations, |
| "base_url": self.base_url, |
| "api_key": self.api_key, |
| "verbose": self.verbose, |
| "ephemeral_system_prompt": self.ephemeral_system_prompt, |
| "log_prefix_chars": self.log_prefix_chars, |
| "providers_allowed": self.providers_allowed, |
| "providers_ignored": self.providers_ignored, |
| "providers_order": self.providers_order, |
| "provider_sort": self.provider_sort, |
| "max_tokens": self.max_tokens, |
| "reasoning_config": self.reasoning_config, |
| "prefill_messages": self.prefill_messages, |
| } |
| |
| |
| completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) |
| |
| |
| total_tool_stats = {} |
| |
| start_time = time.time() |
| |
| print(f"\nπ§ Initializing {self.num_workers} worker processes...") |
| |
| |
| checkpoint_lock = Lock() |
|
|
| |
| with Pool(processes=self.num_workers) as pool: |
| |
| tasks = [ |
| ( |
| batch_num, |
| batch_data, |
| str(self.output_dir), |
| completed_prompts_set, |
| config |
| ) |
| for batch_num, batch_data in enumerate(self.batches) |
| ] |
| |
| print(f"β
Created {len(tasks)} batch tasks") |
| print("π Starting parallel batch processing...\n") |
| |
| |
| |
| results = [] |
| console = Console(force_terminal=True) |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[bold blue]π¦ Batches"), |
| BarColumn(bar_width=40), |
| MofNCompleteColumn(), |
| TextColumn("β’"), |
| TimeRemainingColumn(), |
| console=console, |
| refresh_per_second=2, |
| transient=False, |
| redirect_stdout=False, |
| redirect_stderr=False, |
| ) as progress: |
| task = progress.add_task("Processing", total=len(tasks)) |
| |
| |
| root_logger = logging.getLogger() |
| original_level = root_logger.level |
| root_logger.setLevel(logging.WARNING) |
| |
| try: |
| for result in pool.imap_unordered(_process_batch_worker, tasks): |
| results.append(result) |
| progress.update(task, advance=1) |
|
|
| |
| try: |
| batch_num = result.get('batch_num') |
| completed = result.get('completed_prompts', []) or [] |
| completed_prompts_set.update(completed) |
|
|
| if isinstance(batch_num, int): |
| checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = { |
| 'processed': result.get('processed', 0), |
| 'skipped': result.get('skipped', 0), |
| 'discarded_no_reasoning': result.get('discarded_no_reasoning', 0), |
| } |
|
|
| checkpoint_data['completed_prompts'] = sorted(completed_prompts_set) |
| self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) |
| except Exception as ckpt_err: |
| |
| print(f"β οΈ Warning: Failed to save incremental checkpoint: {ckpt_err}") |
| except Exception as e: |
| logger.error("Batch worker failed: %s", e, exc_info=True) |
| raise |
| finally: |
| root_logger.setLevel(original_level) |
| |
| |
| all_completed_prompts = list(completed_prompts_set) |
| total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} |
| |
| for batch_result in results: |
| |
| all_completed_prompts.extend(batch_result.get("completed_prompts", [])) |
| |
| |
| for tool_name, stats in batch_result.get("tool_stats", {}).items(): |
| if tool_name not in total_tool_stats: |
| total_tool_stats[tool_name] = { |
| "count": 0, |
| "success": 0, |
| "failure": 0 |
| } |
| |
| total_tool_stats[tool_name]["count"] += stats["count"] |
| total_tool_stats[tool_name]["success"] += stats["success"] |
| total_tool_stats[tool_name]["failure"] += stats["failure"] |
| |
| |
| for key in total_reasoning_stats: |
| total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0) |
| |
| |
| try: |
| checkpoint_data["completed_prompts"] = all_completed_prompts |
| self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) |
| except Exception as ckpt_err: |
| print(f"Γ’Ε‘Β Γ―ΒΈΒ Warning: Failed to save final checkpoint: {ckpt_err}") |
| |
| |
| for tool_name in total_tool_stats: |
| stats = total_tool_stats[tool_name] |
| total_calls = stats["success"] + stats["failure"] |
| if total_calls > 0: |
| stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) |
| stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) |
| else: |
| stats["success_rate"] = 0.0 |
| stats["failure_rate"] = 0.0 |
| |
| |
| |
| |
| combined_file = self.output_dir / "trajectories.jsonl" |
| print(f"\nπ¦ Combining ALL batch files into {combined_file.name}...") |
| |
| |
| VALID_TOOLS = ALL_POSSIBLE_TOOLS |
| |
| total_entries = 0 |
| filtered_entries = 0 |
| batch_files_found = 0 |
| |
| |
| all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) |
| |
| with open(combined_file, 'w', encoding='utf-8') as outfile: |
| for batch_file in all_batch_files: |
| batch_files_found += 1 |
| batch_num = batch_file.stem.split("_")[1] |
| |
| with open(batch_file, 'r', encoding='utf-8') as infile: |
| for line in infile: |
| total_entries += 1 |
| try: |
| data = json.loads(line) |
| tool_stats = data.get('tool_stats', {}) |
| |
| |
| invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS] |
| |
| if invalid_tools: |
| filtered_entries += 1 |
| invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0] |
| print(f" β οΈ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'") |
| continue |
| |
| outfile.write(line) |
| except json.JSONDecodeError: |
| filtered_entries += 1 |
| print(f" β οΈ Filtering invalid JSON entry (batch {batch_num})") |
| |
| if filtered_entries > 0: |
| print(f"β οΈ Filtered {filtered_entries} corrupted entries out of {total_entries} total") |
| print(f"β
Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)") |
| |
| |
| final_stats = { |
| "run_name": self.run_name, |
| "distribution": self.distribution, |
| "total_prompts": len(self.dataset), |
| "total_batches": len(self.batches), |
| "batch_size": self.batch_size, |
| "model": self.model, |
| "completed_at": datetime.now().isoformat(), |
| "duration_seconds": round(time.time() - start_time, 2), |
| "tool_statistics": total_tool_stats, |
| "reasoning_statistics": total_reasoning_stats, |
| } |
| |
| with open(self.stats_file, 'w', encoding='utf-8') as f: |
| json.dump(final_stats, f, indent=2, ensure_ascii=False) |
| |
| |
| print("\n" + "=" * 70) |
| print("π BATCH PROCESSING COMPLETE") |
| print("=" * 70) |
| print(f"β
Prompts processed this run: {sum(r.get('processed', 0) for r in results)}") |
| print(f"β
Total trajectories in merged file: {total_entries - filtered_entries}") |
| print(f"β
Total batch files merged: {batch_files_found}") |
| print(f"β±οΈ Total duration: {round(time.time() - start_time, 2)}s") |
| print("\nπ Tool Usage Statistics:") |
| print("-" * 70) |
| |
| if total_tool_stats: |
| |
| sorted_tools = sorted( |
| total_tool_stats.items(), |
| key=lambda x: x[1]["count"], |
| reverse=True |
| ) |
| |
| print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") |
| print("-" * 70) |
| for tool_name, stats in sorted_tools: |
| print( |
| f"{tool_name:<25} " |
| f"{stats['count']:<10} " |
| f"{stats['success']:<10} " |
| f"{stats['failure']:<10} " |
| f"{stats['success_rate']:.1f}%" |
| ) |
| else: |
| print("No tool calls were made during this run.") |
| |
| |
| total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results) |
| |
| print("\nπ§ Reasoning Coverage:") |
| print("-" * 70) |
| total_turns = total_reasoning_stats["total_assistant_turns"] |
| with_reasoning = total_reasoning_stats["turns_with_reasoning"] |
| without_reasoning = total_reasoning_stats["turns_without_reasoning"] |
| if total_turns > 0: |
| pct_with = round(with_reasoning / total_turns * 100, 1) |
| pct_without = round(without_reasoning / total_turns * 100, 1) |
| print(f" Total assistant turns: {total_turns:,}") |
| print(f" With reasoning: {with_reasoning:,} ({pct_with}%)") |
| print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)") |
| else: |
| print(" No assistant turns recorded.") |
| if total_discarded > 0: |
| print(f" π« Samples discarded (zero reasoning): {total_discarded:,}") |
| |
| print(f"\nπΎ Results saved to: {self.output_dir}") |
| print(" - Trajectories: trajectories.jsonl (combined)") |
| print(" - Individual batches: batch_*.jsonl (for debugging)") |
| print(f" - Statistics: {self.stats_file.name}") |
| print(f" - Checkpoint: {self.checkpoint_file.name}") |
|
|
|
|
| def main( |
| dataset_file: str = None, |
| batch_size: int = None, |
| run_name: str = None, |
| distribution: str = "default", |
| model: str = "anthropic/claude-sonnet-4.6", |
| api_key: str = None, |
| base_url: str = "https://openrouter.ai/api/v1", |
| max_turns: int = 10, |
| num_workers: int = 4, |
| resume: bool = False, |
| verbose: bool = False, |
| list_distributions: bool = False, |
| ephemeral_system_prompt: str = None, |
| log_prefix_chars: int = 100, |
| providers_allowed: str = None, |
| providers_ignored: str = None, |
| providers_order: str = None, |
| provider_sort: str = None, |
| max_tokens: int = None, |
| reasoning_effort: str = None, |
| reasoning_disabled: bool = False, |
| prefill_messages_file: str = None, |
| max_samples: int = None, |
| ): |
| """ |
| Run batch processing of agent prompts from a dataset. |
| |
| Args: |
| dataset_file (str): Path to JSONL file with 'prompt' field in each entry |
| batch_size (int): Number of prompts per batch |
| run_name (str): Name for this run (used for output and checkpointing) |
| distribution (str): Toolset distribution to use (default: "default") |
| model (str): Model name to use (default: "claude-opus-4-20250514") |
| api_key (str): API key for model authentication |
| base_url (str): Base URL for model API |
| max_turns (int): Maximum number of tool calling iterations per prompt (default: 10) |
| num_workers (int): Number of parallel worker processes (default: 4) |
| resume (bool): Resume from checkpoint if run was interrupted (default: False) |
| verbose (bool): Enable verbose logging (default: False) |
| list_distributions (bool): List available toolset distributions and exit |
| ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) |
| log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) |
| providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai") |
| providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra") |
| providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google") |
| provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only) |
| max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) |
| reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium") |
| reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False) |
| prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts) |
| max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) |
| |
| Examples: |
| # Basic usage |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run |
| |
| # Resume interrupted run |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume |
| |
| # Use specific distribution |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen |
| |
| # With disabled reasoning and max tokens |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ |
| --reasoning_disabled --max_tokens=128000 |
| |
| # With prefill messages from file |
| python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ |
| --prefill_messages_file=configs/prefill_opus.json |
| |
| # List available distributions |
| python batch_runner.py --list_distributions |
| """ |
| |
| if list_distributions: |
| from toolset_distributions import list_distributions as get_all_dists, print_distribution_info |
| |
| print("π Available Toolset Distributions") |
| print("=" * 70) |
| |
| all_dists = get_all_dists() |
| for dist_name in sorted(all_dists.keys()): |
| print_distribution_info(dist_name) |
| |
| print("\nπ‘ Usage:") |
| print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") |
| print(" --run_name=my_run --distribution=<name>") |
| return |
| |
| |
| if not dataset_file: |
| print("β Error: --dataset_file is required") |
| return |
| |
| if not batch_size or batch_size < 1: |
| print("β Error: --batch_size must be a positive integer") |
| return |
| |
| if not run_name: |
| print("β Error: --run_name is required") |
| return |
| |
| |
| providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None |
| providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None |
| providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None |
| |
| |
| |
| reasoning_config = None |
| if reasoning_disabled: |
| |
| reasoning_config = {"effort": "none"} |
| print("π§ Reasoning: DISABLED (effort=none)") |
| elif reasoning_effort: |
| |
| valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"] |
| if reasoning_effort not in valid_efforts: |
| print(f"β Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}") |
| return |
| reasoning_config = {"enabled": True, "effort": reasoning_effort} |
| print(f"π§ Reasoning effort: {reasoning_effort}") |
| |
| |
| prefill_messages = None |
| if prefill_messages_file: |
| try: |
| with open(prefill_messages_file, 'r', encoding='utf-8') as f: |
| prefill_messages = json.load(f) |
| if not isinstance(prefill_messages, list): |
| print("β Error: prefill_messages_file must contain a JSON array of messages") |
| return |
| print(f"π¬ Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}") |
| except Exception as e: |
| print(f"β Error loading prefill messages: {e}") |
| return |
| |
| |
| try: |
| runner = BatchRunner( |
| dataset_file=dataset_file, |
| batch_size=batch_size, |
| run_name=run_name, |
| distribution=distribution, |
| max_iterations=max_turns, |
| base_url=base_url, |
| api_key=api_key, |
| model=model, |
| num_workers=num_workers, |
| verbose=verbose, |
| ephemeral_system_prompt=ephemeral_system_prompt, |
| log_prefix_chars=log_prefix_chars, |
| providers_allowed=providers_allowed_list, |
| providers_ignored=providers_ignored_list, |
| providers_order=providers_order_list, |
| provider_sort=provider_sort, |
| max_tokens=max_tokens, |
| reasoning_config=reasoning_config, |
| prefill_messages=prefill_messages, |
| max_samples=max_samples, |
| ) |
|
|
| runner.run(resume=resume) |
| |
| except Exception as e: |
| print(f"\nβ Fatal error: {e}") |
| if verbose: |
| traceback.print_exc() |
| return 1 |
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(main) |
|
|
|
|