| """ |
| BitRASP language model assembly. |
| |
| The model is byte-level on purpose: it lets the hard router use cheap regex-like |
| classes directly from token ids and avoids a tokenizer dependency in the MVP. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Dict, Iterable, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
| from core_math import AbsMaxNorm, BitRaspBlock, TernaryLinear, fake_quant_int |
|
|
|
|
| Tensor = torch.Tensor |
| DeviceMap = Optional[Dict[Union[int, str], Union[str, torch.device]]] |
|
|
|
|
| @dataclass |
| class BitRaspConfig: |
| vocab_size: int = 258 |
| d_model: int = 256 |
| n_layers: int = 8 |
| state_dim: int = 256 |
| num_experts: int = 512 |
| expert_hidden: int = 128 |
| active_experts: int = 1 |
| lut_bins: int = 16 |
| max_seq_len: int = 1024 |
| tie_embeddings: bool = False |
|
|
|
|
| class ByteTokenizer: |
| """Minimal byte tokenizer. |
| |
| ids 0..255 are raw UTF-8 bytes, 256 is BOS, 257 is EOS/PAD for simple runs. |
| """ |
|
|
| bos_id = 256 |
| eos_id = 257 |
| pad_id = 257 |
| vocab_size = 258 |
|
|
| def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: |
| ids = list(text.encode("utf-8", errors="replace")) |
| if add_bos: |
| ids.insert(0, self.bos_id) |
| if add_eos: |
| ids.append(self.eos_id) |
| return ids |
|
|
| def decode(self, ids: Iterable[int]) -> str: |
| raw = bytes([i for i in ids if 0 <= int(i) < 256]) |
| return raw.decode("utf-8", errors="replace") |
|
|
|
|
| class BitRaspLM(nn.Module): |
| def __init__(self, config: BitRaspConfig): |
| super().__init__() |
| self.config = config |
| self.embed = nn.Embedding(config.vocab_size, config.d_model) |
| self.layers = nn.ModuleList( |
| [ |
| BitRaspBlock( |
| d_model=config.d_model, |
| state_dim=config.state_dim, |
| num_experts=config.num_experts, |
| expert_hidden=config.expert_hidden, |
| active_experts=config.active_experts, |
| lut_bins=config.lut_bins, |
| ) |
| for _ in range(config.n_layers) |
| ] |
| ) |
| self.final_norm = AbsMaxNorm(config.d_model) |
| self.lm_head = TernaryLinear(config.d_model, config.vocab_size, bias=False) |
| if config.tie_embeddings and config.d_model == config.vocab_size: |
| self.lm_head.weight = self.embed.weight |
|
|
| def init_state(self, batch: int, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32) -> List[Tensor]: |
| if device is None: |
| device = self.embed.weight.device |
| return [layer.init_state(batch, device, dtype) for layer in self.layers] |
|
|
| def apply_device_map(self, device_map: DeviceMap) -> "BitRaspLM": |
| """Move model pieces with a tiny accelerate-like map. |
| |
| Examples: |
| {"embed": "cuda:0", 0: "cuda:0", 1: "cpu", "head": "cpu"} |
| {"layers.0": "cuda:0", "layers.1": "cuda:0", "default": "cpu"} |
| """ |
|
|
| if not device_map: |
| return self |
|
|
| default = torch.device(device_map.get("default", "cpu")) |
| self.embed.to(torch.device(device_map.get("embed", default))) |
| for i, layer in enumerate(self.layers): |
| dev = device_map.get(i, device_map.get(f"layers.{i}", default)) |
| layer.to(torch.device(dev)) |
| head_device = torch.device(device_map.get("head", device_map.get("lm_head", default))) |
| self.final_norm.to(head_device) |
| self.lm_head.to(head_device) |
| return self |
|
|
| def _layer_forward(self, layer: BitRaspBlock, x: Tensor, ids: Tensor, state: Optional[Tensor]) -> Tuple[Tensor, Tensor]: |
| return layer(x, ids, state) |
|
|
| def forward( |
| self, |
| input_ids: Tensor, |
| targets: Optional[Tensor] = None, |
| state: Optional[List[Tensor]] = None, |
| use_checkpoint: bool = False, |
| ) -> Dict[str, Union[Tensor, List[Tensor]]]: |
| if input_ids.dtype != torch.long: |
| input_ids = input_ids.long() |
|
|
| first_device = self.embed.weight.device |
| ids_for_embed = input_ids.to(first_device, non_blocking=True) |
| x = fake_quant_int(self.embed(ids_for_embed)) |
|
|
| if state is None: |
| state = [None] * len(self.layers) |
| next_state: List[Tensor] = [] |
|
|
| ids_current = ids_for_embed |
| for i, layer in enumerate(self.layers): |
| layer_device = next(layer.parameters()).device |
| x = x.to(layer_device, non_blocking=True) |
| ids_current = input_ids.to(layer_device, non_blocking=True) |
| layer_state = state[i] |
| if layer_state is not None: |
| layer_state = layer_state.to(layer_device, non_blocking=True) |
|
|
| if use_checkpoint and self.training: |
| x, layer_next = checkpoint( |
| lambda a, b, c: self._layer_forward(layer, a, b, c), |
| x, |
| ids_current, |
| layer_state if layer_state is not None else layer.init_state(x.shape[0], layer_device, x.dtype), |
| use_reentrant=False, |
| ) |
| else: |
| x, layer_next = self._layer_forward(layer, x, ids_current, layer_state) |
| next_state.append(layer_next.detach() if not self.training else layer_next) |
|
|
| head_device = self.lm_head.weight.device |
| x = self.final_norm(x.to(head_device, non_blocking=True)) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if targets is not None: |
| targets = targets.to(logits.device, non_blocking=True).long() |
| loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1), ignore_index=ByteTokenizer.pad_id) |
|
|
| return {"logits": logits, "loss": loss, "state": next_state} |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt_ids: List[int], |
| max_new_tokens: int = 128, |
| temperature: float = 1.0, |
| device: Optional[torch.device] = None, |
| ) -> List[int]: |
| self.eval() |
| if device is None: |
| device = self.embed.weight.device |
| ids = list(prompt_ids) |
| state = None |
| for token in ids[:-1]: |
| inp = torch.tensor([[token]], device=device, dtype=torch.long) |
| state = self(inp, state=state)["state"] |
|
|
| cur = ids[-1] if ids else ByteTokenizer.bos_id |
| for _ in range(max_new_tokens): |
| inp = torch.tensor([[cur]], device=device, dtype=torch.long) |
| out = self(inp, state=state) |
| state = out["state"] |
| logits = out["logits"][:, -1, :] |
| if temperature <= 0: |
| nxt = int(torch.argmax(logits, dim=-1).item()) |
| else: |
| probs = torch.softmax(logits / temperature, dim=-1) |
| nxt = int(torch.multinomial(probs, num_samples=1).item()) |
| ids.append(nxt) |
| cur = nxt |
| return ids |
|
|
|
|
| def tiny_config() -> BitRaspConfig: |
| return BitRaspConfig( |
| d_model=128, |
| n_layers=4, |
| state_dim=128, |
| num_experts=128, |
| expert_hidden=64, |
| active_experts=1, |
| max_seq_len=256, |
| ) |
|
|
|
|
| def count_active_parameters(config: BitRaspConfig) -> float: |
| """Approximate active parameter fraction per token.""" |
|
|
| dense_per_layer = 3 * config.d_model * config.state_dim |
| expert_per_layer = config.active_experts * (config.d_model * config.expert_hidden + config.expert_hidden * config.d_model) |
| total_expert_per_layer = config.num_experts * (config.d_model * config.expert_hidden + config.expert_hidden * config.d_model) |
| active = config.n_layers * (dense_per_layer + expert_per_layer) |
| total = config.n_layers * (dense_per_layer + total_expert_per_layer) |
| total += config.vocab_size * config.d_model * 2 |
| return active / total |
|
|
|
|
| if __name__ == "__main__": |
| cfg = tiny_config() |
| model = BitRaspLM(cfg) |
| x = torch.randint(0, cfg.vocab_size, (2, 32)) |
| out = model(x, targets=x) |
| print("logits", tuple(out["logits"].shape), "loss", float(out["loss"])) |
| print("active parameter fraction", f"{count_active_parameters(cfg) * 100:.2f}%") |
|
|