RiverRider commited on
Commit
e310afe
·
verified ·
1 Parent(s): 156b473

cross-backbone probe (scripts/cross_backbone_probe.py)

Browse files
Files changed (1) hide show
  1. scripts/cross_backbone_probe.py +138 -0
scripts/cross_backbone_probe.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cross-backbone community-vector probe (LOO k-NN recall@1).
2
+
3
+ For each backbone, extract the mean-pooled hidden state at a target layer for
4
+ every val_200 sample, then compute leave-one-out 1-nearest-neighbor recall@1
5
+ over the 35 coarse community labels. This isolates the question:
6
+
7
+ "Is the discourse-community signal that v8a's CDH lights up specific to
8
+ Qwen, or is it latent in other 7B causal LMs as well?"
9
+
10
+ We do NOT need any adapter or training to answer this -- v8a's CDH at
11
+ layer 4 simply mean-pools and projects, so the headline number ought to be
12
+ recoverable from raw layer-4 features (paper's 0.484 figure is built on top
13
+ of this same signal, just refined through 64-D community vector contrastive
14
+ learning).
15
+
16
+ Usage:
17
+ python scripts/cross_backbone_probe.py \
18
+ --val-data data/val_200.jsonl \
19
+ --backbones Qwen/Qwen2.5-7B mistralai/Mistral-7B-v0.3 \
20
+ --layers 4 8 12 \
21
+ --out artifacts/cross_backbone.json
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import gc
28
+ import json
29
+ import time
30
+ from pathlib import Path
31
+
32
+ import numpy as np
33
+ import torch
34
+ from transformers import AutoModelForCausalLM, AutoTokenizer
35
+
36
+
37
+ def mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
38
+ # h: (1, T, D); mask: (1, T)
39
+ m = mask.unsqueeze(-1).to(h.dtype)
40
+ return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
41
+
42
+
43
+ def loo_knn_recall_at_1(feats: np.ndarray, labels: np.ndarray) -> float:
44
+ # cosine sim, leave-one-out
45
+ f = feats / (np.linalg.norm(feats, axis=1, keepdims=True) + 1e-9)
46
+ sim = f @ f.T
47
+ np.fill_diagonal(sim, -np.inf)
48
+ pred = labels[sim.argmax(axis=1)]
49
+ return float((pred == labels).mean())
50
+
51
+
52
+ @torch.no_grad()
53
+ def extract(backbone_id: str, samples: list[dict], layers: list[int],
54
+ max_len: int, device: str) -> dict[int, np.ndarray]:
55
+ print(f"\n=== {backbone_id} ===")
56
+ t0 = time.time()
57
+ tok = AutoTokenizer.from_pretrained(backbone_id)
58
+ if tok.pad_token is None:
59
+ tok.pad_token = tok.eos_token
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ backbone_id, torch_dtype=torch.bfloat16, device_map=device,
62
+ attn_implementation="eager",
63
+ )
64
+ model.eval()
65
+ n_layers = model.config.num_hidden_layers
66
+ layers_eff = [l for l in layers if l < n_layers]
67
+ print(f" loaded in {time.time()-t0:.1f}s, n_layers={n_layers}, "
68
+ f"using layers {layers_eff}")
69
+ out = {l: [] for l in layers_eff}
70
+ for i, ex in enumerate(samples):
71
+ enc = tok(ex["text"], return_tensors="pt", truncation=True,
72
+ max_length=max_len).to(device)
73
+ h = model(**enc, output_hidden_states=True, use_cache=False).hidden_states
74
+ # h is tuple of length n_layers+1 (embedding + per-layer)
75
+ for l in layers_eff:
76
+ v = mean_pool(h[l + 1], enc.attention_mask)[0].float().cpu().numpy()
77
+ out[l].append(v)
78
+ if (i + 1) % 50 == 0:
79
+ print(f" [{i+1}/{len(samples)}] elapsed={time.time()-t0:.1f}s")
80
+ out_arr = {l: np.stack(out[l]) for l in layers_eff}
81
+ del model, tok
82
+ gc.collect()
83
+ torch.cuda.empty_cache()
84
+ return out_arr
85
+
86
+
87
+ def main() -> None:
88
+ ap = argparse.ArgumentParser()
89
+ ap.add_argument("--val-data", default="data/val_200.jsonl")
90
+ ap.add_argument("--backbones", nargs="+",
91
+ default=["Qwen/Qwen2.5-7B",
92
+ "mistralai/Mistral-7B-v0.3",
93
+ "meta-llama/Meta-Llama-3-8B"])
94
+ ap.add_argument("--layers", type=int, nargs="+",
95
+ default=[4, 8, 12, 16])
96
+ ap.add_argument("--max-len", type=int, default=512)
97
+ ap.add_argument("--out", default="artifacts/cross_backbone.json")
98
+ ap.add_argument("--device", default="cuda")
99
+ args = ap.parse_args()
100
+
101
+ samples = [json.loads(l) for l in Path(args.val_data).read_text().splitlines()]
102
+ labels = np.array([s["community_id"] for s in samples])
103
+ n_classes = len(set(labels.tolist()))
104
+ chance = 1.0 / n_classes
105
+ print(f"val_200: n={len(samples)} classes={n_classes} chance={chance:.3f}")
106
+
107
+ results = {
108
+ "n_samples": len(samples),
109
+ "n_classes": n_classes,
110
+ "chance_recall_at_1": chance,
111
+ "backbones": {},
112
+ }
113
+ for backbone in args.backbones:
114
+ try:
115
+ feats = extract(backbone, samples, args.layers, args.max_len, args.device)
116
+ except Exception as e: # noqa: BLE001
117
+ print(f" SKIP {backbone}: {e}")
118
+ results["backbones"][backbone] = {"error": str(e)}
119
+ continue
120
+ per_layer = {}
121
+ for layer, X in feats.items():
122
+ r1 = loo_knn_recall_at_1(X, labels)
123
+ per_layer[str(layer)] = {
124
+ "recall_at_1": r1,
125
+ "x_chance": r1 / chance,
126
+ "feature_dim": int(X.shape[1]),
127
+ }
128
+ print(f" layer {layer:>2} D={X.shape[1]:>5} "
129
+ f"recall@1={r1:.3f} ({r1/chance:.1f}x chance)")
130
+ results["backbones"][backbone] = {"per_layer": per_layer}
131
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
132
+ Path(args.out).write_text(json.dumps(results, indent=2))
133
+
134
+ print(f"\nwrote {args.out}")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()