File size: 4,687 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """tilelli.core.tilelli_lm — minimal byte-level language model built on
ternary primitives + heterogeneous-pathway blocks.
Stacks TilelliBlock layers on top of a byte embedding and a ternary
unembedding, plus a learned positional embedding.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from tilelli.core.ternary_linear import TernaryLinear
from tilelli.core.tilelli_block import TilelliBlock
class TilelliLM(nn.Module):
"""Byte-level Tilelli language model."""
def __init__(
self,
vocab_size: int = 256,
d_model: int = 128,
n_layers: int = 4,
d_head: int = 32,
top_k: int = 8,
pathways: int = 5,
max_seq_len: int = 512,
quantize: bool = True,
n_banks: int = 1,
per_row: bool = False,
hadamard: bool = False,
lsq: bool = False,
dense_expand: int = 2,
fp_attention: bool = False,
top_k_routing: int = 0,
) -> None:
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.quantize = quantize
self.n_banks = n_banks
self.per_row = per_row
self.hadamard = hadamard
self.lsq = lsq
self.dense_expand = dense_expand
self.fp_attention = fp_attention
self.top_k_routing = top_k_routing
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_seq_len, d_model)
self.blocks = nn.ModuleList(
[
TilelliBlock(
d_model=d_model,
d_head=d_head,
top_k=top_k,
pathways=pathways,
n_banks=n_banks,
quantize=quantize,
per_row=per_row,
hadamard=hadamard,
lsq=lsq,
dense_expand=dense_expand,
fp_attention=fp_attention,
top_k_routing=top_k_routing,
)
for _ in range(n_layers)
]
)
self.norm_out = nn.LayerNorm(d_model)
self.unembed = TernaryLinear(
d_model, vocab_size,
quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
)
def forward(self, ids: Tensor) -> Tensor:
if ids.dim() != 2:
raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
B, L = ids.shape
if L > self.max_seq_len:
raise ValueError(f"sequence length {L} exceeds max_seq_len {self.max_seq_len}")
positions = torch.arange(L, device=ids.device)
x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
for block in self.blocks:
x = block(x)
x = self.norm_out(x)
return self.unembed(x)
@property
def aux_loss(self) -> Tensor:
"""Sum of per-block load-balancing aux losses. Zero when n_banks=1."""
if self.n_banks <= 1:
return torch.tensor(0.0, device=self.token_emb.weight.device)
return sum(b.aux_loss for b in self.blocks)
def loss(self, ids: Tensor, targets: Tensor) -> Tensor:
"""Cross-entropy loss + load-balance aux when banking is on."""
logits = self.forward(ids)
ce = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1))
if self.n_banks > 1:
return ce + self.aux_loss
return ce
@torch.no_grad()
def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor:
was_training = self.training
self.eval()
try:
for _ in range(n_new_tokens):
ids_in = ids[:, -self.max_seq_len:]
logits = self.forward(ids_in)[:, -1, :]
next_id = logits.argmax(dim=-1, keepdim=True)
ids = torch.cat([ids, next_id], dim=1)
return ids
finally:
if was_training:
self.train()
@torch.no_grad()
def router_entropies(self, ids: Tensor) -> list[Tensor]:
if ids.dim() != 2:
raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
positions = torch.arange(ids.size(1), device=ids.device)
x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
out = []
for block in self.blocks:
out.append(block.router_entropy(x))
x = block(x)
return out
def parameter_count(self) -> int:
return sum(p.numel() for p in self.parameters())
|