# ============================================================================ # INTERNAL ANALYZER: CaptionBERT-8192 # # Sees inside the model, not just the output. Five diagnostic lenses: # 1. Spectral trajectories — eigenvalue evolution per layer # 2. Effective dimensionality — how deeply each input is understood # 3. Cross-layer divergence — where computation actually happens # 4. Token influence — which input tokens drive the output # 5. Neighborhood structure — local geometry at each layer # # Usage: # analyzer = InternalAnalyzer(model, tokenizer) # report = analyzer.analyze(["girl", "woman", "subtraction", "multiplication"]) # analyzer.print_report(report) # analyzer.compare(report, "girl", "subtraction") # ============================================================================ import torch import torch.nn.functional as F import numpy as np from collections import defaultdict DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class InternalAnalyzer: def __init__(self, model, tokenizer, max_len=512): self.model = model.to(DEVICE).eval() self.tokenizer = tokenizer self.max_len = max_len # ══════════════════════════════════════════════════════════════ # CORE: Extract all layer representations # ══════════════════════════════════════════════════════════════ @torch.no_grad() def extract_layers(self, texts): """Get per-layer mean-pooled representations for each input.""" if isinstance(texts, str): texts = [texts] inputs = self.tokenizer( texts, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) outputs = self.model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], output_hidden_states=True) mask = inputs["attention_mask"].unsqueeze(-1).float() n_tokens = inputs["attention_mask"].sum(-1) # Mean-pool each layer layer_pooled = [] for h in outputs.hidden_states: pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1) layer_pooled.append(pooled.cpu()) return { "texts": texts, "layer_pooled": layer_pooled, # list of (B, 384) per layer "layer_raw": outputs.hidden_states, # tuple of (B, L, 384) per layer "final_embedding": outputs.last_hidden_state.cpu(), # (B, 768) "attention_mask": inputs["attention_mask"].cpu(), "n_tokens": n_tokens.cpu(), } # ══════════════════════════════════════════════════════════════ # 1. SPECTRAL TRAJECTORIES # ══════════════════════════════════════════════════════════════ def spectral_trajectory(self, data): """ Eigenvalue spectrum at each layer for each input. Shows how the representation's internal structure evolves. """ results = [] n_layers = len(data["layer_pooled"]) B = data["layer_pooled"][0].shape[0] for b in range(B): trajectory = [] for layer_idx in range(n_layers): # For single vector: compute singular values of the # raw token-level representation (before pooling) h = data["layer_raw"][layer_idx][b].cpu().float() # (L, 384) mask = data["attention_mask"][b] n_real = mask.sum().int().item() h = h[:n_real] # only real tokens if n_real < 2: trajectory.append({"spectrum": [], "eff_dim": 0, "entropy": 0}) continue # SVD of token representations h_centered = h - h.mean(0, keepdim=True) try: S = torch.linalg.svdvals(h_centered) except Exception: trajectory.append({"spectrum": [], "eff_dim": 0, "entropy": 0}) continue # Normalized spectrum S_norm = S / (S.sum() + 1e-12) # Effective dimensionality (participation ratio) eff_dim = (S.sum() ** 2) / (S.pow(2).sum() + 1e-12) # Spectral entropy S_pos = S_norm[S_norm > 1e-12] entropy = -(S_pos * S_pos.log()).sum() trajectory.append({ "spectrum": S[:20].tolist(), # top 20 singular values "eff_dim": eff_dim.item(), "entropy": entropy.item(), "top1_ratio": (S[0] / (S.sum() + 1e-12)).item(), }) results.append({ "text": data["texts"][b], "trajectory": trajectory, }) return results # ══════════════════════════════════════════════════════════════ # 2. EFFECTIVE DIMENSIONALITY (output space) # ══════════════════════════════════════════════════════════════ def effective_dimensionality(self, data, k_neighbors=50): """ Local effective dimensionality around each embedding. High = rich understanding. Low = surface-level placement. """ embeddings = data["final_embedding"].float() # (B, 768) B = embeddings.shape[0] if B < k_neighbors + 1: k_neighbors = max(B - 1, 2) # Pairwise distances sim = embeddings @ embeddings.T results = [] for b in range(B): # Get k nearest neighbors sims = sim[b].clone() sims[b] = -1 # exclude self _, topk_idx = sims.topk(k_neighbors) neighbors = embeddings[topk_idx] # (k, 768) # Local PCA centered = neighbors - neighbors.mean(0, keepdim=True) try: S = torch.linalg.svdvals(centered) except Exception: results.append({"eff_dim": 0, "local_variance": 0}) continue # Participation ratio eff_dim = (S.sum() ** 2) / (S.pow(2).sum() + 1e-12) # How fast do eigenvalues decay? S_norm = S / (S.sum() + 1e-12) decay_rate = (S_norm[:5].sum() / S_norm.sum()).item() results.append({ "text": data["texts"][b], "eff_dim": eff_dim.item(), "decay_rate": decay_rate, # high = concentrated, low = spread "local_spread": centered.norm(dim=-1).mean().item(), }) return results # ══════════════════════════════════════════════════════════════ # 3. CROSS-LAYER DIVERGENCE # ══════════════════════════════════════════════════════════════ def cross_layer_divergence(self, data): """ How much does the representation change between layers? High change = computation happening. Low change = pass-through. """ results = [] n_layers = len(data["layer_pooled"]) B = data["layer_pooled"][0].shape[0] for b in range(B): profile = [] for i in range(n_layers - 1): h_curr = data["layer_pooled"][i][b].float() h_next = data["layer_pooled"][i + 1][b].float() # Cosine between consecutive layers cos = F.cosine_similarity(h_curr.unsqueeze(0), h_next.unsqueeze(0)).item() # L2 distance l2 = (h_next - h_curr).norm().item() # Direction change (how much the direction rotates) h_curr_n = F.normalize(h_curr, dim=0) h_next_n = F.normalize(h_next, dim=0) angle = torch.acos(torch.clamp( (h_curr_n * h_next_n).sum(), -1, 1)).item() profile.append({ "layer": f"{i}→{i+1}", "cosine": cos, "l2_shift": l2, "angle_rad": angle, }) # Total path length through representation space total_path = sum(p["l2_shift"] for p in profile) # Where did most change happen? max_shift_layer = max(range(len(profile)), key=lambda i: profile[i]["l2_shift"]) results.append({ "text": data["texts"][b], "profile": profile, "total_path": total_path, "max_shift_layer": max_shift_layer, "input_output_cos": F.cosine_similarity( data["layer_pooled"][0][b].unsqueeze(0).float(), data["layer_pooled"][-1][b].unsqueeze(0).float() ).item(), }) return results # ══════════════════════════════════════════════════════════════ # 4. TOKEN INFLUENCE (gradient-based) # ══════════════════════════════════════════════════════════════ def token_influence(self, texts): """ Which tokens influence the output most? Uses gradient of output norm w.r.t. input embeddings. """ if isinstance(texts, str): texts = [texts] results = [] for text in texts: inputs = self.tokenizer( [text], max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) # Get embedding layer output with gradients input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] n_real = attention_mask.sum().item() # Hook into embedding emb = self.model.token_emb(input_ids) + \ self.model.pos_emb(torch.arange(input_ids.shape[1], device=DEVICE).unsqueeze(0)) emb = self.model.emb_drop(self.model.emb_norm(emb)) emb.retain_grad() # Forward through encoder kpm = ~attention_mask.bool() x = emb for layer in self.model.encoder.layers: x = layer(x, src_key_padding_mask=kpm) # Pool and project mask = attention_mask.unsqueeze(-1).float() pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) output = F.normalize(self.model.output_proj(pooled), dim=-1) # Gradient of output norm w.r.t embeddings output.sum().backward() grad = emb.grad[0].cpu() # Per-token influence = gradient norm influence = grad.norm(dim=-1)[:int(n_real)] # only real tokens influence = influence / (influence.sum() + 1e-12) # normalize # Decode tokens token_ids = input_ids[0][:int(n_real)].cpu().tolist() tokens = self.tokenizer.convert_ids_to_tokens(token_ids) results.append({ "text": text, "tokens": tokens, "influence": influence.tolist(), "top_tokens": sorted(zip(tokens, influence.tolist()), key=lambda x: -x[1])[:10], "concentration": (influence.max() / influence.mean()).item(), }) self.model.zero_grad() return results # ══════════════════════════════════════════════════════════════ # 5. FULL ANALYSIS # ══════════════════════════════════════════════════════════════ def analyze(self, texts): """Run all analyses on a set of texts.""" if isinstance(texts, str): texts = [texts] print(f" Analyzing {len(texts)} inputs...") data = self.extract_layers(texts) spectral = self.spectral_trajectory(data) eff_dim = self.effective_dimensionality(data) divergence = self.cross_layer_divergence(data) influence = self.token_influence(texts) report = {} for i, text in enumerate(texts): report[text] = { "embedding": data["final_embedding"][i], "n_tokens": data["n_tokens"][i].item(), "spectral": spectral[i], "eff_dim": eff_dim[i] if i < len(eff_dim) else {}, "divergence": divergence[i], "influence": influence[i], } return report # ══════════════════════════════════════════════════════════════ # PRINTING # ══════════════════════════════════════════════════════════════ def print_report(self, report): """Print full analysis report.""" print(f"\n{'='*70}") print("INTERNAL ANALYSIS REPORT") print(f"{'='*70}") # Summary table print(f"\n {'Text':<25} {'Tokens':>6} {'EffDim':>7} {'Path':>7} " f"{'MaxShift':>9} {'InOutCos':>8} {'Concentrate':>11}") print(f" {'-'*75}") for text, r in report.items(): label = text[:24] ed = r["eff_dim"].get("eff_dim", 0) tp = r["divergence"]["total_path"] ms = r["divergence"]["max_shift_layer"] ioc = r["divergence"]["input_output_cos"] conc = r["influence"]["concentration"] print(f" {label:<25} {r['n_tokens']:>6} {ed:>7.1f} {tp:>7.2f} " f" layer {ms:>2} {ioc:>7.3f} {conc:>10.1f}") # Spectral evolution print(f"\n SPECTRAL TRAJECTORY (effective dim per layer):") print(f" {'Text':<25}", end="") n_layers = len(next(iter(report.values()))["spectral"]["trajectory"]) for i in range(n_layers): print(f" L{i:>2}", end="") print() print(f" {'-'*75}") for text, r in report.items(): label = text[:24] print(f" {label:<25}", end="") for step in r["spectral"]["trajectory"]: ed = step.get("eff_dim", 0) print(f" {ed:>4.0f}", end="") print() # Spectral entropy per layer print(f"\n SPECTRAL ENTROPY (information content per layer):") print(f" {'Text':<25}", end="") for i in range(n_layers): print(f" L{i:>2}", end="") print() print(f" {'-'*75}") for text, r in report.items(): label = text[:24] print(f" {label:<25}", end="") for step in r["spectral"]["trajectory"]: ent = step.get("entropy", 0) print(f" {ent:>4.1f}", end="") print() # Cross-layer divergence profiles print(f"\n COMPUTATION PROFILE (L2 shift between layers):") print(f" {'Text':<25}", end="") for i in range(n_layers - 1): print(f" {i}→{i+1:>2}", end="") print() print(f" {'-'*75}") for text, r in report.items(): label = text[:24] print(f" {label:<25}", end="") for step in r["divergence"]["profile"]: print(f" {step['l2_shift']:>4.1f}", end="") print() # Token influence for each input print(f"\n TOKEN INFLUENCE (top contributing tokens):") for text, r in report.items(): top = r["influence"]["top_tokens"][:5] tok_str = " ".join(f"{t}={v:.3f}" for t, v in top) print(f" {text[:40]:<42} {tok_str}") def compare(self, report, text_a, text_b): """Compare internal representations of two specific inputs.""" a = report[text_a] b = report[text_b] cos = F.cosine_similarity( a["embedding"].unsqueeze(0), b["embedding"].unsqueeze(0)).item() print(f"\n{'='*70}") print(f"COMPARISON: '{text_a}' vs '{text_b}'") print(f"{'='*70}") print(f" Output cosine: {cos:.4f}") print(f" Tokens: {a['n_tokens']} vs {b['n_tokens']}") # Effective dim comparison ed_a = a["eff_dim"].get("eff_dim", 0) ed_b = b["eff_dim"].get("eff_dim", 0) print(f" Effective dim: {ed_a:.1f} vs {ed_b:.1f} (Δ={abs(ed_a-ed_b):.1f})") # Path comparison pa = a["divergence"]["total_path"] pb = b["divergence"]["total_path"] print(f" Total path: {pa:.2f} vs {pb:.2f} (Δ={abs(pa-pb):.2f})") # Layer-by-layer spectral comparison print(f"\n Effective dim trajectory:") print(f" {'Layer':<8} {'A':>8} {'B':>8} {'Δ':>8}") traj_a = a["spectral"]["trajectory"] traj_b = b["spectral"]["trajectory"] for i in range(len(traj_a)): ea = traj_a[i].get("eff_dim", 0) eb = traj_b[i].get("eff_dim", 0) print(f" L{i:<6} {ea:>8.1f} {eb:>8.1f} {abs(ea-eb):>8.1f}") # Divergence profile comparison print(f"\n Computation profile (L2 shift):") print(f" {'Transition':<10} {'A':>8} {'B':>8} {'Δ':>8}") for i in range(len(a["divergence"]["profile"])): sa = a["divergence"]["profile"][i]["l2_shift"] sb = b["divergence"]["profile"][i]["l2_shift"] label = a["divergence"]["profile"][i]["layer"] print(f" {label:<10} {sa:>8.2f} {sb:>8.2f} {abs(sa-sb):>8.2f}") # Token influence comparison print(f"\n Top tokens:") print(f" A: {' '.join(f'{t}={v:.3f}' for t,v in a['influence']['top_tokens'][:5])}") print(f" B: {' '.join(f'{t}={v:.3f}' for t,v in b['influence']['top_tokens'][:5])}") # ══════════════════════════════════════════════════════════════════ # RUN # ══════════════════════════════════════════════════════════════════ if __name__ == "__main__": from transformers import AutoModel, AutoTokenizer REPO_ID = "AbstractPhil/geolip-captionbert-8192" print("Loading model...") model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(REPO_ID) analyzer = InternalAnalyzer(model, tokenizer) # Test words spanning known-domain and unknown-domain test_words = [ # Known domain (captions) "girl", "woman", "dog", "sunset", "painting", # Unknown domain (abstract) "subtraction", "multiplication", "prophetic", "differential", "adjacency", # Phrases "a girl sitting near a window", "a dog playing on the beach", "the differential equation of motion", ] report = analyzer.analyze(test_words) analyzer.print_report(report) # Direct comparisons analyzer.compare(report, "girl", "woman") analyzer.compare(report, "girl", "subtraction") analyzer.compare(report, "a girl sitting near a window", "the differential equation of motion") print(f"\n{'='*70}") print("DONE") print(f"{'='*70}")