🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified | #!/usr/bin/env python3 | |
| """ | |
| Lie Holonomy Transformer - Geometric Analysis of Hidden States | |
| ============================================================== | |
| Tests whether geometric properties (velocity, curvature, holonomy) | |
| predict model behavior better than raw hidden state probes. | |
| This is the experiment that could change everything. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import json | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Tuple, Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from tqdm import tqdm | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| class GeometryConfig: | |
| model_path: str = "." # Your local model | |
| n_sequences: int = 100 | |
| max_length: int = 256 | |
| device: str = "cuda" | |
| output_dir: str = "geometry_results" | |
| # Geometric thresholds | |
| curvature_window: int = 3 # Tokens to compute curvature over | |
| holonomy_threshold: float = 0.95 # Cosine similarity to detect "loops" | |
| # ============================================================================= | |
| # GEOMETRIC COMPUTATIONS | |
| # ============================================================================= | |
| class ManifoldAnalyzer: | |
| """ | |
| Analyzes the geometry of hidden state trajectories. | |
| Key concepts: | |
| - Velocity: direction of movement in hidden space (first derivative) | |
| - Curvature: how sharply the path bends (second derivative) | |
| - Holonomy: what you lose going around a loop (parallel transport failure) | |
| """ | |
| def __init__(self, config: GeometryConfig): | |
| self.config = config | |
| def compute_velocities(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute velocity vectors (tangent vectors to the trajectory). | |
| Args: | |
| hidden_states: [seq_len, hidden_dim] | |
| Returns: | |
| velocities: [seq_len-1, hidden_dim] | |
| """ | |
| return hidden_states[1:] - hidden_states[:-1] | |
| def compute_speeds(self, velocities: torch.Tensor) -> torch.Tensor: | |
| """Magnitude of velocity vectors.""" | |
| return torch.norm(velocities, dim=-1) | |
| def compute_accelerations(self, velocities: torch.Tensor) -> torch.Tensor: | |
| """Second derivative - how velocity changes.""" | |
| return velocities[1:] - velocities[:-1] | |
| def compute_curvature(self, velocities: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Curvature κ = |dT/ds| where T is unit tangent, s is arc length. | |
| Simplified: κ ≈ |a| / |v|² where a is acceleration, v is velocity. | |
| High curvature = sharp turns in semantic space. | |
| """ | |
| speeds = self.compute_speeds(velocities) | |
| accelerations = self.compute_accelerations(velocities) | |
| accel_magnitudes = torch.norm(accelerations, dim=-1) | |
| # Avoid division by zero | |
| speeds_squared = speeds[:-1] ** 2 + 1e-8 | |
| curvature = accel_magnitudes / speeds_squared | |
| return curvature | |
| def compute_torsion(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Torsion measures how the path twists out of its osculating plane. | |
| Third derivative information. | |
| """ | |
| v = self.compute_velocities(hidden_states) | |
| a = self.compute_accelerations(v) | |
| if len(a) < 2: | |
| return torch.tensor([0.0]) | |
| # Jerk (third derivative) | |
| j = a[1:] - a[:-1] | |
| # Torsion involves cross product in the v-a-j frame | |
| # Simplified: measure how much j is out of the v-a plane | |
| v_trimmed = v[:-2] | |
| a_trimmed = a[:-1] | |
| # Project j onto plane spanned by v and a, measure residual | |
| v_norm = F.normalize(v_trimmed, dim=-1) | |
| a_norm = F.normalize(a_trimmed, dim=-1) | |
| j_proj_v = (j * v_norm).sum(dim=-1, keepdim=True) * v_norm | |
| j_proj_a = (j * a_norm).sum(dim=-1, keepdim=True) * a_norm | |
| j_in_plane = j_proj_v + j_proj_a | |
| j_out_of_plane = j - j_in_plane | |
| torsion = torch.norm(j_out_of_plane, dim=-1) | |
| return torsion | |
| def detect_loops(self, hidden_states: torch.Tensor, | |
| threshold: float = 0.95) -> List[Tuple[int, int]]: | |
| """ | |
| Detect semantic loops: positions where we return to similar states. | |
| """ | |
| # Normalize for cosine similarity | |
| h_norm = F.normalize(hidden_states, dim=-1) | |
| similarity = torch.mm(h_norm, h_norm.t()) | |
| loops = [] | |
| seq_len = hidden_states.shape[0] | |
| # Find high similarity pairs (excluding diagonal and nearby) | |
| for i in range(seq_len): | |
| for j in range(i + 5, seq_len): # At least 5 tokens apart | |
| if similarity[i, j] > threshold: | |
| loops.append((i, j)) | |
| return loops | |
| def compute_holonomy(self, hidden_states: torch.Tensor, | |
| loop: Tuple[int, int]) -> float: | |
| """ | |
| Compute holonomy around a detected loop. | |
| If we parallel transport a vector around a loop and it comes back | |
| unchanged, the space is flat. If it rotates, there's curvature. | |
| Simplified version: compare the "frame" at start vs end of loop. | |
| """ | |
| i, j = loop | |
| # Get velocities at both points | |
| if i > 0 and j < len(hidden_states) - 1: | |
| v_start = hidden_states[i] - hidden_states[i-1] | |
| v_end = hidden_states[j] - hidden_states[j-1] | |
| # Holonomy = angle between velocity vectors at "same" point | |
| v_start_norm = F.normalize(v_start, dim=-1) | |
| v_end_norm = F.normalize(v_end, dim=-1) | |
| cos_angle = (v_start_norm * v_end_norm).sum() | |
| holonomy = 1 - cos_angle.abs() # 0 = flat, 1 = maximally curved | |
| return holonomy.item() | |
| return 0.0 | |
| def analyze_sequence(self, hidden_states: torch.Tensor) -> Dict: | |
| """Full geometric analysis of a sequence.""" | |
| # Basic derivatives | |
| velocities = self.compute_velocities(hidden_states) | |
| speeds = self.compute_speeds(velocities) | |
| curvature = self.compute_curvature(velocities) | |
| torsion = self.compute_torsion(hidden_states) | |
| # Loop detection and holonomy | |
| loops = self.detect_loops(hidden_states, self.config.holonomy_threshold) | |
| holonomies = [self.compute_holonomy(hidden_states, loop) for loop in loops] | |
| return { | |
| "velocities": velocities, | |
| "speeds": speeds, | |
| "curvature": curvature, | |
| "torsion": torsion, | |
| "loops": loops, | |
| "holonomies": holonomies, | |
| # Summary statistics | |
| "mean_speed": speeds.mean().item(), | |
| "std_speed": speeds.std().item(), | |
| "mean_curvature": curvature.mean().item() if len(curvature) > 0 else 0, | |
| "max_curvature": curvature.max().item() if len(curvature) > 0 else 0, | |
| "mean_torsion": torsion.mean().item() if len(torsion) > 0 else 0, | |
| "n_loops": len(loops), | |
| "mean_holonomy": np.mean(holonomies) if holonomies else 0, | |
| } | |
| # ============================================================================= | |
| # REPETITION DETECTION (Ground Truth) | |
| # ============================================================================= | |
| def detect_repetitions(token_ids: torch.Tensor, window: int = 32) -> torch.Tensor: | |
| """ | |
| Create binary labels: 1 if token is a repetition, 0 otherwise. | |
| """ | |
| seq_len = token_ids.shape[0] | |
| labels = torch.zeros(seq_len) | |
| for i in range(1, seq_len): | |
| start = max(0, i - window) | |
| if token_ids[i] in token_ids[start:i]: | |
| labels[i] = 1.0 | |
| return labels | |
| def detect_ngram_repetitions(token_ids: torch.Tensor, n: int = 3) -> torch.Tensor: | |
| """ | |
| Detect n-gram repetitions (more sophisticated). | |
| """ | |
| seq_len = token_ids.shape[0] | |
| labels = torch.zeros(seq_len) | |
| seen_ngrams = set() | |
| for i in range(n - 1, seq_len): | |
| ngram = tuple(token_ids[i-n+1:i+1].tolist()) | |
| if ngram in seen_ngrams: | |
| labels[i] = 1.0 | |
| seen_ngrams.add(ngram) | |
| return labels | |
| # ============================================================================= | |
| # PROBE TRAINING | |
| # ============================================================================= | |
| class GeometricProbe(nn.Module): | |
| """ | |
| Probe that uses geometric features (velocity, curvature) instead of | |
| raw hidden states. | |
| """ | |
| def __init__(self, input_dim: int, hidden_dim: int = 64): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, 1), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class CurvatureProbe(nn.Module): | |
| """ | |
| Probe that takes: [velocity, acceleration, curvature_scalar] | |
| """ | |
| def __init__(self, d_model: int): | |
| super().__init__() | |
| # velocity (d_model) + acceleration (d_model) + curvature (1) | |
| input_dim = d_model * 2 + 1 | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, 128), | |
| nn.GELU(), | |
| nn.Linear(128, 64), | |
| nn.GELU(), | |
| nn.Linear(64, 1), | |
| ) | |
| def forward(self, velocity, acceleration, curvature_scalar): | |
| x = torch.cat([ | |
| velocity, | |
| acceleration, | |
| curvature_scalar.unsqueeze(-1) | |
| ], dim=-1) | |
| return self.net(x) | |
| # ============================================================================= | |
| # MAIN EXPERIMENT | |
| # ============================================================================= | |
| class LieHolonomyExperiment: | |
| """ | |
| Main experiment: compare geometric probes vs raw hidden state probes. | |
| """ | |
| def __init__(self, config: GeometryConfig): | |
| self.config = config | |
| self.analyzer = ManifoldAnalyzer(config) | |
| self.device = config.device | |
| # Results storage | |
| self.results = { | |
| "sequences": [], | |
| "geometry_stats": [], | |
| "correlations": {}, | |
| } | |
| # Load model | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the model.""" | |
| print("Loading model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.config.model_path, | |
| local_files_only=True | |
| ) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.config.model_path, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| local_files_only=True, | |
| ) | |
| self.model.eval() | |
| self.d_model = self.model.config.hidden_size | |
| self.n_layers = self.model.config.num_hidden_layers | |
| print(f"Model loaded: {self.d_model} hidden dim, {self.n_layers} layers") | |
| def generate_with_hidden_states(self, prompt: str) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| """Generate and capture all hidden states.""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| all_hidden_states = [] | |
| generated_ids = inputs.input_ids.clone() | |
| for step in range(self.config.max_length): | |
| with torch.no_grad(): | |
| outputs = self.model( | |
| input_ids=generated_ids, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| ) | |
| # Get hidden states from last layer, last position | |
| hidden = outputs.hidden_states[-1][:, -1, :] # [1, d_model] | |
| all_hidden_states.append(hidden.squeeze(0).cpu()) | |
| # Sample next token | |
| logits = outputs.logits[:, -1, :] | |
| probs = F.softmax(logits / 0.8, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| hidden_states = torch.stack(all_hidden_states) # [seq_len, d_model] | |
| return generated_ids.squeeze(0).cpu(), hidden_states | |
| def run_experiment(self, prompts: List[str] = None): | |
| """Run the full experiment.""" | |
| if prompts is None: | |
| prompts = [ | |
| "Once upon a time", | |
| "The meaning of life is", | |
| "In the beginning there was", | |
| "To be or not to be", | |
| "The quick brown fox", | |
| "Explain how neural networks", | |
| "Write a story about", | |
| "The most important thing", | |
| "Scientists discovered that", | |
| "In a world where", | |
| ] * 10 # 100 sequences | |
| print(f"\nRunning experiment on {len(prompts)} sequences...") | |
| all_curvatures = [] | |
| all_repetition_labels = [] | |
| all_holonomies = [] | |
| all_speeds = [] | |
| for i, prompt in enumerate(tqdm(prompts)): | |
| try: | |
| # Generate | |
| token_ids, hidden_states = self.generate_with_hidden_states(prompt) | |
| # Geometric analysis | |
| geometry = self.analyzer.analyze_sequence(hidden_states) | |
| # Repetition labels | |
| rep_labels = detect_repetitions(token_ids) | |
| ngram_labels = detect_ngram_repetitions(token_ids) | |
| # Align lengths (curvature is shorter due to derivatives) | |
| min_len = min(len(geometry["curvature"]), len(rep_labels) - 2) | |
| if min_len > 0: | |
| curvature = geometry["curvature"][:min_len] | |
| labels = rep_labels[2:2+min_len] # Offset by 2 for second derivative | |
| all_curvatures.extend(curvature.tolist()) | |
| all_repetition_labels.extend(labels.tolist()) | |
| # Store sequence data | |
| self.results["sequences"].append({ | |
| "prompt": prompt, | |
| "length": len(token_ids), | |
| "n_repetitions": int(rep_labels.sum()), | |
| "n_ngram_repetitions": int(ngram_labels.sum()), | |
| **{k: v for k, v in geometry.items() | |
| if isinstance(v, (int, float))} | |
| }) | |
| except Exception as e: | |
| print(f"Error on prompt {i}: {e}") | |
| continue | |
| # Compute correlations | |
| self._compute_correlations(all_curvatures, all_repetition_labels) | |
| # Save results | |
| self._save_results() | |
| return self.results | |
| def _compute_correlations(self, curvatures: List[float], labels: List[float]): | |
| """Compute correlations between geometry and repetition.""" | |
| curvatures = np.array(curvatures) | |
| labels = np.array(labels) | |
| # Basic correlation | |
| if len(curvatures) > 0 and len(labels) > 0: | |
| correlation = np.corrcoef(curvatures, labels)[0, 1] | |
| else: | |
| correlation = 0 | |
| # Split by label and compare means | |
| rep_indices = labels > 0.5 | |
| non_rep_indices = labels < 0.5 | |
| if rep_indices.sum() > 0 and non_rep_indices.sum() > 0: | |
| mean_curv_rep = curvatures[rep_indices].mean() | |
| mean_curv_nonrep = curvatures[non_rep_indices].mean() | |
| separation = mean_curv_rep / (mean_curv_nonrep + 1e-8) | |
| else: | |
| mean_curv_rep = 0 | |
| mean_curv_nonrep = 0 | |
| separation = 1.0 | |
| self.results["correlations"] = { | |
| "curvature_repetition_correlation": float(correlation), | |
| "mean_curvature_at_repetition": float(mean_curv_rep), | |
| "mean_curvature_at_non_repetition": float(mean_curv_nonrep), | |
| "curvature_separation_ratio": float(separation), | |
| "n_samples": len(curvatures), | |
| "n_repetitions": int(labels.sum()), | |
| } | |
| print("\n" + "="*60) | |
| print("GEOMETRIC ANALYSIS RESULTS") | |
| print("="*60) | |
| print(f"Correlation (curvature <-> repetition): {correlation:.4f}") | |
| print(f"Mean curvature at repetitions: {mean_curv_rep:.6f}") | |
| print(f"Mean curvature at non-repetitions: {mean_curv_nonrep:.6f}") | |
| print(f"Separation ratio: {separation:.2f}x") | |
| print(f"Total samples: {len(curvatures)}") | |
| print(f"Total repetitions: {int(labels.sum())}") | |
| print("="*60) | |
| # Interpretation | |
| if separation > 2.0: | |
| print("\n🎯 STRONG SIGNAL: Curvature predicts repetition!") | |
| print(" This validates the geometric hypothesis.") | |
| elif separation > 1.3: | |
| print("\n📊 MODERATE SIGNAL: Some predictive power.") | |
| print(" Worth investigating further.") | |
| else: | |
| print("\n⚠️ WEAK SIGNAL: Curvature alone may not be enough.") | |
| print(" Try holonomy or learned geometric features.") | |
| def _save_results(self): | |
| """Save results to disk.""" | |
| output_dir = Path(self.config.output_dir) | |
| output_dir.mkdir(exist_ok=True) | |
| # Save JSON summary | |
| summary = { | |
| "config": { | |
| "n_sequences": self.config.n_sequences, | |
| "max_length": self.config.max_length, | |
| }, | |
| "correlations": self.results["correlations"], | |
| "sequence_stats": { | |
| "mean_length": np.mean([s["length"] for s in self.results["sequences"]]), | |
| "mean_repetitions": np.mean([s["n_repetitions"] for s in self.results["sequences"]]), | |
| "mean_curvature": np.mean([s["mean_curvature"] for s in self.results["sequences"]]), | |
| "mean_n_loops": np.mean([s["n_loops"] for s in self.results["sequences"]]), | |
| "mean_holonomy": np.mean([s["mean_holonomy"] for s in self.results["sequences"]]), | |
| } | |
| } | |
| with open(output_dir / "geometry_results.json", "w") as f: | |
| json.dump(summary, f, indent=2) | |
| print(f"\nResults saved to {output_dir}/geometry_results.json") | |
| # ============================================================================= | |
| # CONNECTION NETWORK - THE KEY TO HOLONOMY | |
| # ============================================================================= | |
| class ConnectionNetwork(nn.Module): | |
| """ | |
| Learns the Levi-Civita connection on the hidden state manifold. | |
| This is the key insight: if we can learn how to parallel transport | |
| vectors along paths, we can detect curvature (holonomy). | |
| The connection tells us: "If I move from point A to point B, | |
| how should a vector at A transform to stay 'parallel'?" | |
| """ | |
| def __init__(self, d_model: int, d_connection: int = 256): | |
| super().__init__() | |
| # Encode the path between two points | |
| self.path_encoder = nn.Sequential( | |
| nn.Linear(d_model * 2, d_connection), | |
| nn.GELU(), | |
| nn.Linear(d_connection, d_connection), | |
| ) | |
| # Predict the transport matrix (simplified as a learned transformation) | |
| self.transport_predictor = nn.Sequential( | |
| nn.Linear(d_connection, d_connection), | |
| nn.GELU(), | |
| nn.Linear(d_connection, d_model * d_model), # Full matrix | |
| ) | |
| self.d_model = d_model | |
| def forward(self, h_start: torch.Tensor, h_end: torch.Tensor, | |
| v_start: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Parallel transport vector v_start from h_start to h_end. | |
| Args: | |
| h_start: Starting hidden state [batch, d_model] | |
| h_end: Ending hidden state [batch, d_model] | |
| v_start: Vector to transport [batch, d_model] | |
| Returns: | |
| v_transported: Transported vector at h_end [batch, d_model] | |
| """ | |
| batch_size = h_start.shape[0] | |
| # Encode the path | |
| path = torch.cat([h_start, h_end], dim=-1) | |
| path_encoding = self.path_encoder(path) | |
| # Get transport matrix | |
| transport_flat = self.transport_predictor(path_encoding) | |
| transport_matrix = transport_flat.view(batch_size, self.d_model, self.d_model) | |
| # Apply transport (with residual connection for stability) | |
| v_transported = torch.bmm(transport_matrix, v_start.unsqueeze(-1)).squeeze(-1) | |
| v_transported = v_transported + v_start # Residual | |
| return v_transported | |
| def compute_holonomy(self, hidden_states: torch.Tensor, | |
| loop_indices: List[int]) -> torch.Tensor: | |
| """ | |
| Compute holonomy around a loop defined by indices. | |
| Holonomy = what you get when you parallel transport a vector | |
| around a closed loop and compare to the original. | |
| """ | |
| if len(loop_indices) < 3: | |
| return torch.tensor(0.0) | |
| # Start with a basis vector | |
| v = torch.randn(1, self.d_model, device=hidden_states.device) | |
| v = F.normalize(v, dim=-1) | |
| v_original = v.clone() | |
| # Transport around the loop | |
| for i in range(len(loop_indices) - 1): | |
| idx_start = loop_indices[i] | |
| idx_end = loop_indices[i + 1] | |
| h_start = hidden_states[idx_start].unsqueeze(0) | |
| h_end = hidden_states[idx_end].unsqueeze(0) | |
| v = self.forward(h_start, h_end, v) | |
| # Close the loop | |
| h_start = hidden_states[loop_indices[-1]].unsqueeze(0) | |
| h_end = hidden_states[loop_indices[0]].unsqueeze(0) | |
| v_final = self.forward(h_start, h_end, v) | |
| # Holonomy magnitude | |
| holonomy = 1 - F.cosine_similarity(v_final, v_original, dim=-1) | |
| return holonomy | |
| # ============================================================================= | |
| # RUN EXPERIMENT | |
| # ============================================================================= | |
| def main(): | |
| """Run the Lie Holonomy experiment.""" | |
| print("="*60) | |
| print("LIE HOLONOMY TRANSFORMER - GEOMETRIC ANALYSIS") | |
| print("="*60) | |
| print("\nHypothesis: Hidden state GEOMETRY (curvature, holonomy)") | |
| print("predicts model behavior better than raw states.\n") | |
| config = GeometryConfig( | |
| model_path=".", # Current directory | |
| n_sequences=100, | |
| max_length=256, | |
| ) | |
| experiment = LieHolonomyExperiment(config) | |
| results = experiment.run_experiment() | |
| print("\n" + "="*60) | |
| print("EXPERIMENT COMPLETE") | |
| print("="*60) | |
| print("\nNext steps based on results:") | |
| sep = results["correlations"]["curvature_separation_ratio"] | |
| if sep > 2.0: | |
| print(""" | |
| ✅ STRONG SIGNAL DETECTED | |
| The geometric approach shows promise. Next: | |
| 1. Train a CurvatureProbe to beat your 80x probe | |
| 2. Implement the ConnectionNetwork for learned parallel transport | |
| 3. Use holonomy as a NEW training signal for self-improvement | |
| This could be the breakthrough. | |
| """) | |
| elif sep > 1.3: | |
| print(""" | |
| 📊 MODERATE SIGNAL | |
| Worth pursuing. Try: | |
| 1. Different geometric features (torsion, geodesic deviation) | |
| 2. Multi-layer analysis (geometry at each transformer layer) | |
| 3. Larger sample sizes | |
| """) | |
| else: | |
| print(""" | |
| ⚠️ WEAK SIGNAL | |
| Raw curvature may not be the right feature. Try: | |
| 1. Learned geometric features (train the connection network) | |
| 2. Sectional curvature (curvature in specific 2D planes) | |
| 3. Ricci curvature (average over all directions) | |
| """) | |
| return results | |
| if __name__ == "__main__": | |
| main() | |