cfhot-weights / code /training_pipelines /04_lie_holonomy_experiment_GEOMETRY.py
LoganResearch's picture
🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified
#!/usr/bin/env python3
"""
Lie Holonomy Transformer - Geometric Analysis of Hidden States
==============================================================
Tests whether geometric properties (velocity, curvature, holonomy)
predict model behavior better than raw hidden state probes.
This is the experiment that could change everything.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm import tqdm
# =============================================================================
# CONFIGURATION
# =============================================================================
@dataclass
class GeometryConfig:
model_path: str = "." # Your local model
n_sequences: int = 100
max_length: int = 256
device: str = "cuda"
output_dir: str = "geometry_results"
# Geometric thresholds
curvature_window: int = 3 # Tokens to compute curvature over
holonomy_threshold: float = 0.95 # Cosine similarity to detect "loops"
# =============================================================================
# GEOMETRIC COMPUTATIONS
# =============================================================================
class ManifoldAnalyzer:
"""
Analyzes the geometry of hidden state trajectories.
Key concepts:
- Velocity: direction of movement in hidden space (first derivative)
- Curvature: how sharply the path bends (second derivative)
- Holonomy: what you lose going around a loop (parallel transport failure)
"""
def __init__(self, config: GeometryConfig):
self.config = config
def compute_velocities(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Compute velocity vectors (tangent vectors to the trajectory).
Args:
hidden_states: [seq_len, hidden_dim]
Returns:
velocities: [seq_len-1, hidden_dim]
"""
return hidden_states[1:] - hidden_states[:-1]
def compute_speeds(self, velocities: torch.Tensor) -> torch.Tensor:
"""Magnitude of velocity vectors."""
return torch.norm(velocities, dim=-1)
def compute_accelerations(self, velocities: torch.Tensor) -> torch.Tensor:
"""Second derivative - how velocity changes."""
return velocities[1:] - velocities[:-1]
def compute_curvature(self, velocities: torch.Tensor) -> torch.Tensor:
"""
Curvature κ = |dT/ds| where T is unit tangent, s is arc length.
Simplified: κ ≈ |a| / |v|² where a is acceleration, v is velocity.
High curvature = sharp turns in semantic space.
"""
speeds = self.compute_speeds(velocities)
accelerations = self.compute_accelerations(velocities)
accel_magnitudes = torch.norm(accelerations, dim=-1)
# Avoid division by zero
speeds_squared = speeds[:-1] ** 2 + 1e-8
curvature = accel_magnitudes / speeds_squared
return curvature
def compute_torsion(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Torsion measures how the path twists out of its osculating plane.
Third derivative information.
"""
v = self.compute_velocities(hidden_states)
a = self.compute_accelerations(v)
if len(a) < 2:
return torch.tensor([0.0])
# Jerk (third derivative)
j = a[1:] - a[:-1]
# Torsion involves cross product in the v-a-j frame
# Simplified: measure how much j is out of the v-a plane
v_trimmed = v[:-2]
a_trimmed = a[:-1]
# Project j onto plane spanned by v and a, measure residual
v_norm = F.normalize(v_trimmed, dim=-1)
a_norm = F.normalize(a_trimmed, dim=-1)
j_proj_v = (j * v_norm).sum(dim=-1, keepdim=True) * v_norm
j_proj_a = (j * a_norm).sum(dim=-1, keepdim=True) * a_norm
j_in_plane = j_proj_v + j_proj_a
j_out_of_plane = j - j_in_plane
torsion = torch.norm(j_out_of_plane, dim=-1)
return torsion
def detect_loops(self, hidden_states: torch.Tensor,
threshold: float = 0.95) -> List[Tuple[int, int]]:
"""
Detect semantic loops: positions where we return to similar states.
"""
# Normalize for cosine similarity
h_norm = F.normalize(hidden_states, dim=-1)
similarity = torch.mm(h_norm, h_norm.t())
loops = []
seq_len = hidden_states.shape[0]
# Find high similarity pairs (excluding diagonal and nearby)
for i in range(seq_len):
for j in range(i + 5, seq_len): # At least 5 tokens apart
if similarity[i, j] > threshold:
loops.append((i, j))
return loops
def compute_holonomy(self, hidden_states: torch.Tensor,
loop: Tuple[int, int]) -> float:
"""
Compute holonomy around a detected loop.
If we parallel transport a vector around a loop and it comes back
unchanged, the space is flat. If it rotates, there's curvature.
Simplified version: compare the "frame" at start vs end of loop.
"""
i, j = loop
# Get velocities at both points
if i > 0 and j < len(hidden_states) - 1:
v_start = hidden_states[i] - hidden_states[i-1]
v_end = hidden_states[j] - hidden_states[j-1]
# Holonomy = angle between velocity vectors at "same" point
v_start_norm = F.normalize(v_start, dim=-1)
v_end_norm = F.normalize(v_end, dim=-1)
cos_angle = (v_start_norm * v_end_norm).sum()
holonomy = 1 - cos_angle.abs() # 0 = flat, 1 = maximally curved
return holonomy.item()
return 0.0
def analyze_sequence(self, hidden_states: torch.Tensor) -> Dict:
"""Full geometric analysis of a sequence."""
# Basic derivatives
velocities = self.compute_velocities(hidden_states)
speeds = self.compute_speeds(velocities)
curvature = self.compute_curvature(velocities)
torsion = self.compute_torsion(hidden_states)
# Loop detection and holonomy
loops = self.detect_loops(hidden_states, self.config.holonomy_threshold)
holonomies = [self.compute_holonomy(hidden_states, loop) for loop in loops]
return {
"velocities": velocities,
"speeds": speeds,
"curvature": curvature,
"torsion": torsion,
"loops": loops,
"holonomies": holonomies,
# Summary statistics
"mean_speed": speeds.mean().item(),
"std_speed": speeds.std().item(),
"mean_curvature": curvature.mean().item() if len(curvature) > 0 else 0,
"max_curvature": curvature.max().item() if len(curvature) > 0 else 0,
"mean_torsion": torsion.mean().item() if len(torsion) > 0 else 0,
"n_loops": len(loops),
"mean_holonomy": np.mean(holonomies) if holonomies else 0,
}
# =============================================================================
# REPETITION DETECTION (Ground Truth)
# =============================================================================
def detect_repetitions(token_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
"""
Create binary labels: 1 if token is a repetition, 0 otherwise.
"""
seq_len = token_ids.shape[0]
labels = torch.zeros(seq_len)
for i in range(1, seq_len):
start = max(0, i - window)
if token_ids[i] in token_ids[start:i]:
labels[i] = 1.0
return labels
def detect_ngram_repetitions(token_ids: torch.Tensor, n: int = 3) -> torch.Tensor:
"""
Detect n-gram repetitions (more sophisticated).
"""
seq_len = token_ids.shape[0]
labels = torch.zeros(seq_len)
seen_ngrams = set()
for i in range(n - 1, seq_len):
ngram = tuple(token_ids[i-n+1:i+1].tolist())
if ngram in seen_ngrams:
labels[i] = 1.0
seen_ngrams.add(ngram)
return labels
# =============================================================================
# PROBE TRAINING
# =============================================================================
class GeometricProbe(nn.Module):
"""
Probe that uses geometric features (velocity, curvature) instead of
raw hidden states.
"""
def __init__(self, input_dim: int, hidden_dim: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
return self.net(x)
class CurvatureProbe(nn.Module):
"""
Probe that takes: [velocity, acceleration, curvature_scalar]
"""
def __init__(self, d_model: int):
super().__init__()
# velocity (d_model) + acceleration (d_model) + curvature (1)
input_dim = d_model * 2 + 1
self.net = nn.Sequential(
nn.Linear(input_dim, 128),
nn.GELU(),
nn.Linear(128, 64),
nn.GELU(),
nn.Linear(64, 1),
)
def forward(self, velocity, acceleration, curvature_scalar):
x = torch.cat([
velocity,
acceleration,
curvature_scalar.unsqueeze(-1)
], dim=-1)
return self.net(x)
# =============================================================================
# MAIN EXPERIMENT
# =============================================================================
class LieHolonomyExperiment:
"""
Main experiment: compare geometric probes vs raw hidden state probes.
"""
def __init__(self, config: GeometryConfig):
self.config = config
self.analyzer = ManifoldAnalyzer(config)
self.device = config.device
# Results storage
self.results = {
"sequences": [],
"geometry_stats": [],
"correlations": {},
}
# Load model
self._load_model()
def _load_model(self):
"""Load the model."""
print("Loading model...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_path,
local_files_only=True
)
self.tokenizer.pad_token = self.tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_path,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
local_files_only=True,
)
self.model.eval()
self.d_model = self.model.config.hidden_size
self.n_layers = self.model.config.num_hidden_layers
print(f"Model loaded: {self.d_model} hidden dim, {self.n_layers} layers")
def generate_with_hidden_states(self, prompt: str) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Generate and capture all hidden states."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
all_hidden_states = []
generated_ids = inputs.input_ids.clone()
for step in range(self.config.max_length):
with torch.no_grad():
outputs = self.model(
input_ids=generated_ids,
output_hidden_states=True,
return_dict=True,
)
# Get hidden states from last layer, last position
hidden = outputs.hidden_states[-1][:, -1, :] # [1, d_model]
all_hidden_states.append(hidden.squeeze(0).cpu())
# Sample next token
logits = outputs.logits[:, -1, :]
probs = F.softmax(logits / 0.8, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
if next_token.item() == self.tokenizer.eos_token_id:
break
hidden_states = torch.stack(all_hidden_states) # [seq_len, d_model]
return generated_ids.squeeze(0).cpu(), hidden_states
def run_experiment(self, prompts: List[str] = None):
"""Run the full experiment."""
if prompts is None:
prompts = [
"Once upon a time",
"The meaning of life is",
"In the beginning there was",
"To be or not to be",
"The quick brown fox",
"Explain how neural networks",
"Write a story about",
"The most important thing",
"Scientists discovered that",
"In a world where",
] * 10 # 100 sequences
print(f"\nRunning experiment on {len(prompts)} sequences...")
all_curvatures = []
all_repetition_labels = []
all_holonomies = []
all_speeds = []
for i, prompt in enumerate(tqdm(prompts)):
try:
# Generate
token_ids, hidden_states = self.generate_with_hidden_states(prompt)
# Geometric analysis
geometry = self.analyzer.analyze_sequence(hidden_states)
# Repetition labels
rep_labels = detect_repetitions(token_ids)
ngram_labels = detect_ngram_repetitions(token_ids)
# Align lengths (curvature is shorter due to derivatives)
min_len = min(len(geometry["curvature"]), len(rep_labels) - 2)
if min_len > 0:
curvature = geometry["curvature"][:min_len]
labels = rep_labels[2:2+min_len] # Offset by 2 for second derivative
all_curvatures.extend(curvature.tolist())
all_repetition_labels.extend(labels.tolist())
# Store sequence data
self.results["sequences"].append({
"prompt": prompt,
"length": len(token_ids),
"n_repetitions": int(rep_labels.sum()),
"n_ngram_repetitions": int(ngram_labels.sum()),
**{k: v for k, v in geometry.items()
if isinstance(v, (int, float))}
})
except Exception as e:
print(f"Error on prompt {i}: {e}")
continue
# Compute correlations
self._compute_correlations(all_curvatures, all_repetition_labels)
# Save results
self._save_results()
return self.results
def _compute_correlations(self, curvatures: List[float], labels: List[float]):
"""Compute correlations between geometry and repetition."""
curvatures = np.array(curvatures)
labels = np.array(labels)
# Basic correlation
if len(curvatures) > 0 and len(labels) > 0:
correlation = np.corrcoef(curvatures, labels)[0, 1]
else:
correlation = 0
# Split by label and compare means
rep_indices = labels > 0.5
non_rep_indices = labels < 0.5
if rep_indices.sum() > 0 and non_rep_indices.sum() > 0:
mean_curv_rep = curvatures[rep_indices].mean()
mean_curv_nonrep = curvatures[non_rep_indices].mean()
separation = mean_curv_rep / (mean_curv_nonrep + 1e-8)
else:
mean_curv_rep = 0
mean_curv_nonrep = 0
separation = 1.0
self.results["correlations"] = {
"curvature_repetition_correlation": float(correlation),
"mean_curvature_at_repetition": float(mean_curv_rep),
"mean_curvature_at_non_repetition": float(mean_curv_nonrep),
"curvature_separation_ratio": float(separation),
"n_samples": len(curvatures),
"n_repetitions": int(labels.sum()),
}
print("\n" + "="*60)
print("GEOMETRIC ANALYSIS RESULTS")
print("="*60)
print(f"Correlation (curvature <-> repetition): {correlation:.4f}")
print(f"Mean curvature at repetitions: {mean_curv_rep:.6f}")
print(f"Mean curvature at non-repetitions: {mean_curv_nonrep:.6f}")
print(f"Separation ratio: {separation:.2f}x")
print(f"Total samples: {len(curvatures)}")
print(f"Total repetitions: {int(labels.sum())}")
print("="*60)
# Interpretation
if separation > 2.0:
print("\n🎯 STRONG SIGNAL: Curvature predicts repetition!")
print(" This validates the geometric hypothesis.")
elif separation > 1.3:
print("\n📊 MODERATE SIGNAL: Some predictive power.")
print(" Worth investigating further.")
else:
print("\n⚠️ WEAK SIGNAL: Curvature alone may not be enough.")
print(" Try holonomy or learned geometric features.")
def _save_results(self):
"""Save results to disk."""
output_dir = Path(self.config.output_dir)
output_dir.mkdir(exist_ok=True)
# Save JSON summary
summary = {
"config": {
"n_sequences": self.config.n_sequences,
"max_length": self.config.max_length,
},
"correlations": self.results["correlations"],
"sequence_stats": {
"mean_length": np.mean([s["length"] for s in self.results["sequences"]]),
"mean_repetitions": np.mean([s["n_repetitions"] for s in self.results["sequences"]]),
"mean_curvature": np.mean([s["mean_curvature"] for s in self.results["sequences"]]),
"mean_n_loops": np.mean([s["n_loops"] for s in self.results["sequences"]]),
"mean_holonomy": np.mean([s["mean_holonomy"] for s in self.results["sequences"]]),
}
}
with open(output_dir / "geometry_results.json", "w") as f:
json.dump(summary, f, indent=2)
print(f"\nResults saved to {output_dir}/geometry_results.json")
# =============================================================================
# CONNECTION NETWORK - THE KEY TO HOLONOMY
# =============================================================================
class ConnectionNetwork(nn.Module):
"""
Learns the Levi-Civita connection on the hidden state manifold.
This is the key insight: if we can learn how to parallel transport
vectors along paths, we can detect curvature (holonomy).
The connection tells us: "If I move from point A to point B,
how should a vector at A transform to stay 'parallel'?"
"""
def __init__(self, d_model: int, d_connection: int = 256):
super().__init__()
# Encode the path between two points
self.path_encoder = nn.Sequential(
nn.Linear(d_model * 2, d_connection),
nn.GELU(),
nn.Linear(d_connection, d_connection),
)
# Predict the transport matrix (simplified as a learned transformation)
self.transport_predictor = nn.Sequential(
nn.Linear(d_connection, d_connection),
nn.GELU(),
nn.Linear(d_connection, d_model * d_model), # Full matrix
)
self.d_model = d_model
def forward(self, h_start: torch.Tensor, h_end: torch.Tensor,
v_start: torch.Tensor) -> torch.Tensor:
"""
Parallel transport vector v_start from h_start to h_end.
Args:
h_start: Starting hidden state [batch, d_model]
h_end: Ending hidden state [batch, d_model]
v_start: Vector to transport [batch, d_model]
Returns:
v_transported: Transported vector at h_end [batch, d_model]
"""
batch_size = h_start.shape[0]
# Encode the path
path = torch.cat([h_start, h_end], dim=-1)
path_encoding = self.path_encoder(path)
# Get transport matrix
transport_flat = self.transport_predictor(path_encoding)
transport_matrix = transport_flat.view(batch_size, self.d_model, self.d_model)
# Apply transport (with residual connection for stability)
v_transported = torch.bmm(transport_matrix, v_start.unsqueeze(-1)).squeeze(-1)
v_transported = v_transported + v_start # Residual
return v_transported
def compute_holonomy(self, hidden_states: torch.Tensor,
loop_indices: List[int]) -> torch.Tensor:
"""
Compute holonomy around a loop defined by indices.
Holonomy = what you get when you parallel transport a vector
around a closed loop and compare to the original.
"""
if len(loop_indices) < 3:
return torch.tensor(0.0)
# Start with a basis vector
v = torch.randn(1, self.d_model, device=hidden_states.device)
v = F.normalize(v, dim=-1)
v_original = v.clone()
# Transport around the loop
for i in range(len(loop_indices) - 1):
idx_start = loop_indices[i]
idx_end = loop_indices[i + 1]
h_start = hidden_states[idx_start].unsqueeze(0)
h_end = hidden_states[idx_end].unsqueeze(0)
v = self.forward(h_start, h_end, v)
# Close the loop
h_start = hidden_states[loop_indices[-1]].unsqueeze(0)
h_end = hidden_states[loop_indices[0]].unsqueeze(0)
v_final = self.forward(h_start, h_end, v)
# Holonomy magnitude
holonomy = 1 - F.cosine_similarity(v_final, v_original, dim=-1)
return holonomy
# =============================================================================
# RUN EXPERIMENT
# =============================================================================
def main():
"""Run the Lie Holonomy experiment."""
print("="*60)
print("LIE HOLONOMY TRANSFORMER - GEOMETRIC ANALYSIS")
print("="*60)
print("\nHypothesis: Hidden state GEOMETRY (curvature, holonomy)")
print("predicts model behavior better than raw states.\n")
config = GeometryConfig(
model_path=".", # Current directory
n_sequences=100,
max_length=256,
)
experiment = LieHolonomyExperiment(config)
results = experiment.run_experiment()
print("\n" + "="*60)
print("EXPERIMENT COMPLETE")
print("="*60)
print("\nNext steps based on results:")
sep = results["correlations"]["curvature_separation_ratio"]
if sep > 2.0:
print("""
✅ STRONG SIGNAL DETECTED
The geometric approach shows promise. Next:
1. Train a CurvatureProbe to beat your 80x probe
2. Implement the ConnectionNetwork for learned parallel transport
3. Use holonomy as a NEW training signal for self-improvement
This could be the breakthrough.
""")
elif sep > 1.3:
print("""
📊 MODERATE SIGNAL
Worth pursuing. Try:
1. Different geometric features (torsion, geodesic deviation)
2. Multi-layer analysis (geometry at each transformer layer)
3. Larger sample sizes
""")
else:
print("""
⚠️ WEAK SIGNAL
Raw curvature may not be the right feature. Try:
1. Learned geometric features (train the connection network)
2. Sectional curvature (curvature in specific 2D planes)
3. Ricci curvature (average over all directions)
""")
return results
if __name__ == "__main__":
main()