File size: 3,659 Bytes
18adc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
18adc2c
 
 
 
 
 
 
 
 
 
0de2901
 
 
 
18adc2c
 
 
 
 
 
 
 
 
0de2901
18adc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
"""
Cortex Benchmark Harness — CLI Entry Point

Usage:
    # Quick test (10 examples, fast tasks only)
    python -m benchmark.run_benchmark --n 10 --tasks hellaswag piqa

    # Standard suite (50 examples, all tasks)
    python -m benchmark.run_benchmark --n 50

    # Full evaluation (all examples)
    python -m benchmark.run_benchmark --n 0 --tasks hellaswag piqa arc-easy arc-challenge winogrande mmlu

    # Custom model
    python -m benchmark.run_benchmark --model meta-llama/Llama-3.2-1B --n 50

    # Save results
    python -m benchmark.run_benchmark --n 50 --output results.json
"""

import argparse
import json
import sys
import os

# Ensure parent directory is on path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def main():
    parser = argparse.ArgumentParser(description="Cortex Benchmark Harness")
    parser.add_argument(
        "--model", type=str, default="HuggingFaceTB/SmolLM2-135M",
        help="HuggingFace model ID to evaluate",
    )
    parser.add_argument(
        "--tasks", nargs="+", 
        default=["hellaswag", "piqa", "arc-easy", "winogrande"],
        help="Tasks to run (choices: hellaswag, piqa, arc-easy, arc-challenge, winogrande, mmlu, halueval)",
    )
    parser.add_argument(
        "--n", type=int, default=50,
        help="Number of examples per task (0 = all available)",
    )
    parser.add_argument(
        "--no-memory", action="store_true",
        help="Skip memory benchmarks (passkey, multi-hop)",
    )
    parser.add_argument(
        "--passkey-lengths", nargs="+", type=int, default=[128, 256, 512],
        help="Context lengths for passkey retrieval test",
    )
    parser.add_argument(
        "--n-passkey", type=int, default=5,
        help="Number of passkey examples per context length",
    )
    parser.add_argument(
        "--device", type=str, default="auto",
        help="Device: cuda, mps, cpu, or auto (auto: cuda > mps > cpu)",
    )
    parser.add_argument(
        "--dtype", type=str, default="float32",
        choices=["float32", "float16", "bfloat16"],
        help="Model dtype",
    )
    parser.add_argument(
        "--output", type=str, default=None,
        help="Path to save JSON results",
    )
    parser.add_argument(
        "--cortex-weights", type=str, default=None,
        help="Optional Cortex weights file to load before the Cortex phase",
    )
    
    args = parser.parse_args()
    
    from benchmark.runner import BenchmarkRunner
    
    runner = BenchmarkRunner(
        model_name=args.model,
        device=args.device,
        dtype=args.dtype,
        cortex_weights=args.cortex_weights,
    )
    
    n = args.n if args.n > 0 else None
    
    results = runner.run_comparison(
        tasks=args.tasks,
        n=n,
        include_memory=not args.no_memory,
        n_passkey=args.n_passkey,
        passkey_lengths=args.passkey_lengths,
    )
    
    BenchmarkRunner.print_summary(results)
    
    if args.output:
        # Filter out non-serializable items
        def make_serializable(obj):
            if isinstance(obj, dict):
                return {k: make_serializable(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [make_serializable(v) for v in obj]
            elif isinstance(obj, (bool, int, float, str, type(None))):
                return obj
            else:
                return str(obj)
        
        with open(args.output, "w") as f:
            json.dump(make_serializable(results), f, indent=2)
        print(f"\nResults saved to {args.output}")


if __name__ == "__main__":
    main()