| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from dataclasses import dataclass |
| from typing import List, Dict, Sequence |
|
|
| import numpy as np |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| from crom_efficientllm.budget_packer.packer import budget_pack, Chunk |
| from crom_efficientllm.rerank_engine.rerank import hybrid_rerank |
|
|
| try: |
| from sentence_transformers import SentenceTransformer |
| except Exception: |
| SentenceTransformer = None |
|
|
| |
|
|
| @dataclass |
| class Doc: |
| id: str |
| text: str |
|
|
| def load_jsonl(path: str) -> List[Dict]: |
| with open(path, "r", encoding="utf-8") as f: |
| return [json.loads(line) for line in f] |
|
|
| def build_corpus(path: str) -> List[Doc]: |
| rows = load_jsonl(path) |
| return [Doc(id=str(r.get("id", i)), text=str(r["text"])) for i, r in enumerate(rows)] |
|
|
| def sparse_retrieval(query: str, corpus: Sequence[Doc], k: int = 100) -> List[Dict]: |
| texts = [d.text for d in corpus] |
| vect = TfidfVectorizer(ngram_range=(1, 2)).fit(texts) |
| D = vect.transform(texts) |
| Q = vect.transform([query]) |
| sims = cosine_similarity(Q, D).ravel() |
| order = np.argsort(-sims)[:k] |
| return [{"id": corpus[i].id, "text": corpus[i].text, "score_sparse": float(sims[i])} for i in order] |
|
|
| def dense_embed_model(name: str): |
| if SentenceTransformer is None: |
| raise RuntimeError("sentence-transformers not installed. Install with `pip install -e .`.") |
| return SentenceTransformer(name) |
|
|
| def _apply_flashrank(query: str, docs: List[Dict], model_name: str) -> List[Dict]: |
| try: |
| from crom_efficientllm.plugins.flashrank_reranker import flashrank_rerank |
| except Exception as e: |
| raise RuntimeError("FlashRank plugin not available. Install extras: pip install .[plugins]") from e |
| ranked = flashrank_rerank(query, docs, model_name=model_name) |
| |
| scores = np.array([d.get("score_flashrank", 0.0) for d in ranked], dtype=np.float32) |
| if scores.size and float(scores.max() - scores.min()) > 1e-12: |
| s = (scores - scores.min()) / (scores.max() - scores.min()) |
| else: |
| s = np.zeros_like(scores) |
| for i, d in enumerate(ranked): |
| d["score_final"] = float(s[i]) |
| return ranked |
|
|
| def _apply_llmlingua(text: str, ratio: float) -> str: |
| try: |
| from crom_efficientllm.plugins.llmlingua_compressor import compress_prompt |
| except Exception as e: |
| raise RuntimeError("LLMLingua plugin not available. Install extras: pip install .[plugins]") from e |
| return compress_prompt(text, target_ratio=ratio) |
|
|
| def _save_evidently_report(all_embs: List[List[float]], out_html: str) -> None: |
| try: |
| from crom_efficientllm.plugins.evidently_drift import drift_report |
| except Exception as e: |
| raise RuntimeError("Evidently plugin not available. Install extras: pip install .[plugins]") from e |
| n = len(all_embs) |
| if n < 4: |
| return |
| ref = all_embs[: n // 2] |
| cur = all_embs[n // 2 :] |
| rep = drift_report(ref, cur) |
| rep.save_html(out_html) |
|
|
| def mock_llm_generate(prompt: str) -> str: |
| time.sleep(0.005) |
| return "[MOCK] " + prompt[:160] |
|
|
| def e2e(args: argparse.Namespace) -> None: |
| corpus = build_corpus(args.corpus) |
| queries = [r["query"] for r in load_jsonl(args.queries)] |
| embed = dense_embed_model(args.model) |
| all_embs: List[List[float]] = [] |
|
|
| t0 = time.perf_counter() |
| all_rows = [] |
| for q in queries: |
| t_s = time.perf_counter() |
| cands = sparse_retrieval(q, corpus, k=args.k) |
| t_sparse = (time.perf_counter() - t_s) * 1000 |
|
|
| t_r = time.perf_counter() |
| if args.use_flashrank: |
| reranked = _apply_flashrank(q, cands, args.flashrank_model) |
| else: |
| reranked = hybrid_rerank(q, cands, embed, alpha=args.alpha) |
| t_rerank = (time.perf_counter() - t_r) * 1000 |
|
|
| |
| chunks = [ |
| Chunk(text=d["text"], score=d.get("score_final", d.get("score_sparse", 0.0)), tokens=max(1, len(d["text"]) // 4)) |
| for d in reranked |
| ] |
| budget_tokens = int(sum(c.tokens for c in chunks) * args.budget) |
| t_p = time.perf_counter() |
| packed = budget_pack(chunks, budget=budget_tokens) |
| t_pack = (time.perf_counter() - t_p) * 1000 |
|
|
| prompt = "\n\n".join(c.text for c in packed) + f"\n\nQ: {q}\nA:" |
| if args.use_llmlingua: |
| prompt = _apply_llmlingua(prompt, ratio=args.compress_ratio) |
|
|
| |
| with np.errstate(all="ignore"): |
| if len(packed) > 0: |
| doc_embs = embed.encode([c.text for c in packed], convert_to_numpy=True) |
| vec = np.mean(doc_embs, axis=0).tolist() |
| all_embs.append(vec) |
|
|
| t_l = time.perf_counter() |
| _ = mock_llm_generate(prompt) |
| t_llm = (time.perf_counter() - t_l) * 1000 |
|
|
| total = (time.perf_counter() - t_s) * 1000 |
| all_rows.append({ |
| "query": q, |
| "sparse_ms": t_sparse, |
| "rerank_ms": t_rerank, |
| "pack_ms": t_pack, |
| "llm_ms": t_llm, |
| "total_ms": total, |
| "packed_tokens": sum(c.tokens for c in packed), |
| "orig_tokens": sum(c.tokens for c in chunks), |
| "save_ratio": 1 - (sum(c.tokens for c in packed) / max(1, sum(c.tokens for c in chunks))), |
| "used_flashrank": bool(args.use_flashrank), |
| "used_llmlingua": bool(args.use_llmlingua), |
| }) |
|
|
| elapsed = (time.perf_counter() - t0) * 1000 |
| os.makedirs(args.out_dir, exist_ok=True) |
| out_path = os.path.join(args.out_dir, "e2e_results.jsonl") |
| with open(out_path, "w", encoding="utf-8") as f: |
| for r in all_rows: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
| print(f"saved results -> {out_path} ({len(all_rows)} queries) ; elapsed={elapsed:.2f}ms") |
|
|
| if args.use_evidently and all_embs: |
| html_path = os.path.join(args.out_dir, "evidently_report.html") |
| _save_evidently_report(all_embs, html_path) |
| print(f"evidently report -> {html_path}") |
|
|
| def budget_sweep(args: argparse.Namespace) -> None: |
| import itertools |
| corpus = build_corpus(args.corpus) |
| queries = [r["query"] for r in load_jsonl(args.queries)][: args.max_q] |
| embed = dense_embed_model(args.model) |
|
|
| budgets = [b / 100.0 for b in range(args.b_min, args.b_max + 1, args.b_step)] |
| rows = [] |
| for q, b in itertools.product(queries, budgets): |
| cands = sparse_retrieval(q, corpus, k=args.k) |
| reranked = hybrid_rerank(q, cands, embed, alpha=args.alpha) |
| chunks = [Chunk(text=d["text"], score=d["score_final"], tokens=max(1, len(d["text"]) // 4)) for d in reranked] |
| budget_tokens = int(sum(c.tokens for c in chunks) * b) |
| packed = budget_pack(chunks, budget=budget_tokens) |
| rows.append({ |
| "query": q, |
| "budget": b, |
| "packed_tokens": sum(c.tokens for c in packed), |
| "orig_tokens": sum(c.tokens for c in chunks), |
| "save_ratio": 1 - (sum(c.tokens for c in packed) / max(1, sum(c.tokens for c in chunks))), |
| "avg_score": float(np.mean([c.score for c in packed])) if packed else 0.0, |
| }) |
|
|
| os.makedirs(args.out_dir, exist_ok=True) |
| out_path = os.path.join(args.out_dir, "budget_sweep.jsonl") |
| with open(out_path, "w", encoding="utf-8") as f: |
| for r in rows: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
| print(f"saved results -> {out_path} ; points={len(rows)}") |
|
|
| if args.save_plots: |
| try: |
| import matplotlib.pyplot as plt |
| import matplotlib.pyplot as _plt |
| except Exception: |
| print("[warn] matplotlib not installed; install dev extras: pip install -e .[dev]") |
| else: |
| |
| import collections |
| agg = collections.defaultdict(list) |
| for r in rows: |
| agg[r["budget"]].append(r) |
| budgets_sorted = sorted(agg.keys()) |
| avg_save = [float(np.mean([x["save_ratio"] for x in agg[b]])) for b in budgets_sorted] |
| avg_score = [float(np.mean([x["avg_score"] for x in agg[b]])) for b in budgets_sorted] |
|
|
| _plt.figure() |
| _plt.plot([b * 100 for b in budgets_sorted], [s * 100 for s in avg_save], marker="o") |
| _plt.xlabel("Budget (%)") |
| _plt.ylabel("Avg Save Ratio (%)") |
| _plt.title("Budget Sweep: Save Ratio vs Budget") |
| _plt.grid(True) |
| _plt.tight_layout() |
| _plt.savefig(os.path.join(args.out_dir, "budget_sweep.png")) |
|
|
| _plt.figure() |
| _plt.plot([s * 100 for s in avg_save], avg_score, marker="o") |
| _plt.xlabel("Save Ratio (%)") |
| _plt.ylabel("Avg Score (packed)") |
| _plt.title("Pareto: Quality vs Savings") |
| _plt.grid(True) |
| _plt.tight_layout() |
| _plt.savefig(os.path.join(args.out_dir, "budget_pareto.png")) |
| print("plots ->", os.path.join(args.out_dir, "budget_sweep.png"), ",", os.path.join(args.out_dir, "budget_pareto.png")) |
|
|
| def scaling(args: argparse.Namespace) -> None: |
| def make_synth(n: int, seed: int = 42): |
| rng = np.random.default_rng(seed) |
| tokens = np.clip(rng.lognormal(4.0, 0.6, n).astype(int), 5, 2000) |
| score = rng.normal(0, 1, n) |
| return [Chunk(text="x" * int(t * 4), score=float(s), tokens=int(t)) for s, t in zip(score, tokens)] |
|
|
| for n in [1000, 5000, 10000, 20000, 50000, 100000]: |
| if n > args.n_max: |
| break |
| chunks = make_synth(n) |
| budget = int(sum(c.tokens for c in chunks) * args.budget) |
| t0 = time.perf_counter() |
| _ = budget_pack(chunks, budget) |
| ms = (time.perf_counter() - t0) * 1000 |
| print(f"n={n:6d} budget={args.budget:.0%} time={ms:8.2f} ms") |
|
|
| def dp_curve(args: argparse.Namespace) -> None: |
| def make_synth(n: int, seed: int = 123, corr: float = 0.6): |
| rng = np.random.default_rng(seed) |
| true_rel = rng.normal(0, 1, n) |
| noise = rng.normal(0, 1, n) * np.sqrt(1 - corr**2) |
| score = corr * true_rel + noise |
| tokens = np.clip(rng.lognormal(4.0, 0.6, n).astype(int), 5, 2000) |
| chunks = [Chunk(text="x" * int(t * 4), score=float(s), tokens=int(t)) for s, t in zip(score, tokens)] |
| return chunks, true_rel |
|
|
| def optimal(chunks: Sequence[Chunk], values: np.ndarray, budget: int) -> float: |
| B = budget |
| dp = np.zeros(B + 1, dtype=np.float32) |
| for i, ch in enumerate(chunks): |
| wt = ch.tokens |
| val = max(0.0, float(values[i])) |
| for b in range(B, wt - 1, -1): |
| dp[b] = max(dp[b], dp[b - wt] + val) |
| return float(dp[B]) |
|
|
| chunks, true_rel = make_synth(args.n) |
| total = sum(c.tokens for c in chunks) |
| budgets = [int(total * b / 100.0) for b in range(args.b_min, args.b_max + 1, args.b_step)] |
| out_rows = [] |
|
|
| for B in budgets: |
| sel = budget_pack(chunks, B) |
| idx_map = {id(c): i for i, c in enumerate(chunks)} |
| rel_bp = float(np.sum([max(0.0, true_rel[idx_map[id(c)]]) for c in sel])) |
| rel_opt = optimal(chunks[: args.n_opt], true_rel[: args.n_opt], min(B, sum(c.tokens for c in chunks[: args.n_opt]))) |
| pct = rel_bp / max(rel_opt, 1e-9) |
| out_rows.append({"budget": B, "pct": pct, "rel_bp": rel_bp, "rel_opt": rel_opt}) |
| print(f"budget={B:8d} rel_bp={rel_bp:8.3f} rel_opt≈{rel_opt:8.3f} pct≈{pct*100:5.1f}% (subset n={args.n_opt})") |
|
|
| if args.save_plots: |
| try: |
| import matplotlib.pyplot as plt |
| import matplotlib.pyplot as _plt |
| except Exception: |
| print("[warn] matplotlib not installed; install dev extras: pip install -e .[dev]") |
| else: |
| _plt.figure() |
| xs = [r["budget"] * 100.0 / total for r in out_rows] |
| ys = [r["pct"] * 100 for r in out_rows] |
| _plt.plot(xs, ys, marker="o") |
| _plt.xlabel("Budget (%)") |
| _plt.ylabel("% of optimal (subset)") |
| _plt.title("DP Curve: Greedy vs Optimal") |
| _plt.grid(True) |
| _plt.tight_layout() |
| os.makedirs(args.out_dir, exist_ok=True) |
| _plt.savefig(os.path.join(args.out_dir, "dp_curve.png")) |
| print("plot ->", os.path.join(args.out_dir, "dp_curve.png")) |
|
|
| def compare_haystack(args: argparse.Namespace) -> None: |
| try: |
| from haystack.nodes import BM25Retriever, SentenceTransformersRetriever |
| from haystack.document_stores import InMemoryDocumentStore |
| except Exception as e: |
| raise RuntimeError("Install extras: pip install .[haystack]") from e |
|
|
| corpus = build_corpus(args.corpus) |
| docs = [{"content": d.text, "meta": {"id": d.id}} for d in corpus] |
| store = InMemoryDocumentStore(use_bm25=True) |
| store.write_documents(docs) |
|
|
| bm25 = BM25Retriever(document_store=store) |
| dretr = SentenceTransformersRetriever(document_store=store, model_name_or_path=args.model) |
|
|
| queries = [r["query"] for r in load_jsonl(args.queries)][: args.max_q] |
| for q in queries: |
| t0 = time.perf_counter() |
| bm = bm25.retrieve(q, top_k=args.k) |
| dn = dretr.retrieve(q, top_k=args.k) |
| ms = (time.perf_counter() - t0) * 1000 |
| print(f"{q[:40]:40s} bm25={len(bm):3d} dense={len(dn):3d} time={ms:7.2f} ms") |
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser(prog="crom-bench") |
| sub = ap.add_subparsers(dest="cmd", required=True) |
|
|
| p = sub.add_parser("e2e", help="end-to-end: retrieval → rerank → pack → mock LLM") |
| p.add_argument("--corpus", default="examples/corpus/sample_docs.jsonl") |
| p.add_argument("--queries", default="examples/corpus/sample_queries.jsonl") |
| p.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2") |
| p.add_argument("--k", type=int, default=200) |
| p.add_argument("--alpha", type=float, default=0.5) |
| p.add_argument("--budget", type=float, default=0.3) |
| |
| p.add_argument("--use-flashrank", action="store_true") |
| p.add_argument("--flashrank-model", default="ms-marco-TinyBERT-L-2-v2") |
| p.add_argument("--use-llmlingua", action="store_true") |
| p.add_argument("--compress-ratio", type=float, default=0.6) |
| p.add_argument("--use-evidently", action="store_true") |
|
|
| p.add_argument("--out-dir", default="benchmarks/out") |
| p.set_defaults(func=e2e) |
|
|
| p2 = sub.add_parser("sweep", help="budget sweep + Pareto csv") |
| p2.add_argument("--corpus", default="examples/corpus/sample_docs.jsonl") |
| p2.add_argument("--queries", default="examples/corpus/sample_queries.jsonl") |
| p2.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2") |
| p2.add_argument("--k", type=int, default=200) |
| p2.add_argument("--alpha", type=float, default=0.5) |
| p2.add_argument("--b-min", type=int, default=10) |
| p2.add_argument("--b-max", type=int, default=90) |
| p2.add_argument("--b-step", type=int, default=10) |
| p2.add_argument("--max-q", type=int, default=20) |
| p2.add_argument("--out-dir", default="benchmarks/out") |
| p2.add_argument("--save-plots", action="store_true") |
| p2.set_defaults(func=budget_sweep) |
|
|
| p3 = sub.add_parser("scale", help="scaling runtime with synthetic data") |
| p3.add_argument("--n-max", type=int, default=100000) |
| p3.add_argument("--budget", type=float, default=0.3) |
| p3.set_defaults(func=scaling) |
|
|
| p4 = sub.add_parser("dp-curve", help="% of optimal vs budget (synthetic)") |
| p4.add_argument("--n", type=int, default=2000) |
| p4.add_argument("--n-opt", type=int, default=200) |
| p4.add_argument("--b-min", type=int, default=10) |
| p4.add_argument("--b-max", type=int, default=90) |
| p4.add_argument("--b-step", type=int, default=10) |
| p4.add_argument("--out-dir", default="benchmarks/out") |
| p4.add_argument("--save-plots", action="store_true") |
| p4.set_defaults(func=dp_curve) |
|
|
| p5 = sub.add_parser("haystack-compare", help="compare BM25 vs dense retrievers (Haystack)") |
| p5.add_argument("--corpus", default="examples/corpus/sample_docs.jsonl") |
| p5.add_argument("--queries", default="examples/corpus/sample_queries.jsonl") |
| p5.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2") |
| p5.add_argument("--k", type=int, default=50) |
| p5.add_argument("--max-q", type=int, default=10) |
| p5.set_defaults(func=compare_haystack) |
|
|
| args = ap.parse_args() |
| args.func(args) |
|
|
| if __name__ == "__main__": |
| main() |
|
|