File size: 3,327 Bytes
29da5d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | """
Export trained H4 BitLinear model for deployment.
Freezes all BitLinear layers to pure ternary, packs weights,
saves alongside frozen geometric constants.
Output format: PyTorch state dict with:
- geometry.* : frozen H4/E8 buffers (float32)
- ternary.*.weight : int8 with values {-1, 0, +1}
- ternary.*.scale : float32 per-layer scale
- params.* : remaining float parameters (embeddings, norms, chamber_bonus)
"""
import torch
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from bitlinear import BitLinear
def export(model, path):
"""Export model with frozen ternary weights."""
# Freeze all BitLinear layers
frozen_count = 0
for module in model.modules():
if isinstance(module, BitLinear):
module.freeze()
frozen_count += 1
# Collect all tensors
state = {}
# Frozen geometry (float32, static lookup tables)
for name, buf in model.named_buffers():
state[f'geometry.{name}'] = buf.cpu()
# All parameters (float shadow weights, embeddings, norms, etc.)
for name, param in model.named_parameters():
state[f'params.{name}'] = param.data.cpu()
# Frozen ternary buffers (int8 packed)
for name, module in model.named_modules():
if isinstance(module, BitLinear) and module._frozen_ternary is not None:
state[f'ternary.{name}.weight'] = module._frozen_ternary.cpu()
state[f'ternary.{name}.scale'] = module._frozen_scale.cpu()
torch.save(state, path)
# Report
ternary_params = sum(
t.numel() for k, t in state.items()
if k.startswith('ternary.') and 'scale' not in k
)
geom_elems = sum(
t.numel() for k, t in state.items()
if k.startswith('geometry.')
)
learned_params = sum(
t.numel() for k, t in state.items()
if k.startswith('params.')
)
ternary_kb = ternary_params * 1.58 / 8 / 1024
geom_kb = geom_elems * 4 / 1024
learned_kb = learned_params * 4 / 1024
total_mixed = ternary_kb + geom_kb + learned_kb
total_float = (ternary_params + geom_elems + learned_params) * 4 / 1024
print(f"Exported to {path}")
print(f" Frozen BitLinear layers: {frozen_count}")
print(f" Ternary weights: {ternary_params:,} ({ternary_kb:.1f} KB at 1.58 bits)")
print(f" Geometry buffers: {geom_elems:,} ({geom_kb:.1f} KB float32)")
print(f" Learned float: {learned_params:,} ({learned_kb:.1f} KB float32)")
print(f" Total mixed: {total_mixed:.1f} KB")
print(f" vs all-float32: {total_float:.1f} KB")
print(f" Compression: {total_float / total_mixed:.1f}x")
if __name__ == '__main__':
from h4_language_model import H4LanguageModel
ckpt = sys.argv[1] if len(sys.argv) > 1 else None
out = sys.argv[2] if len(sys.argv) > 2 else 'h4_ternary_model.pt'
# Default config for demo
model = H4LanguageModel(
vocab_size=256,
d_model=64,
n_heads=8,
n_layers=2,
d_value=16,
use_bitlinear=True,
)
if ckpt:
model.load_state_dict(torch.load(ckpt, map_location='cpu'))
print(f"Loaded checkpoint from {ckpt}")
else:
print("No checkpoint provided, exporting untrained model (for format verification)")
export(model, out)
|