File size: 4,687 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""tilelli.core.tilelli_lm — minimal byte-level language model built on
ternary primitives + heterogeneous-pathway blocks.

Stacks TilelliBlock layers on top of a byte embedding and a ternary
unembedding, plus a learned positional embedding.
"""
from __future__ import annotations

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from tilelli.core.ternary_linear import TernaryLinear
from tilelli.core.tilelli_block import TilelliBlock


class TilelliLM(nn.Module):
    """Byte-level Tilelli language model."""

    def __init__(
        self,
        vocab_size: int = 256,
        d_model: int = 128,
        n_layers: int = 4,
        d_head: int = 32,
        top_k: int = 8,
        pathways: int = 5,
        max_seq_len: int = 512,
        quantize: bool = True,
        n_banks: int = 1,
        per_row: bool = False,
        hadamard: bool = False,
        lsq: bool = False,
        dense_expand: int = 2,
        fp_attention: bool = False,
        top_k_routing: int = 0,
    ) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.quantize = quantize
        self.n_banks = n_banks
        self.per_row = per_row
        self.hadamard = hadamard
        self.lsq = lsq
        self.dense_expand = dense_expand
        self.fp_attention = fp_attention
        self.top_k_routing = top_k_routing

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList(
            [
                TilelliBlock(
                    d_model=d_model,
                    d_head=d_head,
                    top_k=top_k,
                    pathways=pathways,
                    n_banks=n_banks,
                    quantize=quantize,
                    per_row=per_row,
                    hadamard=hadamard,
                    lsq=lsq,
                    dense_expand=dense_expand,
                    fp_attention=fp_attention,
                    top_k_routing=top_k_routing,
                )
                for _ in range(n_layers)
            ]
        )
        self.norm_out = nn.LayerNorm(d_model)
        self.unembed = TernaryLinear(
            d_model, vocab_size,
            quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
        )

    def forward(self, ids: Tensor) -> Tensor:
        if ids.dim() != 2:
            raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
        B, L = ids.shape
        if L > self.max_seq_len:
            raise ValueError(f"sequence length {L} exceeds max_seq_len {self.max_seq_len}")
        positions = torch.arange(L, device=ids.device)
        x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
        for block in self.blocks:
            x = block(x)
        x = self.norm_out(x)
        return self.unembed(x)

    @property
    def aux_loss(self) -> Tensor:
        """Sum of per-block load-balancing aux losses. Zero when n_banks=1."""
        if self.n_banks <= 1:
            return torch.tensor(0.0, device=self.token_emb.weight.device)
        return sum(b.aux_loss for b in self.blocks)

    def loss(self, ids: Tensor, targets: Tensor) -> Tensor:
        """Cross-entropy loss + load-balance aux when banking is on."""
        logits = self.forward(ids)
        ce = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1))
        if self.n_banks > 1:
            return ce + self.aux_loss
        return ce

    @torch.no_grad()
    def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor:
        was_training = self.training
        self.eval()
        try:
            for _ in range(n_new_tokens):
                ids_in = ids[:, -self.max_seq_len:]
                logits = self.forward(ids_in)[:, -1, :]
                next_id = logits.argmax(dim=-1, keepdim=True)
                ids = torch.cat([ids, next_id], dim=1)
            return ids
        finally:
            if was_training:
                self.train()

    @torch.no_grad()
    def router_entropies(self, ids: Tensor) -> list[Tensor]:
        if ids.dim() != 2:
            raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}")
        positions = torch.arange(ids.size(1), device=ids.device)
        x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :]
        out = []
        for block in self.blocks:
            out.append(block.router_entropy(x))
            x = block(x)
        return out

    def parameter_count(self) -> int:
        return sum(p.numel() for p in self.parameters())