BitRASP-18M-Ghetto / model.py
livadies's picture
Initial commit: The Ghetto Architecture is alive. MatMul is dead.
19358e0 verified
"""
BitRASP language model assembly.
The model is byte-level on purpose: it lets the hard router use cheap regex-like
classes directly from token ids and avoids a tokenizer dependency in the MVP.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from core_math import AbsMaxNorm, BitRaspBlock, TernaryLinear, fake_quant_int
Tensor = torch.Tensor
DeviceMap = Optional[Dict[Union[int, str], Union[str, torch.device]]]
@dataclass
class BitRaspConfig:
vocab_size: int = 258
d_model: int = 256
n_layers: int = 8
state_dim: int = 256
num_experts: int = 512
expert_hidden: int = 128
active_experts: int = 1
lut_bins: int = 16
max_seq_len: int = 1024
tie_embeddings: bool = False
class ByteTokenizer:
"""Minimal byte tokenizer.
ids 0..255 are raw UTF-8 bytes, 256 is BOS, 257 is EOS/PAD for simple runs.
"""
bos_id = 256
eos_id = 257
pad_id = 257
vocab_size = 258
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
ids = list(text.encode("utf-8", errors="replace"))
if add_bos:
ids.insert(0, self.bos_id)
if add_eos:
ids.append(self.eos_id)
return ids
def decode(self, ids: Iterable[int]) -> str:
raw = bytes([i for i in ids if 0 <= int(i) < 256])
return raw.decode("utf-8", errors="replace")
class BitRaspLM(nn.Module):
def __init__(self, config: BitRaspConfig):
super().__init__()
self.config = config
self.embed = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList(
[
BitRaspBlock(
d_model=config.d_model,
state_dim=config.state_dim,
num_experts=config.num_experts,
expert_hidden=config.expert_hidden,
active_experts=config.active_experts,
lut_bins=config.lut_bins,
)
for _ in range(config.n_layers)
]
)
self.final_norm = AbsMaxNorm(config.d_model)
self.lm_head = TernaryLinear(config.d_model, config.vocab_size, bias=False)
if config.tie_embeddings and config.d_model == config.vocab_size:
self.lm_head.weight = self.embed.weight
def init_state(self, batch: int, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32) -> List[Tensor]:
if device is None:
device = self.embed.weight.device
return [layer.init_state(batch, device, dtype) for layer in self.layers]
def apply_device_map(self, device_map: DeviceMap) -> "BitRaspLM":
"""Move model pieces with a tiny accelerate-like map.
Examples:
{"embed": "cuda:0", 0: "cuda:0", 1: "cpu", "head": "cpu"}
{"layers.0": "cuda:0", "layers.1": "cuda:0", "default": "cpu"}
"""
if not device_map:
return self
default = torch.device(device_map.get("default", "cpu"))
self.embed.to(torch.device(device_map.get("embed", default)))
for i, layer in enumerate(self.layers):
dev = device_map.get(i, device_map.get(f"layers.{i}", default))
layer.to(torch.device(dev))
head_device = torch.device(device_map.get("head", device_map.get("lm_head", default)))
self.final_norm.to(head_device)
self.lm_head.to(head_device)
return self
def _layer_forward(self, layer: BitRaspBlock, x: Tensor, ids: Tensor, state: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
return layer(x, ids, state)
def forward(
self,
input_ids: Tensor,
targets: Optional[Tensor] = None,
state: Optional[List[Tensor]] = None,
use_checkpoint: bool = False,
) -> Dict[str, Union[Tensor, List[Tensor]]]:
if input_ids.dtype != torch.long:
input_ids = input_ids.long()
first_device = self.embed.weight.device
ids_for_embed = input_ids.to(first_device, non_blocking=True)
x = fake_quant_int(self.embed(ids_for_embed))
if state is None:
state = [None] * len(self.layers)
next_state: List[Tensor] = []
ids_current = ids_for_embed
for i, layer in enumerate(self.layers):
layer_device = next(layer.parameters()).device
x = x.to(layer_device, non_blocking=True)
ids_current = input_ids.to(layer_device, non_blocking=True)
layer_state = state[i]
if layer_state is not None:
layer_state = layer_state.to(layer_device, non_blocking=True)
if use_checkpoint and self.training:
x, layer_next = checkpoint(
lambda a, b, c: self._layer_forward(layer, a, b, c),
x,
ids_current,
layer_state if layer_state is not None else layer.init_state(x.shape[0], layer_device, x.dtype),
use_reentrant=False,
)
else:
x, layer_next = self._layer_forward(layer, x, ids_current, layer_state)
next_state.append(layer_next.detach() if not self.training else layer_next)
head_device = self.lm_head.weight.device
x = self.final_norm(x.to(head_device, non_blocking=True))
logits = self.lm_head(x)
loss = None
if targets is not None:
targets = targets.to(logits.device, non_blocking=True).long()
loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1), ignore_index=ByteTokenizer.pad_id)
return {"logits": logits, "loss": loss, "state": next_state}
@torch.no_grad()
def generate(
self,
prompt_ids: List[int],
max_new_tokens: int = 128,
temperature: float = 1.0,
device: Optional[torch.device] = None,
) -> List[int]:
self.eval()
if device is None:
device = self.embed.weight.device
ids = list(prompt_ids)
state = None
for token in ids[:-1]:
inp = torch.tensor([[token]], device=device, dtype=torch.long)
state = self(inp, state=state)["state"] # type: ignore[index]
cur = ids[-1] if ids else ByteTokenizer.bos_id
for _ in range(max_new_tokens):
inp = torch.tensor([[cur]], device=device, dtype=torch.long)
out = self(inp, state=state)
state = out["state"] # type: ignore[assignment]
logits = out["logits"][:, -1, :] # type: ignore[index]
if temperature <= 0:
nxt = int(torch.argmax(logits, dim=-1).item())
else:
probs = torch.softmax(logits / temperature, dim=-1)
nxt = int(torch.multinomial(probs, num_samples=1).item())
ids.append(nxt)
cur = nxt
return ids
def tiny_config() -> BitRaspConfig:
return BitRaspConfig(
d_model=128,
n_layers=4,
state_dim=128,
num_experts=128,
expert_hidden=64,
active_experts=1,
max_seq_len=256,
)
def count_active_parameters(config: BitRaspConfig) -> float:
"""Approximate active parameter fraction per token."""
dense_per_layer = 3 * config.d_model * config.state_dim
expert_per_layer = config.active_experts * (config.d_model * config.expert_hidden + config.expert_hidden * config.d_model)
total_expert_per_layer = config.num_experts * (config.d_model * config.expert_hidden + config.expert_hidden * config.d_model)
active = config.n_layers * (dense_per_layer + expert_per_layer)
total = config.n_layers * (dense_per_layer + total_expert_per_layer)
total += config.vocab_size * config.d_model * 2
return active / total
if __name__ == "__main__":
cfg = tiny_config()
model = BitRaspLM(cfg)
x = torch.randint(0, cfg.vocab_size, (2, 32))
out = model(x, targets=x)
print("logits", tuple(out["logits"].shape), "loss", float(out["loss"]))
print("active parameter fraction", f"{count_active_parameters(cfg) * 100:.2f}%")