File size: 9,810 Bytes
d035351 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | """
Diagnostics for ternary quantization x H4 geometry interaction.
Key questions answered:
1. Does ternary W_nudge preserve chamber navigation?
2. What's the weight distribution structure?
3. Does the rank-1 pattern from Phase 5 survive ternary?
4. How does geo_alignment change: float vs ternary?
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from bitlinear import ternary_quantize, BitLinear
from h4_polytopic_attention import generate_600_cell_vertices, build_coxeter_chambers
def chamber_preservation(model, n_test=10000):
"""
Measure how often float vs ternary nudge assigns the same chamber.
Generates random queries on S^3, applies float and ternary nudge,
computes chamber indices via sign(dot(q, root_i)), reports agreement.
Phase 5 showed 96.5% geo alignment with float nudge.
Theory predicts >90% preservation under ternary because
chambers only depend on signs, and ternary preserves signs.
"""
# Get simple roots from the model's first attention layer
attn = _find_h4_attention(model)
if attn is None:
print(" No H4AttentionLayer found in model")
return {}
roots = attn.simple_roots # (4, 4)
results = {}
for h in range(attn.n_heads):
# Generate random unit queries on S^3
q = torch.randn(n_test, 4)
q = F.normalize(q, dim=-1)
# Float nudge path
W_float = attn.W_nudge[h].detach() # (4, 4)
with torch.no_grad():
q_float = (q @ W_float.T)
q_float = F.normalize(q_float, dim=-1)
ch_float = (q_float @ roots.T >= 0).int() # (n_test, 4)
# Ternary nudge path
w_q, w_s = ternary_quantize(W_float)
W_tern = w_q * w_s
with torch.no_grad():
q_tern = (q @ W_tern.T)
q_tern = F.normalize(q_tern, dim=-1)
ch_tern = (q_tern @ roots.T >= 0).int()
# Agreement: all 4 sign bits match
agree = (ch_float == ch_tern).all(dim=-1).float().mean().item()
results[f'head_{h}'] = agree
return results
def nudge_ternary_structure(model):
"""
Analyze ternary W_nudge: SVD, zero pattern, 600-cell alignment.
Expected from Phase 5 observations (nudge_rank=1.68, geo_align=0.965):
- Ternary SVD should show a dominant sigma_1
- High zero% = head learned to ignore some Coxeter directions
- Dominant direction should still align with 600-cell vertices
"""
attn = _find_h4_attention(model)
if attn is None:
return {}
verts = attn.vertices # (120, 4)
results = {}
for h in range(attn.n_heads):
W = attn.W_nudge[h].detach()
delta = W - torch.eye(4, device=W.device)
# Float SVD
s_float = torch.linalg.svdvals(delta)
# Ternary quantization
w_q, w_s = ternary_quantize(W)
w_tern = w_q.float()
delta_tern = w_tern - torch.eye(4, device=w_tern.device) / w_s
# Ternary SVD
U, S_tern, Vt = torch.linalg.svd(w_tern)
s_tern = S_tern
# Ternary weight distribution
n = w_q.numel()
stats = {
'neg1': (w_q == -1).sum().item() / n,
'zero': (w_q == 0).sum().item() / n,
'pos1': (w_q == 1).sum().item() / n,
}
# Dominant direction alignment with 600-cell
dominant = Vt[0]
alignment_tern = (dominant @ verts.T).abs().max().item()
# Float dominant direction for comparison
_, _, Vt_float = torch.linalg.svd(delta)
dominant_float = Vt_float[0]
alignment_float = (dominant_float @ verts.T).abs().max().item()
# Rank ratio
rank_ratio_tern = (s_tern[0] / s_tern[1]).item() if s_tern[1] > 1e-8 else float('inf')
rank_ratio_float = (s_float[0] / s_float[1]).item() if s_float[1] > 1e-8 else float('inf')
results[f'head_{h}'] = {
'svd_float': s_float.tolist(),
'svd_ternary': s_tern.tolist(),
'rank_ratio_float': rank_ratio_float,
'rank_ratio_ternary': rank_ratio_tern,
'alignment_float': alignment_float,
'alignment_ternary': alignment_tern,
'weight_dist': stats,
'zero_rows': (w_q == 0).all(dim=1).sum().item(),
'zero_cols': (w_q == 0).all(dim=0).sum().item(),
}
return results
def bitlinear_layer_stats(model):
"""Collect ternary stats across all BitLinear layers in the model."""
stats = {}
for name, module in model.named_modules():
if isinstance(module, BitLinear):
s = module.ternary_stats
stats[name] = {
'shape': list(module.weight.shape),
'params': module.weight.numel(),
**s,
}
return stats
def size_comparison(model):
"""Compare model size: float32 vs ternary."""
total_float_params = 0
total_ternary_params = 0
total_buffer_elems = 0
for name, module in model.named_modules():
if isinstance(module, BitLinear):
total_ternary_params += module.weight.numel()
elif isinstance(module, nn.Linear):
total_float_params += module.weight.numel()
if module.bias is not None:
total_float_params += module.bias.numel()
for buf in model.buffers():
total_buffer_elems += buf.numel()
# Other parameters (embeddings, layer norms, W_nudge, chamber_bonus)
other_params = sum(p.numel() for p in model.parameters()) - total_ternary_params - total_float_params
float32_bits = (total_float_params + total_ternary_params + other_params) * 32
mixed_bits = total_ternary_params * 1.58 + (total_float_params + other_params) * 32
buffer_bits = total_buffer_elems * 32
return {
'ternary_params': total_ternary_params,
'float_params': total_float_params,
'other_params': other_params,
'buffer_elems': total_buffer_elems,
'float32_kb': float32_bits / 8 / 1024,
'mixed_kb': mixed_bits / 8 / 1024,
'buffer_kb': buffer_bits / 8 / 1024,
'compression': float32_bits / mixed_bits if mixed_bits > 0 else 1.0,
}
def full_report(model):
"""Print full ternary diagnostic report."""
print("=" * 64)
print(" TERNARY DIAGNOSTICS: H4 x BitNet b1.58")
print("=" * 64)
# Check if model uses BitLinear
has_bitlinear = any(isinstance(m, BitLinear) for m in model.modules())
attn = _find_h4_attention(model)
if attn is None:
print(" No H4AttentionLayer found in model")
return
# Chamber preservation
print("\n Chamber Preservation (float vs ternary nudge):")
print(" (Agreement rate: same Coxeter chamber under both paths)")
cp = chamber_preservation(model)
mean_cp = 0.0
for head, agree in cp.items():
status = "OK" if agree > 0.90 else "WARN" if agree > 0.80 else "FAIL"
print(f" {head}: {agree:.1%} [{status}]")
mean_cp += agree
if cp:
mean_cp /= len(cp)
print(f" MEAN: {mean_cp:.1%}")
# Nudge structure
print("\n Ternary Nudge Structure:")
ns = nudge_ternary_structure(model)
for head, info in ns.items():
print(f" {head}:")
print(f" Float SVD: [{', '.join(f'{s:.3f}' for s in info['svd_float'])}]")
print(f" Ternary SVD: [{', '.join(f'{s:.3f}' for s in info['svd_ternary'])}]")
rr_f = info['rank_ratio_float']
rr_t = info['rank_ratio_ternary']
print(f" Rank ratio: float={rr_f:.1f}x ternary={rr_t:.1f}x")
print(f" 600-cell alignment: float={info['alignment_float']:.3f} ternary={info['alignment_ternary']:.3f}")
d = info['weight_dist']
print(f" Ternary weights: -1={d['neg1']:.0%} 0={d['zero']:.0%} +1={d['pos1']:.0%}")
if info['zero_rows'] > 0 or info['zero_cols'] > 0:
print(f" Zero rows/cols: {info['zero_rows']}/4, {info['zero_cols']}/4")
# BitLinear layer stats
if has_bitlinear:
print("\n BitLinear Layers:")
bl_stats = bitlinear_layer_stats(model)
for name, s in bl_stats.items():
print(f" {name:40s} {s['params']:6d} params "
f"-1={s['neg1']:.0%} 0={s['zero']:.0%} +1={s['pos1']:.0%}")
# Size comparison
print("\n Size Comparison:")
sz = size_comparison(model)
print(f" Ternary params: {sz['ternary_params']:,} ({sz['ternary_params'] * 1.58 / 8 / 1024:.1f} KB at 1.58 bits)")
print(f" Float params: {sz['float_params']:,} ({sz['float_params'] * 4 / 1024:.1f} KB)")
print(f" Other params: {sz['other_params']:,} ({sz['other_params'] * 4 / 1024:.1f} KB)")
print(f" Buffers: {sz['buffer_elems']:,} ({sz['buffer_kb']:.1f} KB)")
print(f" All-float32: {sz['float32_kb']:.1f} KB")
print(f" Mixed ternary: {sz['mixed_kb']:.1f} KB")
print(f" Compression: {sz['compression']:.1f}x")
print("=" * 64)
def _find_h4_attention(model):
"""Find the first H4AttentionLayer in a model."""
from h4_hybrid_attention import H4AttentionLayer
for module in model.modules():
if isinstance(module, H4AttentionLayer):
return module
return None
if __name__ == '__main__':
from h4_language_model import H4LanguageModel
print("Testing with float model...")
model_float = H4LanguageModel(vocab_size=128, d_model=64, n_heads=8, n_layers=2, use_bitlinear=False)
full_report(model_float)
print("\n\nTesting with BitLinear model...")
model_tern = H4LanguageModel(vocab_size=128, d_model=64, n_heads=8, n_layers=2, use_bitlinear=True)
full_report(model_tern)
|