File size: 6,250 Bytes
f742bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""
H4 Language Model — Transformer LM with H4 geometric attention.

Architecture:
    - Token embedding + golden-angle positional encoding (PhiPositionalEncoding)
    - N × H4TransformerBlock (H4 attention + FFN)
    - LM head (Linear to vocab_size)

The frozen H4 geometry handles spatial partitioning of attention space.
Trainable adapters (nudge matrices, chamber bonuses, projections) learn
which directions to query and how to weight chambers.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from h4_hybrid_attention import H4TransformerBlock
from utils.phi_positional import PhiPositionalEncoding
from bitlinear import BitLinear


class H4LanguageModel(nn.Module):
    """
    Full language model with H4 polytopic attention.

    Args:
        vocab_size: vocabulary size
        d_model: model dimension
        n_heads: number of H4 attention heads per layer
        n_layers: number of transformer blocks
        d_value: value dimension per head
        d_ffn: FFN hidden dimension (default: 4 * d_model)
        top_k: max candidates per query in ChamberTree lookup
        max_seq_len: max sequence length for positional encoding cache
        dropout: dropout rate
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 64,
        n_heads: int = 8,
        n_layers: int = 4,
        d_value: int = 16,
        d_ffn: int = None,
        top_k: int = 32,
        max_seq_len: int = 8192,
        dropout: float = 0.1,
        use_bitlinear: bool = False,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.use_bitlinear = use_bitlinear

        if d_ffn is None:
            d_ffn = d_model * 4

        # Token embedding (always float — lookup table, not a matmul)
        self.token_emb = nn.Embedding(vocab_size, d_model)
        # Scale embedding by sqrt(d_model) as in original transformer
        self.emb_scale = math.sqrt(d_model)

        # Golden-angle positional encoding
        self.pos_enc = PhiPositionalEncoding(d_model, max_cached=max_seq_len)

        # Embedding dropout
        self.emb_dropout = nn.Dropout(dropout)

        # Transformer blocks with H4 attention
        self.blocks = nn.ModuleList([
            H4TransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_value=d_value,
                d_ffn=d_ffn,
                top_k=top_k,
                dropout=dropout,
                use_bitlinear=use_bitlinear,
            )
            for _ in range(n_layers)
        ])

        # Final layer norm
        self.ln_f = nn.LayerNorm(d_model)

        # LM head (tied with token embedding weights — stays float)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # Weight tying
        self.lm_head.weight = self.token_emb.weight

        self._init_weights()

    def _init_weights(self):
        """Initialize weights following GPT-2 conventions."""
        for module in self.modules():
            if isinstance(module, BitLinear):
                # BitLinear already has kaiming init; apply GPT-2 scale
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        use_tree: bool = True,
        return_diagnostics: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            input_ids: (batch, seq_len) token indices
            use_tree: if True, use ChamberTree for O(log t) attention
            return_diagnostics: if True, return (logits, list_of_diag_dicts)

        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        B, T = input_ids.shape

        # Token + positional embedding
        tok_emb = self.token_emb(input_ids) * self.emb_scale  # (B, T, D)
        pos_emb = self.pos_enc(T).unsqueeze(0).to(tok_emb.device)  # (1, T, D)
        x = self.emb_dropout(tok_emb + pos_emb)

        # Transformer blocks
        diagnostics = []
        for block in self.blocks:
            if return_diagnostics:
                x, diag = block(x, use_tree=use_tree, return_diagnostics=True)
                diagnostics.append(diag)
            else:
                x = block(x, use_tree=use_tree)

        # Final norm + LM head
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, vocab_size)

        if return_diagnostics:
            return logits, diagnostics
        return logits

    def count_params(self):
        """Count trainable and frozen parameters."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
        buffers = sum(b.numel() for b in self.buffers())
        return {
            'trainable': trainable,
            'frozen': frozen,
            'buffers': buffers,
            'total': trainable + frozen,
        }

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        top_k_sample: int = 0,
    ) -> torch.Tensor:
        """Autoregressive generation."""
        for _ in range(max_new_tokens):
            # Crop to max sequence length if needed
            logits = self.forward(input_ids, use_tree=False)
            logits = logits[:, -1, :] / temperature

            if top_k_sample > 0:
                v, _ = torch.topk(logits, min(top_k_sample, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_id], dim=1)

        return input_ids