File size: 3,327 Bytes
29da5d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Export trained H4 BitLinear model for deployment.

Freezes all BitLinear layers to pure ternary, packs weights,
saves alongside frozen geometric constants.

Output format: PyTorch state dict with:
  - geometry.* : frozen H4/E8 buffers (float32)
  - ternary.*.weight : int8 with values {-1, 0, +1}
  - ternary.*.scale : float32 per-layer scale
  - params.* : remaining float parameters (embeddings, norms, chamber_bonus)
"""

import torch
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from bitlinear import BitLinear


def export(model, path):
    """Export model with frozen ternary weights."""

    # Freeze all BitLinear layers
    frozen_count = 0
    for module in model.modules():
        if isinstance(module, BitLinear):
            module.freeze()
            frozen_count += 1

    # Collect all tensors
    state = {}

    # Frozen geometry (float32, static lookup tables)
    for name, buf in model.named_buffers():
        state[f'geometry.{name}'] = buf.cpu()

    # All parameters (float shadow weights, embeddings, norms, etc.)
    for name, param in model.named_parameters():
        state[f'params.{name}'] = param.data.cpu()

    # Frozen ternary buffers (int8 packed)
    for name, module in model.named_modules():
        if isinstance(module, BitLinear) and module._frozen_ternary is not None:
            state[f'ternary.{name}.weight'] = module._frozen_ternary.cpu()
            state[f'ternary.{name}.scale'] = module._frozen_scale.cpu()

    torch.save(state, path)

    # Report
    ternary_params = sum(
        t.numel() for k, t in state.items()
        if k.startswith('ternary.') and 'scale' not in k
    )
    geom_elems = sum(
        t.numel() for k, t in state.items()
        if k.startswith('geometry.')
    )
    learned_params = sum(
        t.numel() for k, t in state.items()
        if k.startswith('params.')
    )

    ternary_kb = ternary_params * 1.58 / 8 / 1024
    geom_kb = geom_elems * 4 / 1024
    learned_kb = learned_params * 4 / 1024
    total_mixed = ternary_kb + geom_kb + learned_kb
    total_float = (ternary_params + geom_elems + learned_params) * 4 / 1024

    print(f"Exported to {path}")
    print(f"  Frozen BitLinear layers: {frozen_count}")
    print(f"  Ternary weights: {ternary_params:,} ({ternary_kb:.1f} KB at 1.58 bits)")
    print(f"  Geometry buffers: {geom_elems:,} ({geom_kb:.1f} KB float32)")
    print(f"  Learned float:   {learned_params:,} ({learned_kb:.1f} KB float32)")
    print(f"  Total mixed:     {total_mixed:.1f} KB")
    print(f"  vs all-float32:  {total_float:.1f} KB")
    print(f"  Compression:     {total_float / total_mixed:.1f}x")


if __name__ == '__main__':
    from h4_language_model import H4LanguageModel

    ckpt = sys.argv[1] if len(sys.argv) > 1 else None
    out = sys.argv[2] if len(sys.argv) > 2 else 'h4_ternary_model.pt'

    # Default config for demo
    model = H4LanguageModel(
        vocab_size=256,
        d_model=64,
        n_heads=8,
        n_layers=2,
        d_value=16,
        use_bitlinear=True,
    )

    if ckpt:
        model.load_state_dict(torch.load(ckpt, map_location='cpu'))
        print(f"Loaded checkpoint from {ckpt}")
    else:
        print("No checkpoint provided, exporting untrained model (for format verification)")

    export(model, out)