| | |
| | """ |
| | Mass evaluation script for running predefined prompts through all checkpoints of a model. |
| | Simple, clean, and minimal approach with readable markdown logging. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import glob |
| | import time |
| | import json |
| | import argparse |
| | from datetime import datetime |
| | from typing import List, Dict, Any, Optional |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
| |
|
| |
|
| | def load_prompts(prompts_file: str = "prompts.json") -> List[str]: |
| | """ |
| | Load benchmark prompts from JSON file. |
| | |
| | Args: |
| | prompts_file: Path to the prompts JSON file |
| | |
| | Returns: |
| | List of prompt strings |
| | """ |
| | |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| | prompts_path = os.path.join(script_dir, prompts_file) |
| | |
| | if not os.path.exists(prompts_path): |
| | print(f"β οΈ Prompts file not found: {prompts_path}") |
| | print("Using default fallback prompts...") |
| | |
| | return ["Hello, how are you?"] |
| | |
| | try: |
| | with open(prompts_path, 'r') as f: |
| | prompts = json.load(f) |
| | |
| | |
| | if isinstance(prompts, dict) and "benchmark_prompts" in prompts: |
| | |
| | prompts = [p.get("text", str(p)) for p in prompts["benchmark_prompts"]] |
| | elif isinstance(prompts, list): |
| | |
| | pass |
| | else: |
| | print("β οΈ Invalid prompts format, using fallback") |
| | return ["Hello, how are you?"] |
| | |
| | print(f"π Loaded {len(prompts)} prompts from {prompts_file}") |
| | return prompts |
| | |
| | except json.JSONDecodeError as e: |
| | print(f"β Error parsing prompts file: {e}") |
| | print("Using default fallback prompts...") |
| | return ["Hello, how are you?"] |
| | except Exception as e: |
| | print(f"β Error loading prompts file: {e}") |
| | print("Using default fallback prompts...") |
| | return ["Hello, how are you?"] |
| |
|
| |
|
| | def discover_checkpoints(model_name: str, base_dir: str = "../pico-train/runs") -> List[str]: |
| | """ |
| | Discover all available checkpoints for a given model. |
| | |
| | Args: |
| | model_name: Name of the model |
| | base_dir: Base directory for model runs |
| | |
| | Returns: |
| | List of checkpoint paths sorted by step number |
| | """ |
| | model_path = os.path.join(base_dir, model_name, "checkpoints") |
| | |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"Model directory not found: {model_path}") |
| | |
| | |
| | pattern = os.path.join(model_path, "step_*") |
| | checkpoint_dirs = glob.glob(pattern) |
| | |
| | |
| | valid_checkpoints = [] |
| | for checkpoint_dir in checkpoint_dirs: |
| | if os.path.isdir(checkpoint_dir): |
| | try: |
| | step_num = int(os.path.basename(checkpoint_dir).split('_')[1]) |
| | valid_checkpoints.append((step_num, checkpoint_dir)) |
| | except (IndexError, ValueError): |
| | continue |
| | |
| | |
| | valid_checkpoints.sort(key=lambda x: x[0]) |
| | return [checkpoint_path for _, checkpoint_path in valid_checkpoints] |
| |
|
| |
|
| | def run_benchmark(model_name: str, output_dir: str = "results", prompts_file: str = "prompts.json") -> str: |
| | """ |
| | Run benchmark evaluation on all checkpoints of a model. |
| | |
| | Args: |
| | model_name: Name of the model to benchmark |
| | output_dir: Directory to save results |
| | prompts_file: Path to the prompts JSON file |
| | |
| | Returns: |
| | Path to the generated report file |
| | """ |
| | print(f"π Starting benchmark for model: {model_name}") |
| | |
| | |
| | benchmark_prompts = load_prompts(prompts_file) |
| | if not benchmark_prompts: |
| | print("β No prompts loaded") |
| | return None |
| | |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | try: |
| | checkpoints = discover_checkpoints(model_name) |
| | print(f"π Found {len(checkpoints)} checkpoints") |
| | except FileNotFoundError as e: |
| | print(f"β Error: {e}") |
| | return None |
| | |
| | if not checkpoints: |
| | print("β No valid checkpoints found") |
| | return None |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | report_file = os.path.join(output_dir, f"{model_name}_benchmark_{timestamp}.md") |
| | |
| | |
| | try: |
| | from inference import PicoLMInference |
| | except ImportError as e: |
| | print(f"β Failed to import inference module: {e}") |
| | return None |
| | |
| | |
| | with open(report_file, 'w') as f: |
| | f.write(f"# Benchmark Report: {model_name}\n\n") |
| | f.write(f"**Generated**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| | f.write(f"**Total Checkpoints**: {len(checkpoints)}\n") |
| | f.write(f"**Total Prompts**: {len(benchmark_prompts)}\n\n") |
| | f.write("---\n\n") |
| | |
| | |
| | for i, checkpoint_path in enumerate(checkpoints, 1): |
| | checkpoint_name = os.path.basename(checkpoint_path) |
| | print(f"π Processing {checkpoint_name} ({i}/{len(checkpoints)})") |
| | |
| | f.write(f"## Checkpoint: {checkpoint_name}\n\n") |
| | f.write(f"**Path**: `{checkpoint_path}`\n\n") |
| | |
| | try: |
| | |
| | start_time = time.time() |
| | inference = PicoLMInference(checkpoint_path=checkpoint_path, device="cuda") |
| | load_time = time.time() - start_time |
| | |
| | f.write(f"**Load Time**: {load_time:.2f}s\n\n") |
| | |
| | |
| | for j, prompt_text in enumerate(benchmark_prompts, 1): |
| | print(f" ββ Prompt {j}/{len(benchmark_prompts)}: {prompt_text[:30]}...") |
| | |
| | f.write(f"### Prompt {j}: \"{prompt_text}\"\n\n") |
| | |
| | try: |
| | |
| | gen_start = time.time() |
| | response = inference.generate_completion( |
| | prompt=prompt_text, |
| | max_length=100, |
| | temperature=0.7 |
| | ) |
| | gen_time = time.time() - gen_start |
| | |
| | f.write(f"**Response**:\n```\n{response}\n```\n\n") |
| | f.write(f"**Metadata**: max_length=100, temperature=0.7, time={gen_time:.2f}s\n\n") |
| | |
| | except Exception as e: |
| | f.write(f"**Error**: {str(e)}\n\n") |
| | print(f" β οΈ Error on prompt {j}: {e}") |
| | |
| | except Exception as e: |
| | f.write(f"**Checkpoint Error**: {str(e)}\n\n") |
| | print(f" β Failed to load checkpoint: {e}") |
| | |
| | f.write("---\n\n") |
| | |
| | print(f"β
Benchmark complete! Report saved to: {report_file}") |
| | return report_file |
| |
|
| |
|
| | def main(): |
| | """Main function with command-line interface.""" |
| | parser = argparse.ArgumentParser( |
| | description="Run benchmark evaluation on all checkpoints of a model", |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | python benchmark.py pico-decoder-tiny-dolma5M-v1 |
| | python benchmark.py pico-decoder-tiny-dolma29k-v3 --output results/ |
| | """ |
| | ) |
| | |
| | parser.add_argument("model_name", type=str, |
| | help="Model name (e.g., 'pico-decoder-tiny-dolma5M-v1')") |
| | parser.add_argument("--output", "-o", type=str, default="results", |
| | help="Output directory for results (default: results)") |
| | parser.add_argument("--prompts", "-p", type=str, default="prompts.json", |
| | help="Prompts JSON file (default: prompts.json)") |
| | |
| | args = parser.parse_args() |
| | |
| | try: |
| | report_file = run_benchmark(args.model_name, args.output, args.prompts) |
| | if report_file: |
| | print(f"\nπ Report available at: {report_file}") |
| | return 0 |
| | else: |
| | return 1 |
| | |
| | except KeyboardInterrupt: |
| | print("\nβΉοΈ Benchmark interrupted by user") |
| | return 1 |
| | except Exception as e: |
| | print(f"β Unexpected error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return 1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | exit(main()) |
| |
|