File size: 9,176 Bytes
b5d4048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
model.py β€” MechanismBase
========================

The transformer decoder implementing P / G β†’ Q.

Two configurations:
    SmallConfig  (~10M params) β€” appropriate for ~200K tokens.
                                 Generalizes. Recommended for current corpus.

    FullConfig   (~235M params) β€” appropriate for ~2M+ tokens.
                                  Use after expanding the training corpus.

Architecture maps to PL terminology:
    wte          β€” token embedding: seeds patterns P with initial loaded history
    wpe          β€” position encoding: adds positional loaded history
    PropagationBlock β€” one complete P / G β†’ Q step:
                       attention = gradient family G applied to P
                       residual  = loaded history H_P accumulating
                       pre-norm  = coherence check before each propagation
                       MLP       = reconfiguration toward coherent state
    ln_f         β€” final coherence check
    lm_head      β€” output: weight-tied to wte (same carrier in and out)

Parameter counts (approximate):
    SmallConfig:   10.5M params
    FullConfig:   235.0M params
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


# =============================================================================
# CONFIGURATIONS
# =============================================================================

@dataclass
class SmallConfig:
    """
    ~10M params. Appropriate for 100K–500K tokens.
    This is the working configuration for the current corpus (~200K tokens).
    Trains in ~30 minutes on RTX 4060 Ti.
    Will generalize, not just memorize.
    """
    vocab_size:  int   = 16384   # Carrier V β€” BPE tokenizer
    n_embd:      int   = 256     # Loaded history vector dimension
    n_layer:     int   = 8       # Propagation steps
    n_head:      int   = 8       # Gradient families per step
    block_size:  int   = 256     # Context window
    dropout:     float = 0.1
    name:        str   = "SmallBase"


@dataclass
class MediumConfig:
    """
    ~50M params. Appropriate for 500K–2M tokens.
    Use after expanding generate_data.py to produce more derivation traces.
    Trains in ~2-3 hours on RTX 4060 Ti.
    """
    vocab_size:  int   = 16384
    n_embd:      int   = 512
    n_layer:     int   = 12
    n_head:      int   = 8
    block_size:  int   = 256
    dropout:     float = 0.1
    name:        str   = "MediumBase"


@dataclass
class FullConfig:
    """
    ~235M params. The full AGI Base V1.
    Appropriate for 2M+ tokens.
    Requires expanding generate_data.py significantly (see comments there).
    Trains in ~6 hours on RTX 4060 Ti when data is sufficient.
    """
    vocab_size:  int   = 16384
    n_embd:      int   = 1024
    n_layer:     int   = 16
    n_head:      int   = 16
    block_size:  int   = 256
    dropout:     float = 0.1
    name:        str   = "FullBase"


# Default: SmallConfig for the current corpus
MechanismConfig = SmallConfig


# =============================================================================
# PROPAGATION BLOCK
# =============================================================================

class PropagationBlock(nn.Module):
    """
    One complete P / G β†’ Q propagation step.

    Attention  : gradient family G applied to pattern P
    Residual   : loaded history H_P accumulating
    LayerNorm  : coherence threshold check (pre-norm: check BEFORE propagating)
    MLP        : reconfiguration toward coherent state
    """

    def __init__(self, config):
        super().__init__()
        self.ln1  = nn.LayerNorm(config.n_embd)
        self.attn = nn.MultiheadAttention(
            config.n_embd,
            config.n_head,
            dropout=config.dropout,
            batch_first=True,
        )
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout),
        )
        self.drop = nn.Dropout(config.dropout)

    def forward(self, x, attn_mask=None):
        # Pre-norm: coherence check before gradient application
        normed = self.ln1(x)
        attn_out, _ = self.attn(
            normed, normed, normed,
            attn_mask=attn_mask,
            need_weights=False,
        )
        # Residual accumulates loaded history
        x = x + self.drop(attn_out)
        x = x + self.mlp(self.ln2(x))
        return x


# =============================================================================
# MECHANISMBASE
# =============================================================================

class MechanismBase(nn.Module):
    """
    The mechanism instantiated in the weight carrier.

    wte        : token embedding β€” seeds patterns
    wpe        : position encoding β€” adds positional loaded history
    h          : propagation blocks
    ln_f       : final coherence check
    lm_head    : output (weight-tied to wte)
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.wte     = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe     = nn.Embedding(config.block_size, config.n_embd)
        self.drop    = nn.Dropout(config.dropout)
        self.h       = nn.ModuleList(
            [PropagationBlock(config) for _ in range(config.n_layer)]
        )
        self.ln_f    = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying: input and output in the same carrier
        self.lm_head.weight = self.wte.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        assert T <= self.config.block_size, \
            f"Sequence length {T} exceeds block_size {self.config.block_size}"

        positions = torch.arange(T, device=idx.device)
        x = self.drop(self.wte(idx) + self.wpe(positions))

        # Causal mask: patterns attend only to prior loaded history
        causal_mask = nn.Transformer.generate_square_subsequent_mask(
            T, device=idx.device
        )

        for block in self.h:
            x = block(x, attn_mask=causal_mask)

        x      = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
            )

        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx,
        max_new_tokens: int = 200,
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 0.9,
    ):
        """
        Autoregressive generation with temperature + top-k + top-p sampling.
        """
        self.eval()
        for _ in range(max_new_tokens):
            x = idx[:, -self.config.block_size:]
            logits, _ = self(x, None)
            next_logits = logits[0, -1, :] / temperature

            # Top-k
            if top_k > 0:
                k = min(top_k, next_logits.size(-1))
                topk_vals, _ = torch.topk(next_logits, k)
                next_logits[next_logits < topk_vals[-1]] = float("-inf")

            # Top-p
            if top_p < 1.0:
                sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
                cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                remove = (cumprobs - F.softmax(sorted_logits, dim=-1)) > top_p
                sorted_logits[remove] = float("-inf")
                next_logits = torch.zeros_like(next_logits).scatter_(
                    0, sorted_idx, sorted_logits
                )

            probs   = F.softmax(next_logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id.unsqueeze(0)], dim=1)

        return idx

    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def parameter_summary(self) -> str:
        total = self.count_parameters()
        embed = self.wte.weight.numel()
        lines = [
            f"  Configuration: {self.config.name}",
            f"  Total params:  {total:,}",
            f"  Embed params:  {embed:,} ({embed/total:.1%} of total)",
            f"  n_embd={self.config.n_embd}, "
            f"n_layer={self.config.n_layer}, "
            f"n_head={self.config.n_head}",
        ]
        return "\n".join(lines)


if __name__ == "__main__":
    for ConfigClass in [SmallConfig, MediumConfig, FullConfig]:
        config = ConfigClass()
        model  = MechanismBase(config)
        print(model.parameter_summary())
        print()