""" BitRASP core math. Research MVP for an addition-first recurrent language layer: - ternary {-1, 0, +1} trainable projections with STE; - fixed-size integer recurrent state, no KV cache; - selective decay approximated by shifts: s <- s - (s >> k); - KAN-like per-channel lookup tables instead of dense MLP activations; - hard sparse MoE routing hooks driven by token/byte classes. The training path is PyTorch-friendly and differentiable where practical. The step_int8 path shows the intended no-FP-multiply inference contract; replacing the scatter/gather pieces with packed C/AVX/SVE/Elbrus kernels is the next step. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F Tensor = torch.Tensor @dataclass(frozen=True) class QuantSpec: bits: int = 8 eps: float = 1e-6 @property def qmax(self) -> int: return (1 << (self.bits - 1)) - 1 @property def qmin(self) -> int: return -(1 << (self.bits - 1)) def ste_round(x: Tensor) -> Tensor: return x + (torch.round(x) - x).detach() def fake_quant_int(x: Tensor, spec: QuantSpec = QuantSpec()) -> Tensor: """Symmetric per-token fake quantization with a straight-through gradient.""" scale = x.detach().abs().amax(dim=-1, keepdim=True).clamp_min(spec.eps) / spec.qmax q = ste_round(x / scale).clamp(spec.qmin, spec.qmax) return x + (q * scale - x).detach() def ternarize_weight(w: Tensor, threshold: float = 0.55) -> Tuple[Tensor, Tensor]: """BitNet-style ternary weight plus per-output scale. Returns a straight-through quantized weight and its scale. A production CPU kernel would pack only the sign codes and keep scale as a power-of-two shift. """ scale = w.detach().abs().mean(dim=1, keepdim=True).clamp_min(1e-6) code = torch.where(w > threshold * scale, 1.0, torch.where(w < -threshold * scale, -1.0, 0.0)) w_q = w + (code * scale - w).detach() return w_q, scale.squeeze(1) def additive_ternary_linear_int8( x_q: Tensor, weight_code: Tensor, bias: Optional[Tensor] = None, out_shift: int = 0, ) -> Tensor: """Reference no-multiply ternary linear for int tensors. y[o] = sum(x[i] where W[o,i] == +1) - sum(x[i] where W[o,i] == -1) This is intentionally simple and readable. It uses gathers and integer sums; high performance requires packing ternary codes and vectorizing the positive and negative accumulation loops in C/C++. """ if x_q.dtype not in (torch.int8, torch.int16, torch.int32, torch.int64): raise TypeError("x_q must be an integer tensor") if weight_code.dtype not in (torch.int8, torch.int16, torch.int32, torch.int64): raise TypeError("weight_code must be an integer tensor") flat = x_q.reshape(-1, x_q.shape[-1]).to(torch.int32) out = torch.empty((flat.shape[0], weight_code.shape[0]), dtype=torch.int32, device=x_q.device) for o in range(weight_code.shape[0]): pos = torch.nonzero(weight_code[o] > 0, as_tuple=False).flatten() neg = torch.nonzero(weight_code[o] < 0, as_tuple=False).flatten() acc = flat.new_zeros(flat.shape[0]) if pos.numel(): acc += flat.index_select(1, pos).sum(dim=1) if neg.numel(): acc -= flat.index_select(1, neg).sum(dim=1) if bias is not None: acc += bias[o].to(torch.int32) if out_shift > 0: acc = acc >> out_shift out[:, o] = acc return out.reshape(*x_q.shape[:-1], weight_code.shape[0]) class TernaryLinear(nn.Module): """Trainable ternary projection. The default forward is the STE training path. `weight_code()` and `forward_int8()` expose the add/sub inference semantics. """ def __init__(self, in_features: int, out_features: int, bias: bool = True, threshold: float = 0.55): super().__init__() self.in_features = in_features self.out_features = out_features self.threshold = threshold self.weight = nn.Parameter(torch.empty(out_features, in_features)) self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None self.reset_parameters() def reset_parameters(self) -> None: nn.init.normal_(self.weight, mean=0.0, std=1.0 / math.sqrt(self.in_features)) def forward(self, x: Tensor) -> Tensor: w_q, _ = ternarize_weight(self.weight, self.threshold) x_q = fake_quant_int(x) return F.linear(x_q, w_q, self.bias) @torch.no_grad() def weight_code(self) -> Tensor: scale = self.weight.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) return torch.where( self.weight > self.threshold * scale, 1, torch.where(self.weight < -self.threshold * scale, -1, 0), ).to(torch.int8) @torch.no_grad() def forward_int8(self, x_q: Tensor, out_shift: int = 0) -> Tensor: bias = None if self.bias is not None: bias = ste_round(self.bias.detach()).to(device=x_q.device, dtype=torch.int32) return additive_ternary_linear_int8(x_q, self.weight_code().to(x_q.device), bias, out_shift) class AbsMaxNorm(nn.Module): """RMSNorm replacement that avoids square/multiply in the target inference path.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.gain = nn.Parameter(torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: denom = x.detach().abs().amax(dim=-1, keepdim=True).clamp_min(self.eps) y = x / denom return fake_quant_int(y * self.gain) class ChannelLUT(nn.Module): """KAN-ish learnable univariate function per channel. Inference can quantize each channel to a bin and gather table values. The training path uses linear interpolation so gradients reach the table. """ def __init__(self, channels: int, bins: int = 16, value_scale: float = 1.0): super().__init__() self.channels = channels self.bins = bins grid = torch.linspace(-value_scale, value_scale, bins) self.table = nn.Parameter(grid.repeat(channels, 1)) def forward(self, x: Tensor) -> Tensor: clipped = x.clamp(-1.0, 1.0) pos = (clipped + 1.0) * (self.bins - 1) * 0.5 lo = torch.floor(pos).long().clamp(0, self.bins - 1) hi = (lo + 1).clamp(0, self.bins - 1) frac = (pos - lo.to(pos.dtype)).unsqueeze(-1) table = self.table.t().contiguous() flat_lo = lo.reshape(-1, self.channels) flat_hi = hi.reshape(-1, self.channels) channel_idx = torch.arange(self.channels, device=x.device).view(1, -1) y_lo = table[flat_lo, channel_idx] y_hi = table[flat_hi, channel_idx] y = y_lo + (y_hi - y_lo) * frac.reshape(-1, self.channels) return fake_quant_int(y.reshape_as(x)) @torch.no_grad() def forward_int8(self, x_q: Tensor) -> Tensor: idx = ((x_q.to(torch.int16) + 128) * (self.bins - 1) // 255).clamp(0, self.bins - 1) table_q = torch.round(self.table.detach().clamp(-1, 1) * 127).to(torch.int8) flat_idx = idx.reshape(-1, self.channels).long() channel_idx = torch.arange(self.channels, device=x_q.device).view(1, -1) y = table_q.t().contiguous()[flat_idx, channel_idx] return y.reshape_as(x_q) class ShiftSelectiveState(nn.Module): """Mamba/RWKV-inspired recurrent state, discretized to shifts and additions.""" def __init__(self, d_model: int, state_dim: int, min_shift: int = 1, max_shift: int = 6): super().__init__() self.d_model = d_model self.state_dim = state_dim self.min_shift = min_shift self.max_shift = max_shift self.in_proj = TernaryLinear(d_model, state_dim, bias=False) self.out_proj = TernaryLinear(state_dim, d_model, bias=False) self.shift_proj = TernaryLinear(d_model, state_dim, bias=True) def init_state(self, batch: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor: return torch.zeros(batch, self.state_dim, device=device, dtype=dtype) def forward_step(self, x_t: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]: drive = self.in_proj(x_t) shift_score = self.shift_proj(x_t) shift_bins = torch.sigmoid(shift_score) shift = torch.round(shift_bins * (self.max_shift - self.min_shift) + self.min_shift) decay = 1.0 - torch.pow(2.0, -shift) new_state = fake_quant_int(state * decay + drive).clamp(-8.0, 8.0) y_t = self.out_proj(new_state) return y_t, new_state def forward(self, x: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: batch, steps, _ = x.shape if state is None: state = self.init_state(batch, x.device, x.dtype) outs = [] for t in range(steps): y_t, state = self.forward_step(x[:, t], state) outs.append(y_t) return torch.stack(outs, dim=1), state @torch.no_grad() def step_int8(self, x_q: Tensor, state_i16: Tensor) -> Tuple[Tensor, Tensor]: drive = self.in_proj.forward_int8(x_q, out_shift=2).to(torch.int16) score = self.shift_proj.forward_int8(x_q, out_shift=4) shift = ((score.clamp(-128, 127).to(torch.int16) + 128) * (self.max_shift - self.min_shift) // 255) shift = (shift + self.min_shift).clamp(self.min_shift, self.max_shift) decayed = state_i16 - torch.bitwise_right_shift(state_i16, shift) new_state = (decayed + drive).clamp(-32768, 32767).to(torch.int16) state_q = torch.bitwise_right_shift(new_state, 2).clamp(-128, 127).to(torch.int8) y_q = self.out_proj.forward_int8(state_q, out_shift=3).clamp(-128, 127).to(torch.int8) return y_q, new_state class TernaryExpert(nn.Module): def __init__(self, d_model: int, hidden_dim: int, bins: int = 16): super().__init__() self.up = TernaryLinear(d_model, hidden_dim) self.act = ChannelLUT(hidden_dim, bins=bins) self.down = TernaryLinear(hidden_dim, d_model) def forward(self, x: Tensor) -> Tensor: return self.down(self.act(self.up(x))) def byte_regex_routes(input_ids: Tensor, num_experts: int) -> Tensor: """Hard byte-class router. 0 digits, 1 whitespace, 2 latin letters, 3 punctuation/operators, 4 non-ascii, the rest hashed. This approximates regex pre-routing without a tokenizer dependency and works with the byte-level trainer in train_ghetto.py. """ if num_experts <= 0: raise ValueError("num_experts must be positive") ids = input_ids.long() route = torch.remainder(ids * 1103515245 + 12345, num_experts) def set_if(mask: Tensor, expert: int) -> None: if expert < num_experts: route.masked_fill_(mask, expert) set_if((ids >= 48) & (ids <= 57), 0) set_if((ids == 9) | (ids == 10) | (ids == 13) | (ids == 32), 1) set_if(((ids >= 65) & (ids <= 90)) | ((ids >= 97) & (ids <= 122)), 2) set_if(((ids >= 33) & (ids <= 47)) | ((ids >= 58) & (ids <= 64)), 3) set_if(ids >= 128, 4) return route class HardSparseMoE(nn.Module): """Extreme sparse MoE with deterministic routes and no softmax router.""" def __init__(self, d_model: int, num_experts: int = 64, hidden_dim: int = 128, active_experts: int = 1): super().__init__() if active_experts < 1: raise ValueError("active_experts must be >= 1") self.d_model = d_model self.num_experts = num_experts self.active_experts = active_experts self.experts = nn.ModuleList([TernaryExpert(d_model, hidden_dim) for _ in range(num_experts)]) def route(self, input_ids: Tensor) -> Tensor: primary = byte_regex_routes(input_ids, self.num_experts) if self.active_experts == 1: return primary.unsqueeze(-1) offsets = torch.arange(self.active_experts, device=input_ids.device).view(*([1] * input_ids.ndim), -1) return torch.remainder(primary.unsqueeze(-1) + offsets, self.num_experts) def forward(self, x: Tensor, input_ids: Tensor) -> Tensor: routes = self.route(input_ids) flat_x = x.reshape(-1, x.shape[-1]) flat_routes = routes.reshape(-1, self.active_experts) out = torch.zeros_like(flat_x) for slot in range(self.active_experts): slot_routes = flat_routes[:, slot] for expert_idx_t in torch.unique(slot_routes): expert_idx = int(expert_idx_t.item()) expert = self.experts[expert_idx] mask = slot_routes == expert_idx out[mask] += expert(flat_x[mask]) / float(self.active_experts) return out.reshape_as(x) class BitRaspBlock(nn.Module): """One BitRASP layer: norm -> shift-state mixer -> LUT -> hard sparse MoE.""" def __init__( self, d_model: int, state_dim: int, num_experts: int, expert_hidden: int, active_experts: int = 1, lut_bins: int = 16, ): super().__init__() self.norm_a = AbsMaxNorm(d_model) self.mixer = ShiftSelectiveState(d_model, state_dim) self.lut = ChannelLUT(d_model, bins=lut_bins) self.norm_b = AbsMaxNorm(d_model) self.moe = HardSparseMoE(d_model, num_experts, expert_hidden, active_experts) def init_state(self, batch: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor: return self.mixer.init_state(batch, device, dtype) def forward(self, x: Tensor, input_ids: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: mixed, next_state = self.mixer(self.norm_a(x), state) x = fake_quant_int(x + self.lut(mixed)) x = fake_quant_int(x + self.moe(self.norm_b(x), input_ids)) return x, next_state