""" BitRASP language model assembly. The model is byte-level on purpose: it lets the hard router use cheap regex-like classes directly from token ids and avoids a tokenizer dependency in the MVP. """ from __future__ import annotations from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from core_math import AbsMaxNorm, BitRaspBlock, TernaryLinear, fake_quant_int Tensor = torch.Tensor DeviceMap = Optional[Dict[Union[int, str], Union[str, torch.device]]] @dataclass class BitRaspConfig: vocab_size: int = 258 d_model: int = 256 n_layers: int = 8 state_dim: int = 256 num_experts: int = 512 expert_hidden: int = 128 active_experts: int = 1 lut_bins: int = 16 max_seq_len: int = 1024 tie_embeddings: bool = False class ByteTokenizer: """Minimal byte tokenizer. ids 0..255 are raw UTF-8 bytes, 256 is BOS, 257 is EOS/PAD for simple runs. """ bos_id = 256 eos_id = 257 pad_id = 257 vocab_size = 258 def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: ids = list(text.encode("utf-8", errors="replace")) if add_bos: ids.insert(0, self.bos_id) if add_eos: ids.append(self.eos_id) return ids def decode(self, ids: Iterable[int]) -> str: raw = bytes([i for i in ids if 0 <= int(i) < 256]) return raw.decode("utf-8", errors="replace") class BitRaspLM(nn.Module): def __init__(self, config: BitRaspConfig): super().__init__() self.config = config self.embed = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList( [ BitRaspBlock( d_model=config.d_model, state_dim=config.state_dim, num_experts=config.num_experts, expert_hidden=config.expert_hidden, active_experts=config.active_experts, lut_bins=config.lut_bins, ) for _ in range(config.n_layers) ] ) self.final_norm = AbsMaxNorm(config.d_model) self.lm_head = TernaryLinear(config.d_model, config.vocab_size, bias=False) if config.tie_embeddings and config.d_model == config.vocab_size: self.lm_head.weight = self.embed.weight def init_state(self, batch: int, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32) -> List[Tensor]: if device is None: device = self.embed.weight.device return [layer.init_state(batch, device, dtype) for layer in self.layers] def apply_device_map(self, device_map: DeviceMap) -> "BitRaspLM": """Move model pieces with a tiny accelerate-like map. Examples: {"embed": "cuda:0", 0: "cuda:0", 1: "cpu", "head": "cpu"} {"layers.0": "cuda:0", "layers.1": "cuda:0", "default": "cpu"} """ if not device_map: return self default = torch.device(device_map.get("default", "cpu")) self.embed.to(torch.device(device_map.get("embed", default))) for i, layer in enumerate(self.layers): dev = device_map.get(i, device_map.get(f"layers.{i}", default)) layer.to(torch.device(dev)) head_device = torch.device(device_map.get("head", device_map.get("lm_head", default))) self.final_norm.to(head_device) self.lm_head.to(head_device) return self def _layer_forward(self, layer: BitRaspBlock, x: Tensor, ids: Tensor, state: Optional[Tensor]) -> Tuple[Tensor, Tensor]: return layer(x, ids, state) def forward( self, input_ids: Tensor, targets: Optional[Tensor] = None, state: Optional[List[Tensor]] = None, use_checkpoint: bool = False, ) -> Dict[str, Union[Tensor, List[Tensor]]]: if input_ids.dtype != torch.long: input_ids = input_ids.long() first_device = self.embed.weight.device ids_for_embed = input_ids.to(first_device, non_blocking=True) x = fake_quant_int(self.embed(ids_for_embed)) if state is None: state = [None] * len(self.layers) next_state: List[Tensor] = [] ids_current = ids_for_embed for i, layer in enumerate(self.layers): layer_device = next(layer.parameters()).device x = x.to(layer_device, non_blocking=True) ids_current = input_ids.to(layer_device, non_blocking=True) layer_state = state[i] if layer_state is not None: layer_state = layer_state.to(layer_device, non_blocking=True) if use_checkpoint and self.training: x, layer_next = checkpoint( lambda a, b, c: self._layer_forward(layer, a, b, c), x, ids_current, layer_state if layer_state is not None else layer.init_state(x.shape[0], layer_device, x.dtype), use_reentrant=False, ) else: x, layer_next = self._layer_forward(layer, x, ids_current, layer_state) next_state.append(layer_next.detach() if not self.training else layer_next) head_device = self.lm_head.weight.device x = self.final_norm(x.to(head_device, non_blocking=True)) logits = self.lm_head(x) loss = None if targets is not None: targets = targets.to(logits.device, non_blocking=True).long() loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1), ignore_index=ByteTokenizer.pad_id) return {"logits": logits, "loss": loss, "state": next_state} @torch.no_grad() def generate( self, prompt_ids: List[int], max_new_tokens: int = 128, temperature: float = 1.0, device: Optional[torch.device] = None, ) -> List[int]: self.eval() if device is None: device = self.embed.weight.device ids = list(prompt_ids) state = None for token in ids[:-1]: inp = torch.tensor([[token]], device=device, dtype=torch.long) state = self(inp, state=state)["state"] # type: ignore[index] cur = ids[-1] if ids else ByteTokenizer.bos_id for _ in range(max_new_tokens): inp = torch.tensor([[cur]], device=device, dtype=torch.long) out = self(inp, state=state) state = out["state"] # type: ignore[assignment] logits = out["logits"][:, -1, :] # type: ignore[index] if temperature <= 0: nxt = int(torch.argmax(logits, dim=-1).item()) else: probs = torch.softmax(logits / temperature, dim=-1) nxt = int(torch.multinomial(probs, num_samples=1).item()) ids.append(nxt) cur = nxt return ids def tiny_config() -> BitRaspConfig: return BitRaspConfig( d_model=128, n_layers=4, state_dim=128, num_experts=128, expert_hidden=64, active_experts=1, max_seq_len=256, ) def count_active_parameters(config: BitRaspConfig) -> float: """Approximate active parameter fraction per token.""" dense_per_layer = 3 * config.d_model * config.state_dim expert_per_layer = config.active_experts * (config.d_model * config.expert_hidden + config.expert_hidden * config.d_model) total_expert_per_layer = config.num_experts * (config.d_model * config.expert_hidden + config.expert_hidden * config.d_model) active = config.n_layers * (dense_per_layer + expert_per_layer) total = config.n_layers * (dense_per_layer + total_expert_per_layer) total += config.vocab_size * config.d_model * 2 return active / total if __name__ == "__main__": cfg = tiny_config() model = BitRaspLM(cfg) x = torch.randint(0, cfg.vocab_size, (2, 32)) out = model(x, targets=x) print("logits", tuple(out["logits"].shape), "loss", float(out["loss"])) print("active parameter fraction", f"{count_active_parameters(cfg) * 100:.2f}%")