File size: 5,203 Bytes
e310afe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """Cross-backbone community-vector probe (LOO k-NN recall@1).
For each backbone, extract the mean-pooled hidden state at a target layer for
every val_200 sample, then compute leave-one-out 1-nearest-neighbor recall@1
over the 35 coarse community labels. This isolates the question:
"Is the discourse-community signal that v8a's CDH lights up specific to
Qwen, or is it latent in other 7B causal LMs as well?"
We do NOT need any adapter or training to answer this -- v8a's CDH at
layer 4 simply mean-pools and projects, so the headline number ought to be
recoverable from raw layer-4 features (paper's 0.484 figure is built on top
of this same signal, just refined through 64-D community vector contrastive
learning).
Usage:
python scripts/cross_backbone_probe.py \
--val-data data/val_200.jsonl \
--backbones Qwen/Qwen2.5-7B mistralai/Mistral-7B-v0.3 \
--layers 4 8 12 \
--out artifacts/cross_backbone.json
"""
from __future__ import annotations
import argparse
import gc
import json
import time
from pathlib import Path
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# h: (1, T, D); mask: (1, T)
m = mask.unsqueeze(-1).to(h.dtype)
return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
def loo_knn_recall_at_1(feats: np.ndarray, labels: np.ndarray) -> float:
# cosine sim, leave-one-out
f = feats / (np.linalg.norm(feats, axis=1, keepdims=True) + 1e-9)
sim = f @ f.T
np.fill_diagonal(sim, -np.inf)
pred = labels[sim.argmax(axis=1)]
return float((pred == labels).mean())
@torch.no_grad()
def extract(backbone_id: str, samples: list[dict], layers: list[int],
max_len: int, device: str) -> dict[int, np.ndarray]:
print(f"\n=== {backbone_id} ===")
t0 = time.time()
tok = AutoTokenizer.from_pretrained(backbone_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
backbone_id, torch_dtype=torch.bfloat16, device_map=device,
attn_implementation="eager",
)
model.eval()
n_layers = model.config.num_hidden_layers
layers_eff = [l for l in layers if l < n_layers]
print(f" loaded in {time.time()-t0:.1f}s, n_layers={n_layers}, "
f"using layers {layers_eff}")
out = {l: [] for l in layers_eff}
for i, ex in enumerate(samples):
enc = tok(ex["text"], return_tensors="pt", truncation=True,
max_length=max_len).to(device)
h = model(**enc, output_hidden_states=True, use_cache=False).hidden_states
# h is tuple of length n_layers+1 (embedding + per-layer)
for l in layers_eff:
v = mean_pool(h[l + 1], enc.attention_mask)[0].float().cpu().numpy()
out[l].append(v)
if (i + 1) % 50 == 0:
print(f" [{i+1}/{len(samples)}] elapsed={time.time()-t0:.1f}s")
out_arr = {l: np.stack(out[l]) for l in layers_eff}
del model, tok
gc.collect()
torch.cuda.empty_cache()
return out_arr
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--val-data", default="data/val_200.jsonl")
ap.add_argument("--backbones", nargs="+",
default=["Qwen/Qwen2.5-7B",
"mistralai/Mistral-7B-v0.3",
"meta-llama/Meta-Llama-3-8B"])
ap.add_argument("--layers", type=int, nargs="+",
default=[4, 8, 12, 16])
ap.add_argument("--max-len", type=int, default=512)
ap.add_argument("--out", default="artifacts/cross_backbone.json")
ap.add_argument("--device", default="cuda")
args = ap.parse_args()
samples = [json.loads(l) for l in Path(args.val_data).read_text().splitlines()]
labels = np.array([s["community_id"] for s in samples])
n_classes = len(set(labels.tolist()))
chance = 1.0 / n_classes
print(f"val_200: n={len(samples)} classes={n_classes} chance={chance:.3f}")
results = {
"n_samples": len(samples),
"n_classes": n_classes,
"chance_recall_at_1": chance,
"backbones": {},
}
for backbone in args.backbones:
try:
feats = extract(backbone, samples, args.layers, args.max_len, args.device)
except Exception as e: # noqa: BLE001
print(f" SKIP {backbone}: {e}")
results["backbones"][backbone] = {"error": str(e)}
continue
per_layer = {}
for layer, X in feats.items():
r1 = loo_knn_recall_at_1(X, labels)
per_layer[str(layer)] = {
"recall_at_1": r1,
"x_chance": r1 / chance,
"feature_dim": int(X.shape[1]),
}
print(f" layer {layer:>2} D={X.shape[1]:>5} "
f"recall@1={r1:.3f} ({r1/chance:.1f}x chance)")
results["backbones"][backbone] = {"per_layer": per_layer}
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
Path(args.out).write_text(json.dumps(results, indent=2))
print(f"\nwrote {args.out}")
if __name__ == "__main__":
main()
|