| """ |
| BitRASP core math. |
| |
| Research MVP for an addition-first recurrent language layer: |
| - ternary {-1, 0, +1} trainable projections with STE; |
| - fixed-size integer recurrent state, no KV cache; |
| - selective decay approximated by shifts: s <- s - (s >> k); |
| - KAN-like per-channel lookup tables instead of dense MLP activations; |
| - hard sparse MoE routing hooks driven by token/byte classes. |
| |
| The training path is PyTorch-friendly and differentiable where practical. The |
| step_int8 path shows the intended no-FP-multiply inference contract; replacing |
| the scatter/gather pieces with packed C/AVX/SVE/Elbrus kernels is the next step. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| Tensor = torch.Tensor |
|
|
|
|
| @dataclass(frozen=True) |
| class QuantSpec: |
| bits: int = 8 |
| eps: float = 1e-6 |
|
|
| @property |
| def qmax(self) -> int: |
| return (1 << (self.bits - 1)) - 1 |
|
|
| @property |
| def qmin(self) -> int: |
| return -(1 << (self.bits - 1)) |
|
|
|
|
| def ste_round(x: Tensor) -> Tensor: |
| return x + (torch.round(x) - x).detach() |
|
|
|
|
| def fake_quant_int(x: Tensor, spec: QuantSpec = QuantSpec()) -> Tensor: |
| """Symmetric per-token fake quantization with a straight-through gradient.""" |
|
|
| scale = x.detach().abs().amax(dim=-1, keepdim=True).clamp_min(spec.eps) / spec.qmax |
| q = ste_round(x / scale).clamp(spec.qmin, spec.qmax) |
| return x + (q * scale - x).detach() |
|
|
|
|
| def ternarize_weight(w: Tensor, threshold: float = 0.55) -> Tuple[Tensor, Tensor]: |
| """BitNet-style ternary weight plus per-output scale. |
| |
| Returns a straight-through quantized weight and its scale. A production CPU |
| kernel would pack only the sign codes and keep scale as a power-of-two shift. |
| """ |
|
|
| scale = w.detach().abs().mean(dim=1, keepdim=True).clamp_min(1e-6) |
| code = torch.where(w > threshold * scale, 1.0, torch.where(w < -threshold * scale, -1.0, 0.0)) |
| w_q = w + (code * scale - w).detach() |
| return w_q, scale.squeeze(1) |
|
|
|
|
| def additive_ternary_linear_int8( |
| x_q: Tensor, |
| weight_code: Tensor, |
| bias: Optional[Tensor] = None, |
| out_shift: int = 0, |
| ) -> Tensor: |
| """Reference no-multiply ternary linear for int tensors. |
| |
| y[o] = sum(x[i] where W[o,i] == +1) - sum(x[i] where W[o,i] == -1) |
| |
| This is intentionally simple and readable. It uses gathers and integer sums; |
| high performance requires packing ternary codes and vectorizing the positive |
| and negative accumulation loops in C/C++. |
| """ |
|
|
| if x_q.dtype not in (torch.int8, torch.int16, torch.int32, torch.int64): |
| raise TypeError("x_q must be an integer tensor") |
| if weight_code.dtype not in (torch.int8, torch.int16, torch.int32, torch.int64): |
| raise TypeError("weight_code must be an integer tensor") |
|
|
| flat = x_q.reshape(-1, x_q.shape[-1]).to(torch.int32) |
| out = torch.empty((flat.shape[0], weight_code.shape[0]), dtype=torch.int32, device=x_q.device) |
| for o in range(weight_code.shape[0]): |
| pos = torch.nonzero(weight_code[o] > 0, as_tuple=False).flatten() |
| neg = torch.nonzero(weight_code[o] < 0, as_tuple=False).flatten() |
| acc = flat.new_zeros(flat.shape[0]) |
| if pos.numel(): |
| acc += flat.index_select(1, pos).sum(dim=1) |
| if neg.numel(): |
| acc -= flat.index_select(1, neg).sum(dim=1) |
| if bias is not None: |
| acc += bias[o].to(torch.int32) |
| if out_shift > 0: |
| acc = acc >> out_shift |
| out[:, o] = acc |
| return out.reshape(*x_q.shape[:-1], weight_code.shape[0]) |
|
|
|
|
| class TernaryLinear(nn.Module): |
| """Trainable ternary projection. |
| |
| The default forward is the STE training path. `weight_code()` and |
| `forward_int8()` expose the add/sub inference semantics. |
| """ |
|
|
| def __init__(self, in_features: int, out_features: int, bias: bool = True, threshold: float = 0.55): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.threshold = threshold |
| self.weight = nn.Parameter(torch.empty(out_features, in_features)) |
| self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.normal_(self.weight, mean=0.0, std=1.0 / math.sqrt(self.in_features)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| w_q, _ = ternarize_weight(self.weight, self.threshold) |
| x_q = fake_quant_int(x) |
| return F.linear(x_q, w_q, self.bias) |
|
|
| @torch.no_grad() |
| def weight_code(self) -> Tensor: |
| scale = self.weight.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) |
| return torch.where( |
| self.weight > self.threshold * scale, |
| 1, |
| torch.where(self.weight < -self.threshold * scale, -1, 0), |
| ).to(torch.int8) |
|
|
| @torch.no_grad() |
| def forward_int8(self, x_q: Tensor, out_shift: int = 0) -> Tensor: |
| bias = None |
| if self.bias is not None: |
| bias = ste_round(self.bias.detach()).to(device=x_q.device, dtype=torch.int32) |
| return additive_ternary_linear_int8(x_q, self.weight_code().to(x_q.device), bias, out_shift) |
|
|
|
|
| class AbsMaxNorm(nn.Module): |
| """RMSNorm replacement that avoids square/multiply in the target inference path.""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.gain = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| denom = x.detach().abs().amax(dim=-1, keepdim=True).clamp_min(self.eps) |
| y = x / denom |
| return fake_quant_int(y * self.gain) |
|
|
|
|
| class ChannelLUT(nn.Module): |
| """KAN-ish learnable univariate function per channel. |
| |
| Inference can quantize each channel to a bin and gather table values. The |
| training path uses linear interpolation so gradients reach the table. |
| """ |
|
|
| def __init__(self, channels: int, bins: int = 16, value_scale: float = 1.0): |
| super().__init__() |
| self.channels = channels |
| self.bins = bins |
| grid = torch.linspace(-value_scale, value_scale, bins) |
| self.table = nn.Parameter(grid.repeat(channels, 1)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| clipped = x.clamp(-1.0, 1.0) |
| pos = (clipped + 1.0) * (self.bins - 1) * 0.5 |
| lo = torch.floor(pos).long().clamp(0, self.bins - 1) |
| hi = (lo + 1).clamp(0, self.bins - 1) |
| frac = (pos - lo.to(pos.dtype)).unsqueeze(-1) |
|
|
| table = self.table.t().contiguous() |
| flat_lo = lo.reshape(-1, self.channels) |
| flat_hi = hi.reshape(-1, self.channels) |
| channel_idx = torch.arange(self.channels, device=x.device).view(1, -1) |
| y_lo = table[flat_lo, channel_idx] |
| y_hi = table[flat_hi, channel_idx] |
| y = y_lo + (y_hi - y_lo) * frac.reshape(-1, self.channels) |
| return fake_quant_int(y.reshape_as(x)) |
|
|
| @torch.no_grad() |
| def forward_int8(self, x_q: Tensor) -> Tensor: |
| idx = ((x_q.to(torch.int16) + 128) * (self.bins - 1) // 255).clamp(0, self.bins - 1) |
| table_q = torch.round(self.table.detach().clamp(-1, 1) * 127).to(torch.int8) |
| flat_idx = idx.reshape(-1, self.channels).long() |
| channel_idx = torch.arange(self.channels, device=x_q.device).view(1, -1) |
| y = table_q.t().contiguous()[flat_idx, channel_idx] |
| return y.reshape_as(x_q) |
|
|
|
|
| class ShiftSelectiveState(nn.Module): |
| """Mamba/RWKV-inspired recurrent state, discretized to shifts and additions.""" |
|
|
| def __init__(self, d_model: int, state_dim: int, min_shift: int = 1, max_shift: int = 6): |
| super().__init__() |
| self.d_model = d_model |
| self.state_dim = state_dim |
| self.min_shift = min_shift |
| self.max_shift = max_shift |
| self.in_proj = TernaryLinear(d_model, state_dim, bias=False) |
| self.out_proj = TernaryLinear(state_dim, d_model, bias=False) |
| self.shift_proj = TernaryLinear(d_model, state_dim, bias=True) |
|
|
| def init_state(self, batch: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor: |
| return torch.zeros(batch, self.state_dim, device=device, dtype=dtype) |
|
|
| def forward_step(self, x_t: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]: |
| drive = self.in_proj(x_t) |
| shift_score = self.shift_proj(x_t) |
| shift_bins = torch.sigmoid(shift_score) |
| shift = torch.round(shift_bins * (self.max_shift - self.min_shift) + self.min_shift) |
| decay = 1.0 - torch.pow(2.0, -shift) |
| new_state = fake_quant_int(state * decay + drive).clamp(-8.0, 8.0) |
| y_t = self.out_proj(new_state) |
| return y_t, new_state |
|
|
| def forward(self, x: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: |
| batch, steps, _ = x.shape |
| if state is None: |
| state = self.init_state(batch, x.device, x.dtype) |
| outs = [] |
| for t in range(steps): |
| y_t, state = self.forward_step(x[:, t], state) |
| outs.append(y_t) |
| return torch.stack(outs, dim=1), state |
|
|
| @torch.no_grad() |
| def step_int8(self, x_q: Tensor, state_i16: Tensor) -> Tuple[Tensor, Tensor]: |
| drive = self.in_proj.forward_int8(x_q, out_shift=2).to(torch.int16) |
| score = self.shift_proj.forward_int8(x_q, out_shift=4) |
| shift = ((score.clamp(-128, 127).to(torch.int16) + 128) * (self.max_shift - self.min_shift) // 255) |
| shift = (shift + self.min_shift).clamp(self.min_shift, self.max_shift) |
|
|
| decayed = state_i16 - torch.bitwise_right_shift(state_i16, shift) |
| new_state = (decayed + drive).clamp(-32768, 32767).to(torch.int16) |
| state_q = torch.bitwise_right_shift(new_state, 2).clamp(-128, 127).to(torch.int8) |
| y_q = self.out_proj.forward_int8(state_q, out_shift=3).clamp(-128, 127).to(torch.int8) |
| return y_q, new_state |
|
|
|
|
| class TernaryExpert(nn.Module): |
| def __init__(self, d_model: int, hidden_dim: int, bins: int = 16): |
| super().__init__() |
| self.up = TernaryLinear(d_model, hidden_dim) |
| self.act = ChannelLUT(hidden_dim, bins=bins) |
| self.down = TernaryLinear(hidden_dim, d_model) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.down(self.act(self.up(x))) |
|
|
|
|
| def byte_regex_routes(input_ids: Tensor, num_experts: int) -> Tensor: |
| """Hard byte-class router. |
| |
| 0 digits, 1 whitespace, 2 latin letters, 3 punctuation/operators, 4 non-ascii, |
| the rest hashed. This approximates regex pre-routing without a tokenizer |
| dependency and works with the byte-level trainer in train_ghetto.py. |
| """ |
|
|
| if num_experts <= 0: |
| raise ValueError("num_experts must be positive") |
| ids = input_ids.long() |
| route = torch.remainder(ids * 1103515245 + 12345, num_experts) |
|
|
| def set_if(mask: Tensor, expert: int) -> None: |
| if expert < num_experts: |
| route.masked_fill_(mask, expert) |
|
|
| set_if((ids >= 48) & (ids <= 57), 0) |
| set_if((ids == 9) | (ids == 10) | (ids == 13) | (ids == 32), 1) |
| set_if(((ids >= 65) & (ids <= 90)) | ((ids >= 97) & (ids <= 122)), 2) |
| set_if(((ids >= 33) & (ids <= 47)) | ((ids >= 58) & (ids <= 64)), 3) |
| set_if(ids >= 128, 4) |
| return route |
|
|
|
|
| class HardSparseMoE(nn.Module): |
| """Extreme sparse MoE with deterministic routes and no softmax router.""" |
|
|
| def __init__(self, d_model: int, num_experts: int = 64, hidden_dim: int = 128, active_experts: int = 1): |
| super().__init__() |
| if active_experts < 1: |
| raise ValueError("active_experts must be >= 1") |
| self.d_model = d_model |
| self.num_experts = num_experts |
| self.active_experts = active_experts |
| self.experts = nn.ModuleList([TernaryExpert(d_model, hidden_dim) for _ in range(num_experts)]) |
|
|
| def route(self, input_ids: Tensor) -> Tensor: |
| primary = byte_regex_routes(input_ids, self.num_experts) |
| if self.active_experts == 1: |
| return primary.unsqueeze(-1) |
| offsets = torch.arange(self.active_experts, device=input_ids.device).view(*([1] * input_ids.ndim), -1) |
| return torch.remainder(primary.unsqueeze(-1) + offsets, self.num_experts) |
|
|
| def forward(self, x: Tensor, input_ids: Tensor) -> Tensor: |
| routes = self.route(input_ids) |
| flat_x = x.reshape(-1, x.shape[-1]) |
| flat_routes = routes.reshape(-1, self.active_experts) |
| out = torch.zeros_like(flat_x) |
|
|
| for slot in range(self.active_experts): |
| slot_routes = flat_routes[:, slot] |
| for expert_idx_t in torch.unique(slot_routes): |
| expert_idx = int(expert_idx_t.item()) |
| expert = self.experts[expert_idx] |
| mask = slot_routes == expert_idx |
| out[mask] += expert(flat_x[mask]) / float(self.active_experts) |
| return out.reshape_as(x) |
|
|
|
|
| class BitRaspBlock(nn.Module): |
| """One BitRASP layer: norm -> shift-state mixer -> LUT -> hard sparse MoE.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| state_dim: int, |
| num_experts: int, |
| expert_hidden: int, |
| active_experts: int = 1, |
| lut_bins: int = 16, |
| ): |
| super().__init__() |
| self.norm_a = AbsMaxNorm(d_model) |
| self.mixer = ShiftSelectiveState(d_model, state_dim) |
| self.lut = ChannelLUT(d_model, bins=lut_bins) |
| self.norm_b = AbsMaxNorm(d_model) |
| self.moe = HardSparseMoE(d_model, num_experts, expert_hidden, active_experts) |
|
|
| def init_state(self, batch: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor: |
| return self.mixer.init_state(batch, device, dtype) |
|
|
| def forward(self, x: Tensor, input_ids: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: |
| mixed, next_state = self.mixer(self.norm_a(x), state) |
| x = fake_quant_int(x + self.lut(mixed)) |
| x = fake_quant_int(x + self.moe(self.norm_b(x), input_ids)) |
| return x, next_state |
|
|