| """ |
| Export trained H4 BitLinear model for deployment. |
| |
| Freezes all BitLinear layers to pure ternary, packs weights, |
| saves alongside frozen geometric constants. |
| |
| Output format: PyTorch state dict with: |
| - geometry.* : frozen H4/E8 buffers (float32) |
| - ternary.*.weight : int8 with values {-1, 0, +1} |
| - ternary.*.scale : float32 per-layer scale |
| - params.* : remaining float parameters (embeddings, norms, chamber_bonus) |
| """ |
|
|
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from bitlinear import BitLinear |
|
|
|
|
| def export(model, path): |
| """Export model with frozen ternary weights.""" |
|
|
| |
| frozen_count = 0 |
| for module in model.modules(): |
| if isinstance(module, BitLinear): |
| module.freeze() |
| frozen_count += 1 |
|
|
| |
| state = {} |
|
|
| |
| for name, buf in model.named_buffers(): |
| state[f'geometry.{name}'] = buf.cpu() |
|
|
| |
| for name, param in model.named_parameters(): |
| state[f'params.{name}'] = param.data.cpu() |
|
|
| |
| for name, module in model.named_modules(): |
| if isinstance(module, BitLinear) and module._frozen_ternary is not None: |
| state[f'ternary.{name}.weight'] = module._frozen_ternary.cpu() |
| state[f'ternary.{name}.scale'] = module._frozen_scale.cpu() |
|
|
| torch.save(state, path) |
|
|
| |
| ternary_params = sum( |
| t.numel() for k, t in state.items() |
| if k.startswith('ternary.') and 'scale' not in k |
| ) |
| geom_elems = sum( |
| t.numel() for k, t in state.items() |
| if k.startswith('geometry.') |
| ) |
| learned_params = sum( |
| t.numel() for k, t in state.items() |
| if k.startswith('params.') |
| ) |
|
|
| ternary_kb = ternary_params * 1.58 / 8 / 1024 |
| geom_kb = geom_elems * 4 / 1024 |
| learned_kb = learned_params * 4 / 1024 |
| total_mixed = ternary_kb + geom_kb + learned_kb |
| total_float = (ternary_params + geom_elems + learned_params) * 4 / 1024 |
|
|
| print(f"Exported to {path}") |
| print(f" Frozen BitLinear layers: {frozen_count}") |
| print(f" Ternary weights: {ternary_params:,} ({ternary_kb:.1f} KB at 1.58 bits)") |
| print(f" Geometry buffers: {geom_elems:,} ({geom_kb:.1f} KB float32)") |
| print(f" Learned float: {learned_params:,} ({learned_kb:.1f} KB float32)") |
| print(f" Total mixed: {total_mixed:.1f} KB") |
| print(f" vs all-float32: {total_float:.1f} KB") |
| print(f" Compression: {total_float / total_mixed:.1f}x") |
|
|
|
|
| if __name__ == '__main__': |
| from h4_language_model import H4LanguageModel |
|
|
| ckpt = sys.argv[1] if len(sys.argv) > 1 else None |
| out = sys.argv[2] if len(sys.argv) > 2 else 'h4_ternary_model.pt' |
|
|
| |
| model = H4LanguageModel( |
| vocab_size=256, |
| d_model=64, |
| n_heads=8, |
| n_layers=2, |
| d_value=16, |
| use_bitlinear=True, |
| ) |
|
|
| if ckpt: |
| model.load_state_dict(torch.load(ckpt, map_location='cpu')) |
| print(f"Loaded checkpoint from {ckpt}") |
| else: |
| print("No checkpoint provided, exporting untrained model (for format verification)") |
|
|
| export(model, out) |
|
|