Flamehaven's picture
Initial commit: Add project structure and all source files
aca22b8
from __future__ import annotations
import argparse
import json
import os
import time
from dataclasses import dataclass
from typing import List, Dict, Sequence
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from crom_efficientllm.budget_packer.packer import budget_pack, Chunk
from crom_efficientllm.rerank_engine.rerank import hybrid_rerank
try:
from sentence_transformers import SentenceTransformer
except Exception: # pragma: no cover
SentenceTransformer = None # type: ignore
# Optional plugins are imported lazily when flags are set
@dataclass
class Doc:
id: str
text: str
def load_jsonl(path: str) -> List[Dict]:
with open(path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
def build_corpus(path: str) -> List[Doc]:
rows = load_jsonl(path)
return [Doc(id=str(r.get("id", i)), text=str(r["text"])) for i, r in enumerate(rows)]
def sparse_retrieval(query: str, corpus: Sequence[Doc], k: int = 100) -> List[Dict]:
texts = [d.text for d in corpus]
vect = TfidfVectorizer(ngram_range=(1, 2)).fit(texts)
D = vect.transform(texts)
Q = vect.transform([query])
sims = cosine_similarity(Q, D).ravel()
order = np.argsort(-sims)[:k]
return [{"id": corpus[i].id, "text": corpus[i].text, "score_sparse": float(sims[i])} for i in order]
def dense_embed_model(name: str):
if SentenceTransformer is None:
raise RuntimeError("sentence-transformers not installed. Install with `pip install -e .`.")
return SentenceTransformer(name)
def _apply_flashrank(query: str, docs: List[Dict], model_name: str) -> List[Dict]:
try:
from crom_efficientllm.plugins.flashrank_reranker import flashrank_rerank
except Exception as e: # pragma: no cover
raise RuntimeError("FlashRank plugin not available. Install extras: pip install .[plugins]") from e
ranked = flashrank_rerank(query, docs, model_name=model_name)
# Normalize plugin score to 0..1 and put into score_final
scores = np.array([d.get("score_flashrank", 0.0) for d in ranked], dtype=np.float32)
if scores.size and float(scores.max() - scores.min()) > 1e-12:
s = (scores - scores.min()) / (scores.max() - scores.min())
else:
s = np.zeros_like(scores)
for i, d in enumerate(ranked):
d["score_final"] = float(s[i])
return ranked
def _apply_llmlingua(text: str, ratio: float) -> str:
try:
from crom_efficientllm.plugins.llmlingua_compressor import compress_prompt
except Exception as e: # pragma: no cover
raise RuntimeError("LLMLingua plugin not available. Install extras: pip install .[plugins]") from e
return compress_prompt(text, target_ratio=ratio)
def _save_evidently_report(all_embs: List[List[float]], out_html: str) -> None:
try:
from crom_efficientllm.plugins.evidently_drift import drift_report
except Exception as e: # pragma: no cover
raise RuntimeError("Evidently plugin not available. Install extras: pip install .[plugins]") from e
n = len(all_embs)
if n < 4:
return
ref = all_embs[: n // 2]
cur = all_embs[n // 2 :]
rep = drift_report(ref, cur)
rep.save_html(out_html)
def mock_llm_generate(prompt: str) -> str:
time.sleep(0.005) # simulate small latency
return "[MOCK] " + prompt[:160]
def e2e(args: argparse.Namespace) -> None:
corpus = build_corpus(args.corpus)
queries = [r["query"] for r in load_jsonl(args.queries)]
embed = dense_embed_model(args.model)
all_embs: List[List[float]] = []
t0 = time.perf_counter()
all_rows = []
for q in queries:
t_s = time.perf_counter()
cands = sparse_retrieval(q, corpus, k=args.k)
t_sparse = (time.perf_counter() - t_s) * 1000
t_r = time.perf_counter()
if args.use_flashrank:
reranked = _apply_flashrank(q, cands, args.flashrank_model)
else:
reranked = hybrid_rerank(q, cands, embed, alpha=args.alpha)
t_rerank = (time.perf_counter() - t_r) * 1000
# token heuristic + budget pack
chunks = [
Chunk(text=d["text"], score=d.get("score_final", d.get("score_sparse", 0.0)), tokens=max(1, len(d["text"]) // 4))
for d in reranked
]
budget_tokens = int(sum(c.tokens for c in chunks) * args.budget)
t_p = time.perf_counter()
packed = budget_pack(chunks, budget=budget_tokens)
t_pack = (time.perf_counter() - t_p) * 1000
prompt = "\n\n".join(c.text for c in packed) + f"\n\nQ: {q}\nA:"
if args.use_llmlingua:
prompt = _apply_llmlingua(prompt, ratio=args.compress_ratio)
# collect embeddings for drift snapshot (mean-pooled)
with np.errstate(all="ignore"):
if len(packed) > 0:
doc_embs = embed.encode([c.text for c in packed], convert_to_numpy=True)
vec = np.mean(doc_embs, axis=0).tolist()
all_embs.append(vec)
t_l = time.perf_counter()
_ = mock_llm_generate(prompt)
t_llm = (time.perf_counter() - t_l) * 1000
total = (time.perf_counter() - t_s) * 1000
all_rows.append({
"query": q,
"sparse_ms": t_sparse,
"rerank_ms": t_rerank,
"pack_ms": t_pack,
"llm_ms": t_llm,
"total_ms": total,
"packed_tokens": sum(c.tokens for c in packed),
"orig_tokens": sum(c.tokens for c in chunks),
"save_ratio": 1 - (sum(c.tokens for c in packed) / max(1, sum(c.tokens for c in chunks))),
"used_flashrank": bool(args.use_flashrank),
"used_llmlingua": bool(args.use_llmlingua),
})
elapsed = (time.perf_counter() - t0) * 1000
os.makedirs(args.out_dir, exist_ok=True)
out_path = os.path.join(args.out_dir, "e2e_results.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for r in all_rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"saved results -> {out_path} ({len(all_rows)} queries) ; elapsed={elapsed:.2f}ms")
if args.use_evidently and all_embs:
html_path = os.path.join(args.out_dir, "evidently_report.html")
_save_evidently_report(all_embs, html_path)
print(f"evidently report -> {html_path}")
def budget_sweep(args: argparse.Namespace) -> None:
import itertools
corpus = build_corpus(args.corpus)
queries = [r["query"] for r in load_jsonl(args.queries)][: args.max_q]
embed = dense_embed_model(args.model)
budgets = [b / 100.0 for b in range(args.b_min, args.b_max + 1, args.b_step)]
rows = []
for q, b in itertools.product(queries, budgets):
cands = sparse_retrieval(q, corpus, k=args.k)
reranked = hybrid_rerank(q, cands, embed, alpha=args.alpha)
chunks = [Chunk(text=d["text"], score=d["score_final"], tokens=max(1, len(d["text"]) // 4)) for d in reranked]
budget_tokens = int(sum(c.tokens for c in chunks) * b)
packed = budget_pack(chunks, budget=budget_tokens)
rows.append({
"query": q,
"budget": b,
"packed_tokens": sum(c.tokens for c in packed),
"orig_tokens": sum(c.tokens for c in chunks),
"save_ratio": 1 - (sum(c.tokens for c in packed) / max(1, sum(c.tokens for c in chunks))),
"avg_score": float(np.mean([c.score for c in packed])) if packed else 0.0,
})
os.makedirs(args.out_dir, exist_ok=True)
out_path = os.path.join(args.out_dir, "budget_sweep.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"saved results -> {out_path} ; points={len(rows)}")
if args.save_plots:
try:
import matplotlib.pyplot as plt # noqa: F401
import matplotlib.pyplot as _plt
except Exception:
print("[warn] matplotlib not installed; install dev extras: pip install -e .[dev]")
else:
# Aggregate by budget
import collections
agg = collections.defaultdict(list)
for r in rows:
agg[r["budget"]].append(r)
budgets_sorted = sorted(agg.keys())
avg_save = [float(np.mean([x["save_ratio"] for x in agg[b]])) for b in budgets_sorted]
avg_score = [float(np.mean([x["avg_score"] for x in agg[b]])) for b in budgets_sorted]
_plt.figure()
_plt.plot([b * 100 for b in budgets_sorted], [s * 100 for s in avg_save], marker="o")
_plt.xlabel("Budget (%)")
_plt.ylabel("Avg Save Ratio (%)")
_plt.title("Budget Sweep: Save Ratio vs Budget")
_plt.grid(True)
_plt.tight_layout()
_plt.savefig(os.path.join(args.out_dir, "budget_sweep.png"))
_plt.figure()
_plt.plot([s * 100 for s in avg_save], avg_score, marker="o")
_plt.xlabel("Save Ratio (%)")
_plt.ylabel("Avg Score (packed)")
_plt.title("Pareto: Quality vs Savings")
_plt.grid(True)
_plt.tight_layout()
_plt.savefig(os.path.join(args.out_dir, "budget_pareto.png"))
print("plots ->", os.path.join(args.out_dir, "budget_sweep.png"), ",", os.path.join(args.out_dir, "budget_pareto.png"))
def scaling(args: argparse.Namespace) -> None:
def make_synth(n: int, seed: int = 42):
rng = np.random.default_rng(seed)
tokens = np.clip(rng.lognormal(4.0, 0.6, n).astype(int), 5, 2000)
score = rng.normal(0, 1, n)
return [Chunk(text="x" * int(t * 4), score=float(s), tokens=int(t)) for s, t in zip(score, tokens)]
for n in [1000, 5000, 10000, 20000, 50000, 100000]:
if n > args.n_max:
break
chunks = make_synth(n)
budget = int(sum(c.tokens for c in chunks) * args.budget)
t0 = time.perf_counter()
_ = budget_pack(chunks, budget)
ms = (time.perf_counter() - t0) * 1000
print(f"n={n:6d} budget={args.budget:.0%} time={ms:8.2f} ms")
def dp_curve(args: argparse.Namespace) -> None:
def make_synth(n: int, seed: int = 123, corr: float = 0.6):
rng = np.random.default_rng(seed)
true_rel = rng.normal(0, 1, n)
noise = rng.normal(0, 1, n) * np.sqrt(1 - corr**2)
score = corr * true_rel + noise
tokens = np.clip(rng.lognormal(4.0, 0.6, n).astype(int), 5, 2000)
chunks = [Chunk(text="x" * int(t * 4), score=float(s), tokens=int(t)) for s, t in zip(score, tokens)]
return chunks, true_rel
def optimal(chunks: Sequence[Chunk], values: np.ndarray, budget: int) -> float:
B = budget
dp = np.zeros(B + 1, dtype=np.float32)
for i, ch in enumerate(chunks):
wt = ch.tokens
val = max(0.0, float(values[i]))
for b in range(B, wt - 1, -1):
dp[b] = max(dp[b], dp[b - wt] + val)
return float(dp[B])
chunks, true_rel = make_synth(args.n)
total = sum(c.tokens for c in chunks)
budgets = [int(total * b / 100.0) for b in range(args.b_min, args.b_max + 1, args.b_step)]
out_rows = []
for B in budgets:
sel = budget_pack(chunks, B)
idx_map = {id(c): i for i, c in enumerate(chunks)}
rel_bp = float(np.sum([max(0.0, true_rel[idx_map[id(c)]]) for c in sel]))
rel_opt = optimal(chunks[: args.n_opt], true_rel[: args.n_opt], min(B, sum(c.tokens for c in chunks[: args.n_opt])))
pct = rel_bp / max(rel_opt, 1e-9)
out_rows.append({"budget": B, "pct": pct, "rel_bp": rel_bp, "rel_opt": rel_opt})
print(f"budget={B:8d} rel_bp={rel_bp:8.3f} rel_opt≈{rel_opt:8.3f} pct≈{pct*100:5.1f}% (subset n={args.n_opt})")
if args.save_plots:
try:
import matplotlib.pyplot as plt # noqa: F401
import matplotlib.pyplot as _plt
except Exception:
print("[warn] matplotlib not installed; install dev extras: pip install -e .[dev]")
else:
_plt.figure()
xs = [r["budget"] * 100.0 / total for r in out_rows]
ys = [r["pct"] * 100 for r in out_rows]
_plt.plot(xs, ys, marker="o")
_plt.xlabel("Budget (%)")
_plt.ylabel("% of optimal (subset)")
_plt.title("DP Curve: Greedy vs Optimal")
_plt.grid(True)
_plt.tight_layout()
os.makedirs(args.out_dir, exist_ok=True)
_plt.savefig(os.path.join(args.out_dir, "dp_curve.png"))
print("plot ->", os.path.join(args.out_dir, "dp_curve.png"))
def compare_haystack(args: argparse.Namespace) -> None:
try:
from haystack.nodes import BM25Retriever, SentenceTransformersRetriever
from haystack.document_stores import InMemoryDocumentStore
except Exception as e: # pragma: no cover
raise RuntimeError("Install extras: pip install .[haystack]") from e
corpus = build_corpus(args.corpus)
docs = [{"content": d.text, "meta": {"id": d.id}} for d in corpus]
store = InMemoryDocumentStore(use_bm25=True)
store.write_documents(docs)
bm25 = BM25Retriever(document_store=store)
dretr = SentenceTransformersRetriever(document_store=store, model_name_or_path=args.model)
queries = [r["query"] for r in load_jsonl(args.queries)][: args.max_q]
for q in queries:
t0 = time.perf_counter()
bm = bm25.retrieve(q, top_k=args.k)
dn = dretr.retrieve(q, top_k=args.k)
ms = (time.perf_counter() - t0) * 1000
print(f"{q[:40]:40s} bm25={len(bm):3d} dense={len(dn):3d} time={ms:7.2f} ms")
def main() -> None:
ap = argparse.ArgumentParser(prog="crom-bench")
sub = ap.add_subparsers(dest="cmd", required=True)
p = sub.add_parser("e2e", help="end-to-end: retrieval → rerank → pack → mock LLM")
p.add_argument("--corpus", default="examples/corpus/sample_docs.jsonl")
p.add_argument("--queries", default="examples/corpus/sample_queries.jsonl")
p.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
p.add_argument("--k", type=int, default=200)
p.add_argument("--alpha", type=float, default=0.5)
p.add_argument("--budget", type=float, default=0.3)
# plugins
p.add_argument("--use-flashrank", action="store_true")
p.add_argument("--flashrank-model", default="ms-marco-TinyBERT-L-2-v2")
p.add_argument("--use-llmlingua", action="store_true")
p.add_argument("--compress-ratio", type=float, default=0.6)
p.add_argument("--use-evidently", action="store_true")
p.add_argument("--out-dir", default="benchmarks/out")
p.set_defaults(func=e2e)
p2 = sub.add_parser("sweep", help="budget sweep + Pareto csv")
p2.add_argument("--corpus", default="examples/corpus/sample_docs.jsonl")
p2.add_argument("--queries", default="examples/corpus/sample_queries.jsonl")
p2.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
p2.add_argument("--k", type=int, default=200)
p2.add_argument("--alpha", type=float, default=0.5)
p2.add_argument("--b-min", type=int, default=10)
p2.add_argument("--b-max", type=int, default=90)
p2.add_argument("--b-step", type=int, default=10)
p2.add_argument("--max-q", type=int, default=20)
p2.add_argument("--out-dir", default="benchmarks/out")
p2.add_argument("--save-plots", action="store_true")
p2.set_defaults(func=budget_sweep)
p3 = sub.add_parser("scale", help="scaling runtime with synthetic data")
p3.add_argument("--n-max", type=int, default=100000)
p3.add_argument("--budget", type=float, default=0.3)
p3.set_defaults(func=scaling)
p4 = sub.add_parser("dp-curve", help="% of optimal vs budget (synthetic)")
p4.add_argument("--n", type=int, default=2000)
p4.add_argument("--n-opt", type=int, default=200)
p4.add_argument("--b-min", type=int, default=10)
p4.add_argument("--b-max", type=int, default=90)
p4.add_argument("--b-step", type=int, default=10)
p4.add_argument("--out-dir", default="benchmarks/out")
p4.add_argument("--save-plots", action="store_true")
p4.set_defaults(func=dp_curve)
p5 = sub.add_parser("haystack-compare", help="compare BM25 vs dense retrievers (Haystack)")
p5.add_argument("--corpus", default="examples/corpus/sample_docs.jsonl")
p5.add_argument("--queries", default="examples/corpus/sample_queries.jsonl")
p5.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
p5.add_argument("--k", type=int, default=50)
p5.add_argument("--max-q", type=int, default=10)
p5.set_defaults(func=compare_haystack)
args = ap.parse_args()
args.func(args)
if __name__ == "__main__":
main()