File size: 5,417 Bytes
754890f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Round 3b: validate the NIAH harness against a small control model on plain
HuggingFace transformers (no AirLLM). Confirms the harness works end-to-end
on a real LLM before we point it at the much heavier Kimi K2.6 path.

Defaults to Qwen2.5-3B-Instruct (downloads ~6 GB on first run). For a smaller
fingerprint, pass `--model unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit` (~2 GB).

Run:
    python scripts/run_control_niah.py \
        --model Qwen/Qwen2.5-3B-Instruct \
        --context-lengths 2000 8000 \
        --trials 10 \
        --out-dir /tmp/niah_control

Expected: Qwen2.5-3B should score >= 80% on the 2K-char (~500-token) cell.
The 8K cell will be lower; that's the point — proves the harness picks up
context-length quality decay.
"""
from __future__ import annotations

import argparse
import csv
import sys
import time
from pathlib import Path

# Local imports
sys.path.insert(0, str(Path(__file__).resolve().parent))
from niah_harness import run_niah_cell, write_cell_csv


def make_hf_generate_fn(model_id: str, dtype: str = "bfloat16", device: str = "auto", max_new_tokens: int = 64):
    """Return (model, tokenizer, generate_fn) for a HF transformers model."""
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    print(f"[control] loading {model_id} ({dtype}) ...")
    t0 = time.time()
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    torch_dtype = getattr(torch, dtype) if hasattr(torch, dtype) else torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        dtype=torch_dtype,
        device_map=device,
        trust_remote_code=True,
    )
    model.eval()
    print(f"[control] loaded in {time.time()-t0:.1f}s, params: {sum(p.numel() for p in model.parameters()):,}")

    def generate(prompt: str) -> str:
        inputs = tok(prompt, return_tensors="pt").to(model.device)
        prompt_len = inputs["input_ids"].shape[1]
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                pad_token_id=tok.eos_token_id,
            )
        # Strip prompt tokens, decode just the new tokens
        new_tokens = out[0, prompt_len:]
        return tok.decode(new_tokens, skip_special_tokens=True)

    return model, tok, generate


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct",
                    help="HF model ID. Defaults to Qwen2.5-3B-Instruct.")
    ap.add_argument("--dtype", default="bfloat16")
    ap.add_argument("--device", default="auto")
    ap.add_argument("--context-lengths", type=int, nargs="+", default=[2000, 8000])
    ap.add_argument("--trials", type=int, default=10)
    ap.add_argument("--max-new-tokens", type=int, default=64)
    ap.add_argument("--out-dir", default="/tmp/niah_control")
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    model, tok, gen = make_hf_generate_fn(
        args.model, dtype=args.dtype, device=args.device, max_new_tokens=args.max_new_tokens
    )

    print(f"\n[control] sweep: {len(args.context_lengths)} cells, {args.trials} trials each")
    summary = []
    for ctx in args.context_lengths:
        cell_id = f"ctx{ctx}_baseline"
        print(f"\n[control] === cell {cell_id} ===")
        result = run_niah_cell(
            generate_fn=gen,
            target_chars=ctx,
            n_trials=args.trials,
            seed=args.seed,
        )
        cell_csv = out_dir / f"{cell_id}.csv"
        write_cell_csv(result, cell_csv)
        print(f"[control] accuracy={result.accuracy:.2%} ({result.n_correct}/{result.n_trials})  "
              f"elapsed={result.elapsed_s:.1f}s  mean_response_chars={result.mean_response_chars:.0f}")
        summary.append({
            "cell_id": cell_id,
            "ctx_chars": ctx,
            "n_trials": result.n_trials,
            "n_correct": result.n_correct,
            "accuracy": round(result.accuracy, 4),
            "elapsed_s": round(result.elapsed_s, 1),
        })

    summary_path = out_dir / "summary.csv"
    if summary:
        with open(summary_path, "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=list(summary[0].keys()))
            w.writeheader()
            w.writerows(summary)
        print(f"\n[control] summary -> {summary_path}")

    # Sanity check: the harness should resolve the needle on at least the
    # smallest context. If 2K accuracy is < 50%, something is wrong with
    # either the harness OR the model is fundamentally broken at retrieval.
    if summary:
        smallest = min(summary, key=lambda s: s["ctx_chars"])
        if smallest["accuracy"] < 0.5:
            print(f"\n[control] WARNING: smallest-context accuracy {smallest['accuracy']:.2%} is < 50%.")
            print("[control] either the harness is broken or the model can't do NIAH at all.")
            print("[control] inspect a per-trial CSV to see what the model actually produced.")
        else:
            print(f"\n[control] PASS: smallest-context accuracy {smallest['accuracy']:.2%} >= 50%.")
            print("[control] harness validated end-to-end on a real LLM.")


if __name__ == "__main__":
    main()