#!/usr/bin/env python3 """ CF-HoT Universal Probe Loader Load any probe from this repo and run it on a model's hidden states. Works with all suppression probes (LLaMA 8B) and cognitive enhancement probes (Qwen, Mamba, Mistral). Usage: python inference.py --probe suppression/hedging_168x python inference.py --probe cognitive/mistral/depth python inference.py --probe suppression/repetition_125x --prompt "Tell me about AI" """ import torch import torch.nn as nn import argparse import os import glob # ─── Architecture definitions ─────────────────────────────────────── class FiberProjection(nn.Module): """Projects hidden states from multiple layers into fiber space.""" def __init__(self, hidden_dim, fiber_dim=16, num_layers=3, bias=True): super().__init__() self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.projections = nn.ModuleList([ nn.Linear(hidden_dim, fiber_dim, bias=bias) for _ in range(num_layers) ]) def forward(self, hidden_states_list): weights = torch.softmax(self.layer_weights, dim=0) return sum(w * proj(h.float()) for w, h, proj in zip(weights, hidden_states_list, self.projections)) class ProbeHead(nn.Module): """Classifies fiber-space vectors into behavioral risk scores.""" def __init__(self, fiber_dim=16, hidden_dim=64): super().__init__() self.classifier = nn.Sequential( nn.Linear(fiber_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 1), ) def forward(self, x): return torch.sigmoid(self.classifier(x)) class RiskPredictor(nn.Module): """Full risk predictor (used by repetition_125x). All-layer version.""" def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=32): super().__init__() self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) self.fiber_projs = nn.ModuleList([ nn.Linear(hidden_dim, fiber_dim, bias=False) for _ in range(n_layers) ]) self.predictor = nn.Sequential( nn.Linear(fiber_dim, 64), nn.GELU(), nn.Linear(64, 64), nn.GELU(), nn.Linear(64, 1), ) def forward(self, hidden_states_list): weights = torch.softmax(self.layer_weights, dim=0) fiber = sum(w * proj(h.float()) for w, h, proj in zip(weights, hidden_states_list, self.fiber_projs)) return torch.sigmoid(self.predictor(fiber)) # ─── Loader ───────────────────────────────────────────────────────── # Base models and their configs MODEL_CONFIGS = { "llama": { "model_id": "meta-llama/Llama-3.1-8B-Instruct", "hidden_dim": 4096, "n_layers": 32, "probe_layers": [10, 20, 30], # default for 3-layer probes }, "qwen": { "model_id": "Qwen/Qwen2.5-7B-Instruct", "hidden_dim": 3584, "n_layers": 28, "probe_layers": [9, 18, 27], }, "mamba": { "model_id": "tiiuae/falcon-mamba-7b-instruct", "hidden_dim": 4096, "n_layers": 64, "probe_layers": [16, 32, 48], }, "mistral": { "model_id": "mistralai/Mistral-7B-Instruct-v0.3", "hidden_dim": 4096, "n_layers": 32, "probe_layers": [8, 16, 24], }, } def detect_probe_type(probe_path): """Auto-detect what kind of probe checkpoint this is.""" files = os.listdir(probe_path) if os.path.isdir(probe_path) else [] # Repetition uses risk_predictor.pt if "risk_predictor.pt" in files: return "risk_predictor" # Suppression probes: separate head + fiber_proj files head_files = [f for f in files if f.endswith("_head.pt")] if head_files and "fiber_proj.pt" in files: return "suppression" # Cognitive probes: single file with fiber_projection + head_state if head_files and "fiber_proj.pt" not in files: return "cognitive" return "unknown" def detect_architecture(probe_path): """Detect which base model architecture a probe targets.""" path_lower = probe_path.lower() if "qwen" in path_lower: return "qwen" elif "mamba" in path_lower: return "mamba" elif "mistral" in path_lower: return "mistral" else: return "llama" # suppression probes default to LLaMA def load_probe(probe_path, device="cuda"): """ Load any CF-HoT probe from a directory. Returns: dict with keys: - 'type': str ('risk_predictor', 'suppression', or 'cognitive') - 'arch': str ('llama', 'qwen', 'mamba', 'mistral') - 'config': dict (model config) - 'fiber': FiberProjection or None - 'head': ProbeHead or None - 'risk_predictor': RiskPredictor or None - 'probe_layers': list[int] - 'metadata': dict (step, separation, etc.) """ probe_type = detect_probe_type(probe_path) arch = detect_architecture(probe_path) config = MODEL_CONFIGS[arch] result = { "type": probe_type, "arch": arch, "config": config, "fiber": None, "head": None, "risk_predictor": None, "probe_layers": config["probe_layers"], "metadata": {}, } if probe_type == "risk_predictor": ckpt = torch.load(os.path.join(probe_path, "risk_predictor.pt"), map_location=device, weights_only=False) rp = RiskPredictor( hidden_dim=config["hidden_dim"], fiber_dim=16, n_layers=config["n_layers"] ).to(device) # Keys are nested under 'risk_predictor.*' state = {k.replace("risk_predictor.", ""): v for k, v in ckpt.items() if k.startswith("risk_predictor.")} rp.load_state_dict(state) rp.eval() result["risk_predictor"] = rp result["probe_layers"] = list(range(config["n_layers"])) if "step" in ckpt: result["metadata"]["step"] = ckpt["step"] elif probe_type == "suppression": # Separate head + fiber_proj files head_file = [f for f in os.listdir(probe_path) if f.endswith("_head.pt")][0] head_ckpt = torch.load(os.path.join(probe_path, head_file), map_location=device, weights_only=False) fiber_ckpt = torch.load(os.path.join(probe_path, "fiber_proj.pt"), map_location=device, weights_only=False) # Detect bias from checkpoint has_bias = any("bias" in k for k in fiber_ckpt.keys()) fiber = FiberProjection( hidden_dim=config["hidden_dim"], fiber_dim=16, num_layers=3, bias=has_bias ).to(device) fiber.load_state_dict(fiber_ckpt) fiber.eval() head = ProbeHead(fiber_dim=16, hidden_dim=64).to(device) head.load_state_dict(head_ckpt) head.eval() result["fiber"] = fiber result["head"] = head elif probe_type == "cognitive": head_file = [f for f in os.listdir(probe_path) if f.endswith("_head.pt")][0] ckpt = torch.load(os.path.join(probe_path, head_file), map_location=device, weights_only=False) # Extract metadata for key in ["step", "separation", "loss", "probe_name", "hidden_dim", "probe_layers", "architecture"]: if key in ckpt: result["metadata"][key] = ckpt[key] # Override probe_layers if stored in checkpoint if "probe_layers" in ckpt: result["probe_layers"] = ckpt["probe_layers"] # Detect hidden_dim from weights hidden_dim = ckpt.get("hidden_dim", config["hidden_dim"]) has_bias = any("bias" in k for k in ckpt if "fiber_projection" in k) fiber = FiberProjection( hidden_dim=hidden_dim, fiber_dim=16, num_layers=3, bias=has_bias ).to(device) fiber_state = {k.replace("fiber_projection.", ""): v for k, v in ckpt.items() if k.startswith("fiber_projection.")} fiber.load_state_dict(fiber_state) fiber.eval() head = ProbeHead(fiber_dim=16, hidden_dim=64).to(device) # Cognitive probes use either 'classifier' or 'net' naming head_state = {} for k, v in ckpt.items(): if k.startswith("head_state."): clean = k.replace("head_state.", "") # Normalize 'net.*' to 'classifier.*' clean = clean.replace("net.", "classifier.") head_state[clean] = v head.load_state_dict(head_state) head.eval() result["fiber"] = fiber result["head"] = head return result def score_hidden_states(probe, hidden_states, position=-1): """ Score hidden states using a loaded probe. Args: probe: dict returned by load_probe() hidden_states: tuple of tensors from model(output_hidden_states=True) position: token position to score (default: last token) Returns: float: risk/behavioral score between 0 and 1 """ layers = probe["probe_layers"] if probe["type"] == "risk_predictor": hs = [hidden_states[i][:, position, :] for i in range(len(hidden_states)) if i < len(hidden_states)] with torch.no_grad(): return probe["risk_predictor"](hs).item() else: hs = [hidden_states[i][:, position, :] for i in layers] with torch.no_grad(): fiber_vec = probe["fiber"](hs) return probe["head"](fiber_vec).item() # ─── CLI demo ─────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="CF-HoT Probe Inference") parser.add_argument("--probe", required=True, help="Path to probe directory (e.g. suppression/hedging_168x)") parser.add_argument("--prompt", default="Can you explain quantum computing?", help="Text prompt to analyze") parser.add_argument("--device", default="cuda") parser.add_argument("--info-only", action="store_true", help="Just print probe info, don't load base model") args = parser.parse_args() print(f"Loading probe from: {args.probe}") probe = load_probe(args.probe, device=args.device) print(f" Type: {probe['type']}") print(f" Architecture: {probe['arch']}") print(f" Base model: {probe['config']['model_id']}") print(f" Probe layers: {probe['probe_layers']}") if probe["metadata"]: for k, v in probe["metadata"].items(): print(f" {k}: {v}") if args.info_only: return # Load base model from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig model_id = probe["config"]["model_id"] print(f"\nLoading {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, ), device_map="auto", output_hidden_states=True, ) model.eval() # Tokenize and run inputs = tokenizer(args.prompt, return_tensors="pt").to(args.device) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) score = score_hidden_states(probe, outputs.hidden_states) print(f"\nPrompt: {args.prompt}") print(f"Score: {score:.4f}") print(f" (>0.5 = behavioral pattern detected, <0.5 = normal)") if __name__ == "__main__": main()