ApplePiesFromScratch's picture
Upload folder using huggingface_hub
b5d4048 verified
"""
model.py — MechanismBase
========================
The transformer decoder implementing P / G → Q.
Two configurations:
SmallConfig (~10M params) — appropriate for ~200K tokens.
Generalizes. Recommended for current corpus.
FullConfig (~235M params) — appropriate for ~2M+ tokens.
Use after expanding the training corpus.
Architecture maps to PL terminology:
wte — token embedding: seeds patterns P with initial loaded history
wpe — position encoding: adds positional loaded history
PropagationBlock — one complete P / G → Q step:
attention = gradient family G applied to P
residual = loaded history H_P accumulating
pre-norm = coherence check before each propagation
MLP = reconfiguration toward coherent state
ln_f — final coherence check
lm_head — output: weight-tied to wte (same carrier in and out)
Parameter counts (approximate):
SmallConfig: 10.5M params
FullConfig: 235.0M params
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
# =============================================================================
# CONFIGURATIONS
# =============================================================================
@dataclass
class SmallConfig:
"""
~10M params. Appropriate for 100K–500K tokens.
This is the working configuration for the current corpus (~200K tokens).
Trains in ~30 minutes on RTX 4060 Ti.
Will generalize, not just memorize.
"""
vocab_size: int = 16384 # Carrier V — BPE tokenizer
n_embd: int = 256 # Loaded history vector dimension
n_layer: int = 8 # Propagation steps
n_head: int = 8 # Gradient families per step
block_size: int = 256 # Context window
dropout: float = 0.1
name: str = "SmallBase"
@dataclass
class MediumConfig:
"""
~50M params. Appropriate for 500K–2M tokens.
Use after expanding generate_data.py to produce more derivation traces.
Trains in ~2-3 hours on RTX 4060 Ti.
"""
vocab_size: int = 16384
n_embd: int = 512
n_layer: int = 12
n_head: int = 8
block_size: int = 256
dropout: float = 0.1
name: str = "MediumBase"
@dataclass
class FullConfig:
"""
~235M params. The full AGI Base V1.
Appropriate for 2M+ tokens.
Requires expanding generate_data.py significantly (see comments there).
Trains in ~6 hours on RTX 4060 Ti when data is sufficient.
"""
vocab_size: int = 16384
n_embd: int = 1024
n_layer: int = 16
n_head: int = 16
block_size: int = 256
dropout: float = 0.1
name: str = "FullBase"
# Default: SmallConfig for the current corpus
MechanismConfig = SmallConfig
# =============================================================================
# PROPAGATION BLOCK
# =============================================================================
class PropagationBlock(nn.Module):
"""
One complete P / G → Q propagation step.
Attention : gradient family G applied to pattern P
Residual : loaded history H_P accumulating
LayerNorm : coherence threshold check (pre-norm: check BEFORE propagating)
MLP : reconfiguration toward coherent state
"""
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = nn.MultiheadAttention(
config.n_embd,
config.n_head,
dropout=config.dropout,
batch_first=True,
)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
self.drop = nn.Dropout(config.dropout)
def forward(self, x, attn_mask=None):
# Pre-norm: coherence check before gradient application
normed = self.ln1(x)
attn_out, _ = self.attn(
normed, normed, normed,
attn_mask=attn_mask,
need_weights=False,
)
# Residual accumulates loaded history
x = x + self.drop(attn_out)
x = x + self.mlp(self.ln2(x))
return x
# =============================================================================
# MECHANISMBASE
# =============================================================================
class MechanismBase(nn.Module):
"""
The mechanism instantiated in the weight carrier.
wte : token embedding — seeds patterns
wpe : position encoding — adds positional loaded history
h : propagation blocks
ln_f : final coherence check
lm_head : output (weight-tied to wte)
"""
def __init__(self, config):
super().__init__()
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.block_size, config.n_embd)
self.drop = nn.Dropout(config.dropout)
self.h = nn.ModuleList(
[PropagationBlock(config) for _ in range(config.n_layer)]
)
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying: input and output in the same carrier
self.lm_head.weight = self.wte.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
assert T <= self.config.block_size, \
f"Sequence length {T} exceeds block_size {self.config.block_size}"
positions = torch.arange(T, device=idx.device)
x = self.drop(self.wte(idx) + self.wpe(positions))
# Causal mask: patterns attend only to prior loaded history
causal_mask = nn.Transformer.generate_square_subsequent_mask(
T, device=idx.device
)
for block in self.h:
x = block(x, attn_mask=causal_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
)
return logits, loss
@torch.no_grad()
def generate(
self,
idx,
max_new_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
):
"""
Autoregressive generation with temperature + top-k + top-p sampling.
"""
self.eval()
for _ in range(max_new_tokens):
x = idx[:, -self.config.block_size:]
logits, _ = self(x, None)
next_logits = logits[0, -1, :] / temperature
# Top-k
if top_k > 0:
k = min(top_k, next_logits.size(-1))
topk_vals, _ = torch.topk(next_logits, k)
next_logits[next_logits < topk_vals[-1]] = float("-inf")
# Top-p
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = (cumprobs - F.softmax(sorted_logits, dim=-1)) > top_p
sorted_logits[remove] = float("-inf")
next_logits = torch.zeros_like(next_logits).scatter_(
0, sorted_idx, sorted_logits
)
probs = F.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_id.unsqueeze(0)], dim=1)
return idx
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
def parameter_summary(self) -> str:
total = self.count_parameters()
embed = self.wte.weight.numel()
lines = [
f" Configuration: {self.config.name}",
f" Total params: {total:,}",
f" Embed params: {embed:,} ({embed/total:.1%} of total)",
f" n_embd={self.config.n_embd}, "
f"n_layer={self.config.n_layer}, "
f"n_head={self.config.n_head}",
]
return "\n".join(lines)
if __name__ == "__main__":
for ConfigClass in [SmallConfig, MediumConfig, FullConfig]:
config = ConfigClass()
model = MechanismBase(config)
print(model.parameter_summary())
print()