Instructions to use GenomaLabs-com/kv-cache-eviction-mla with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use GenomaLabs-com/kv-cache-eviction-mla with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("GenomaLabs-com/kv-cache-eviction-mla", dtype="auto") - Notebooks
- Google Colab
- Kaggle
GENOMA LABS / research
Round 3b prep: control-validation driver (plain HF transformers, no AirLLM)
754890f | """Round 3b: validate the NIAH harness against a small control model on plain | |
| HuggingFace transformers (no AirLLM). Confirms the harness works end-to-end | |
| on a real LLM before we point it at the much heavier Kimi K2.6 path. | |
| Defaults to Qwen2.5-3B-Instruct (downloads ~6 GB on first run). For a smaller | |
| fingerprint, pass `--model unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit` (~2 GB). | |
| Run: | |
| python scripts/run_control_niah.py \ | |
| --model Qwen/Qwen2.5-3B-Instruct \ | |
| --context-lengths 2000 8000 \ | |
| --trials 10 \ | |
| --out-dir /tmp/niah_control | |
| Expected: Qwen2.5-3B should score >= 80% on the 2K-char (~500-token) cell. | |
| The 8K cell will be lower; that's the point — proves the harness picks up | |
| context-length quality decay. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import sys | |
| import time | |
| from pathlib import Path | |
| # Local imports | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| from niah_harness import run_niah_cell, write_cell_csv | |
| def make_hf_generate_fn(model_id: str, dtype: str = "bfloat16", device: str = "auto", max_new_tokens: int = 64): | |
| """Return (model, tokenizer, generate_fn) for a HF transformers model.""" | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| print(f"[control] loading {model_id} ({dtype}) ...") | |
| t0 = time.time() | |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| torch_dtype = getattr(torch, dtype) if hasattr(torch, dtype) else torch.float16 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| dtype=torch_dtype, | |
| device_map=device, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| print(f"[control] loaded in {time.time()-t0:.1f}s, params: {sum(p.numel() for p in model.parameters()):,}") | |
| def generate(prompt: str) -> str: | |
| inputs = tok(prompt, return_tensors="pt").to(model.device) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=tok.eos_token_id, | |
| ) | |
| # Strip prompt tokens, decode just the new tokens | |
| new_tokens = out[0, prompt_len:] | |
| return tok.decode(new_tokens, skip_special_tokens=True) | |
| return model, tok, generate | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct", | |
| help="HF model ID. Defaults to Qwen2.5-3B-Instruct.") | |
| ap.add_argument("--dtype", default="bfloat16") | |
| ap.add_argument("--device", default="auto") | |
| ap.add_argument("--context-lengths", type=int, nargs="+", default=[2000, 8000]) | |
| ap.add_argument("--trials", type=int, default=10) | |
| ap.add_argument("--max-new-tokens", type=int, default=64) | |
| ap.add_argument("--out-dir", default="/tmp/niah_control") | |
| ap.add_argument("--seed", type=int, default=42) | |
| args = ap.parse_args() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| model, tok, gen = make_hf_generate_fn( | |
| args.model, dtype=args.dtype, device=args.device, max_new_tokens=args.max_new_tokens | |
| ) | |
| print(f"\n[control] sweep: {len(args.context_lengths)} cells, {args.trials} trials each") | |
| summary = [] | |
| for ctx in args.context_lengths: | |
| cell_id = f"ctx{ctx}_baseline" | |
| print(f"\n[control] === cell {cell_id} ===") | |
| result = run_niah_cell( | |
| generate_fn=gen, | |
| target_chars=ctx, | |
| n_trials=args.trials, | |
| seed=args.seed, | |
| ) | |
| cell_csv = out_dir / f"{cell_id}.csv" | |
| write_cell_csv(result, cell_csv) | |
| print(f"[control] accuracy={result.accuracy:.2%} ({result.n_correct}/{result.n_trials}) " | |
| f"elapsed={result.elapsed_s:.1f}s mean_response_chars={result.mean_response_chars:.0f}") | |
| summary.append({ | |
| "cell_id": cell_id, | |
| "ctx_chars": ctx, | |
| "n_trials": result.n_trials, | |
| "n_correct": result.n_correct, | |
| "accuracy": round(result.accuracy, 4), | |
| "elapsed_s": round(result.elapsed_s, 1), | |
| }) | |
| summary_path = out_dir / "summary.csv" | |
| if summary: | |
| with open(summary_path, "w", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=list(summary[0].keys())) | |
| w.writeheader() | |
| w.writerows(summary) | |
| print(f"\n[control] summary -> {summary_path}") | |
| # Sanity check: the harness should resolve the needle on at least the | |
| # smallest context. If 2K accuracy is < 50%, something is wrong with | |
| # either the harness OR the model is fundamentally broken at retrieval. | |
| if summary: | |
| smallest = min(summary, key=lambda s: s["ctx_chars"]) | |
| if smallest["accuracy"] < 0.5: | |
| print(f"\n[control] WARNING: smallest-context accuracy {smallest['accuracy']:.2%} is < 50%.") | |
| print("[control] either the harness is broken or the model can't do NIAH at all.") | |
| print("[control] inspect a per-trial CSV to see what the model actually produced.") | |
| else: | |
| print(f"\n[control] PASS: smallest-context accuracy {smallest['accuracy']:.2%} >= 50%.") | |
| print("[control] harness validated end-to-end on a real LLM.") | |
| if __name__ == "__main__": | |
| main() | |