File size: 6,198 Bytes
173f28e
 
 
 
 
bf74331
173f28e
d6ca6d1
173f28e
f56dbf3
173f28e
 
 
 
 
 
 
 
 
 
d6ca6d1
 
 
 
 
 
 
 
 
173f28e
 
 
 
 
 
 
 
bf74331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173f28e
 
d6ca6d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf74331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173f28e
 
 
bf74331
173f28e
 
 
bf74331
173f28e
 
 
 
 
 
 
 
 
 
f56dbf3
bf74331
 
 
 
 
 
 
 
 
173f28e
 
 
 
f56dbf3
173f28e
 
 
 
 
 
f56dbf3
173f28e
 
 
 
bf74331
 
173f28e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations

import argparse

from corpus import build_corpus
from dataset_config import DATASET_PRESETS, DatasetConfig
from evals import evaluate_memory, evaluate_quality, evaluate_speed
from models import REGISTRY, ModelConfig, load_custom_models_from_file, register_model
from report import print_report
from wrapper import load_model


def main(argv: list[str] | None = None) -> None:
    parser = argparse.ArgumentParser(
        prog="embedding-bench",
        description="Compare embedding models on quality, speed, and memory.",
    )
    parser.add_argument(
        "--models",
        nargs="+",
        default=None,
        help="Models to benchmark (default: all registered)",
    )
    parser.add_argument(
        "--add-model",
        action="append",
        default=[],
        metavar="KEY:NAME:MODEL_ID:BACKEND[:GGUF_FILE]",
        help="Register a custom model. Can be repeated.",
    )
    parser.add_argument("--corpus-size", type=int, default=1000)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--num-runs", type=int, default=3)
    parser.add_argument("--skip-quality", action="store_true")
    parser.add_argument("--skip-speed", action="store_true")
    parser.add_argument("--skip-memory", action="store_true")

    # Dataset configuration
    parser.add_argument(
        "--datasets",
        nargs="+",
        default=["sts"],
        choices=list(DATASET_PRESETS.keys()),
        help=f"Dataset presets to evaluate (default: sts). "
             f"Available: {', '.join(DATASET_PRESETS.keys())}",
    )
    parser.add_argument("--max-pairs", type=int, default=None,
                        help="Limit number of pairs per dataset (useful for large datasets)")

    # Custom dataset (overrides --datasets)
    parser.add_argument("--dataset", default=None,
                        help="Custom HF dataset name (overrides --datasets)")
    parser.add_argument("--config", default=None,
                        help="Dataset config/subset name (e.g. 'triplet')")
    parser.add_argument("--split", default="test")
    parser.add_argument("--query-col", default="sentence1")
    parser.add_argument("--passage-col", default="sentence2")
    parser.add_argument("--score-col", default="score",
                        help="Score column name. Pass 'none' for pair-only datasets.")
    parser.add_argument("--score-scale", type=float, default=5.0)

    # Output options
    parser.add_argument("--csv", default=None, metavar="PATH",
                        help="Export results to a CSV file")
    parser.add_argument("--charts", default=None, metavar="DIR",
                        help="Save charts to a directory (e.g. ./results)")

    args = parser.parse_args(argv)

    # Load persisted custom models and register any --add-model entries
    load_custom_models_from_file()
    for spec in args.add_model:
        parts = spec.split(":")
        if len(parts) < 4:
            parser.error(f"--add-model requires KEY:NAME:MODEL_ID:BACKEND, got: {spec}")
        key, name, model_id, backend = parts[0], parts[1], parts[2], parts[3]
        gguf_file = parts[4] if len(parts) > 4 else None
        try:
            register_model(key, ModelConfig(
                name=name, model_id=model_id, backend=backend, gguf_file=gguf_file,
            ))
        except ValueError as e:
            parser.error(str(e))

    if args.models is None:
        args.models = list(REGISTRY.keys())
    else:
        for k in args.models:
            if k not in REGISTRY:
                parser.error(f"Unknown model key: '{k}'. Available: {list(REGISTRY.keys())}")

    # Build list of dataset configs
    if args.dataset:
        # Custom dataset overrides presets
        ds_configs = [DatasetConfig(
            name=args.dataset,
            config=args.config,
            split=args.split,
            query_col=args.query_col,
            passage_col=args.passage_col,
            score_col=None if args.score_col.lower() == "none" else args.score_col,
            score_scale=args.score_scale,
        )]
    else:
        ds_configs = [DATASET_PRESETS[k] for k in args.datasets]

    configs = [REGISTRY[k] for k in args.models]
    baseline_name = next((c.name for c in configs if c.is_baseline), None)

    # Use first dataset for corpus building
    corpus: list[str] | None = None
    if not args.skip_speed or not args.skip_memory:
        print(f"Preparing corpus ({args.corpus_size} sentences)...")
        corpus = build_corpus(args.corpus_size, ds_configs[0])

    results = []
    for cfg in configs:
        print(f"\n{'='*50}")
        print(f"Benchmarking: {cfg.name}")
        print(f"{'='*50}")

        result: dict = {"name": cfg.name, "is_baseline": cfg.is_baseline}

        if not args.skip_quality:
            model = load_model(cfg)
            quality_results = {}
            for ds_cfg in ds_configs:
                ds_key = ds_cfg.name.split("/")[-1]
                print(f"  Evaluating quality on {ds_cfg.name}...")
                quality_results[ds_key] = evaluate_quality(
                    model, ds_cfg, max_pairs=args.max_pairs,
                )
                print(f"    {quality_results[ds_key]}")
            result["quality"] = quality_results
            del model

        if not args.skip_speed and corpus is not None:
            print(f"  Evaluating speed ({args.num_runs} runs, {args.corpus_size} sentences)...")
            model = load_model(cfg)
            result["speed"] = evaluate_speed(model, corpus, num_runs=args.num_runs, batch_size=args.batch_size)
            print(f"  Speed: {result['speed']['sentences_per_second']} sent/s")
            del model

        if not args.skip_memory and corpus is not None:
            print("  Evaluating memory (isolated subprocess)...")
            result["memory_mb"] = evaluate_memory(cfg.model_id, corpus, batch_size=args.batch_size, backend=cfg.backend)
            print(f"  Memory: {result['memory_mb']} MB")

        results.append(result)

    print_report(results, baseline_name=baseline_name,
                 csv_path=args.csv, chart_dir=args.charts)


if __name__ == "__main__":
    main()