BitRASP-18M-Ghetto / core_math.py
livadies's picture
Initial commit: The Ghetto Architecture is alive. MatMul is dead.
19358e0 verified
"""
BitRASP core math.
Research MVP for an addition-first recurrent language layer:
- ternary {-1, 0, +1} trainable projections with STE;
- fixed-size integer recurrent state, no KV cache;
- selective decay approximated by shifts: s <- s - (s >> k);
- KAN-like per-channel lookup tables instead of dense MLP activations;
- hard sparse MoE routing hooks driven by token/byte classes.
The training path is PyTorch-friendly and differentiable where practical. The
step_int8 path shows the intended no-FP-multiply inference contract; replacing
the scatter/gather pieces with packed C/AVX/SVE/Elbrus kernels is the next step.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
Tensor = torch.Tensor
@dataclass(frozen=True)
class QuantSpec:
bits: int = 8
eps: float = 1e-6
@property
def qmax(self) -> int:
return (1 << (self.bits - 1)) - 1
@property
def qmin(self) -> int:
return -(1 << (self.bits - 1))
def ste_round(x: Tensor) -> Tensor:
return x + (torch.round(x) - x).detach()
def fake_quant_int(x: Tensor, spec: QuantSpec = QuantSpec()) -> Tensor:
"""Symmetric per-token fake quantization with a straight-through gradient."""
scale = x.detach().abs().amax(dim=-1, keepdim=True).clamp_min(spec.eps) / spec.qmax
q = ste_round(x / scale).clamp(spec.qmin, spec.qmax)
return x + (q * scale - x).detach()
def ternarize_weight(w: Tensor, threshold: float = 0.55) -> Tuple[Tensor, Tensor]:
"""BitNet-style ternary weight plus per-output scale.
Returns a straight-through quantized weight and its scale. A production CPU
kernel would pack only the sign codes and keep scale as a power-of-two shift.
"""
scale = w.detach().abs().mean(dim=1, keepdim=True).clamp_min(1e-6)
code = torch.where(w > threshold * scale, 1.0, torch.where(w < -threshold * scale, -1.0, 0.0))
w_q = w + (code * scale - w).detach()
return w_q, scale.squeeze(1)
def additive_ternary_linear_int8(
x_q: Tensor,
weight_code: Tensor,
bias: Optional[Tensor] = None,
out_shift: int = 0,
) -> Tensor:
"""Reference no-multiply ternary linear for int tensors.
y[o] = sum(x[i] where W[o,i] == +1) - sum(x[i] where W[o,i] == -1)
This is intentionally simple and readable. It uses gathers and integer sums;
high performance requires packing ternary codes and vectorizing the positive
and negative accumulation loops in C/C++.
"""
if x_q.dtype not in (torch.int8, torch.int16, torch.int32, torch.int64):
raise TypeError("x_q must be an integer tensor")
if weight_code.dtype not in (torch.int8, torch.int16, torch.int32, torch.int64):
raise TypeError("weight_code must be an integer tensor")
flat = x_q.reshape(-1, x_q.shape[-1]).to(torch.int32)
out = torch.empty((flat.shape[0], weight_code.shape[0]), dtype=torch.int32, device=x_q.device)
for o in range(weight_code.shape[0]):
pos = torch.nonzero(weight_code[o] > 0, as_tuple=False).flatten()
neg = torch.nonzero(weight_code[o] < 0, as_tuple=False).flatten()
acc = flat.new_zeros(flat.shape[0])
if pos.numel():
acc += flat.index_select(1, pos).sum(dim=1)
if neg.numel():
acc -= flat.index_select(1, neg).sum(dim=1)
if bias is not None:
acc += bias[o].to(torch.int32)
if out_shift > 0:
acc = acc >> out_shift
out[:, o] = acc
return out.reshape(*x_q.shape[:-1], weight_code.shape[0])
class TernaryLinear(nn.Module):
"""Trainable ternary projection.
The default forward is the STE training path. `weight_code()` and
`forward_int8()` expose the add/sub inference semantics.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, threshold: float = 0.55):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.threshold = threshold
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, mean=0.0, std=1.0 / math.sqrt(self.in_features))
def forward(self, x: Tensor) -> Tensor:
w_q, _ = ternarize_weight(self.weight, self.threshold)
x_q = fake_quant_int(x)
return F.linear(x_q, w_q, self.bias)
@torch.no_grad()
def weight_code(self) -> Tensor:
scale = self.weight.abs().mean(dim=1, keepdim=True).clamp_min(1e-6)
return torch.where(
self.weight > self.threshold * scale,
1,
torch.where(self.weight < -self.threshold * scale, -1, 0),
).to(torch.int8)
@torch.no_grad()
def forward_int8(self, x_q: Tensor, out_shift: int = 0) -> Tensor:
bias = None
if self.bias is not None:
bias = ste_round(self.bias.detach()).to(device=x_q.device, dtype=torch.int32)
return additive_ternary_linear_int8(x_q, self.weight_code().to(x_q.device), bias, out_shift)
class AbsMaxNorm(nn.Module):
"""RMSNorm replacement that avoids square/multiply in the target inference path."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.gain = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
denom = x.detach().abs().amax(dim=-1, keepdim=True).clamp_min(self.eps)
y = x / denom
return fake_quant_int(y * self.gain)
class ChannelLUT(nn.Module):
"""KAN-ish learnable univariate function per channel.
Inference can quantize each channel to a bin and gather table values. The
training path uses linear interpolation so gradients reach the table.
"""
def __init__(self, channels: int, bins: int = 16, value_scale: float = 1.0):
super().__init__()
self.channels = channels
self.bins = bins
grid = torch.linspace(-value_scale, value_scale, bins)
self.table = nn.Parameter(grid.repeat(channels, 1))
def forward(self, x: Tensor) -> Tensor:
clipped = x.clamp(-1.0, 1.0)
pos = (clipped + 1.0) * (self.bins - 1) * 0.5
lo = torch.floor(pos).long().clamp(0, self.bins - 1)
hi = (lo + 1).clamp(0, self.bins - 1)
frac = (pos - lo.to(pos.dtype)).unsqueeze(-1)
table = self.table.t().contiguous()
flat_lo = lo.reshape(-1, self.channels)
flat_hi = hi.reshape(-1, self.channels)
channel_idx = torch.arange(self.channels, device=x.device).view(1, -1)
y_lo = table[flat_lo, channel_idx]
y_hi = table[flat_hi, channel_idx]
y = y_lo + (y_hi - y_lo) * frac.reshape(-1, self.channels)
return fake_quant_int(y.reshape_as(x))
@torch.no_grad()
def forward_int8(self, x_q: Tensor) -> Tensor:
idx = ((x_q.to(torch.int16) + 128) * (self.bins - 1) // 255).clamp(0, self.bins - 1)
table_q = torch.round(self.table.detach().clamp(-1, 1) * 127).to(torch.int8)
flat_idx = idx.reshape(-1, self.channels).long()
channel_idx = torch.arange(self.channels, device=x_q.device).view(1, -1)
y = table_q.t().contiguous()[flat_idx, channel_idx]
return y.reshape_as(x_q)
class ShiftSelectiveState(nn.Module):
"""Mamba/RWKV-inspired recurrent state, discretized to shifts and additions."""
def __init__(self, d_model: int, state_dim: int, min_shift: int = 1, max_shift: int = 6):
super().__init__()
self.d_model = d_model
self.state_dim = state_dim
self.min_shift = min_shift
self.max_shift = max_shift
self.in_proj = TernaryLinear(d_model, state_dim, bias=False)
self.out_proj = TernaryLinear(state_dim, d_model, bias=False)
self.shift_proj = TernaryLinear(d_model, state_dim, bias=True)
def init_state(self, batch: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor:
return torch.zeros(batch, self.state_dim, device=device, dtype=dtype)
def forward_step(self, x_t: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]:
drive = self.in_proj(x_t)
shift_score = self.shift_proj(x_t)
shift_bins = torch.sigmoid(shift_score)
shift = torch.round(shift_bins * (self.max_shift - self.min_shift) + self.min_shift)
decay = 1.0 - torch.pow(2.0, -shift)
new_state = fake_quant_int(state * decay + drive).clamp(-8.0, 8.0)
y_t = self.out_proj(new_state)
return y_t, new_state
def forward(self, x: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
batch, steps, _ = x.shape
if state is None:
state = self.init_state(batch, x.device, x.dtype)
outs = []
for t in range(steps):
y_t, state = self.forward_step(x[:, t], state)
outs.append(y_t)
return torch.stack(outs, dim=1), state
@torch.no_grad()
def step_int8(self, x_q: Tensor, state_i16: Tensor) -> Tuple[Tensor, Tensor]:
drive = self.in_proj.forward_int8(x_q, out_shift=2).to(torch.int16)
score = self.shift_proj.forward_int8(x_q, out_shift=4)
shift = ((score.clamp(-128, 127).to(torch.int16) + 128) * (self.max_shift - self.min_shift) // 255)
shift = (shift + self.min_shift).clamp(self.min_shift, self.max_shift)
decayed = state_i16 - torch.bitwise_right_shift(state_i16, shift)
new_state = (decayed + drive).clamp(-32768, 32767).to(torch.int16)
state_q = torch.bitwise_right_shift(new_state, 2).clamp(-128, 127).to(torch.int8)
y_q = self.out_proj.forward_int8(state_q, out_shift=3).clamp(-128, 127).to(torch.int8)
return y_q, new_state
class TernaryExpert(nn.Module):
def __init__(self, d_model: int, hidden_dim: int, bins: int = 16):
super().__init__()
self.up = TernaryLinear(d_model, hidden_dim)
self.act = ChannelLUT(hidden_dim, bins=bins)
self.down = TernaryLinear(hidden_dim, d_model)
def forward(self, x: Tensor) -> Tensor:
return self.down(self.act(self.up(x)))
def byte_regex_routes(input_ids: Tensor, num_experts: int) -> Tensor:
"""Hard byte-class router.
0 digits, 1 whitespace, 2 latin letters, 3 punctuation/operators, 4 non-ascii,
the rest hashed. This approximates regex pre-routing without a tokenizer
dependency and works with the byte-level trainer in train_ghetto.py.
"""
if num_experts <= 0:
raise ValueError("num_experts must be positive")
ids = input_ids.long()
route = torch.remainder(ids * 1103515245 + 12345, num_experts)
def set_if(mask: Tensor, expert: int) -> None:
if expert < num_experts:
route.masked_fill_(mask, expert)
set_if((ids >= 48) & (ids <= 57), 0)
set_if((ids == 9) | (ids == 10) | (ids == 13) | (ids == 32), 1)
set_if(((ids >= 65) & (ids <= 90)) | ((ids >= 97) & (ids <= 122)), 2)
set_if(((ids >= 33) & (ids <= 47)) | ((ids >= 58) & (ids <= 64)), 3)
set_if(ids >= 128, 4)
return route
class HardSparseMoE(nn.Module):
"""Extreme sparse MoE with deterministic routes and no softmax router."""
def __init__(self, d_model: int, num_experts: int = 64, hidden_dim: int = 128, active_experts: int = 1):
super().__init__()
if active_experts < 1:
raise ValueError("active_experts must be >= 1")
self.d_model = d_model
self.num_experts = num_experts
self.active_experts = active_experts
self.experts = nn.ModuleList([TernaryExpert(d_model, hidden_dim) for _ in range(num_experts)])
def route(self, input_ids: Tensor) -> Tensor:
primary = byte_regex_routes(input_ids, self.num_experts)
if self.active_experts == 1:
return primary.unsqueeze(-1)
offsets = torch.arange(self.active_experts, device=input_ids.device).view(*([1] * input_ids.ndim), -1)
return torch.remainder(primary.unsqueeze(-1) + offsets, self.num_experts)
def forward(self, x: Tensor, input_ids: Tensor) -> Tensor:
routes = self.route(input_ids)
flat_x = x.reshape(-1, x.shape[-1])
flat_routes = routes.reshape(-1, self.active_experts)
out = torch.zeros_like(flat_x)
for slot in range(self.active_experts):
slot_routes = flat_routes[:, slot]
for expert_idx_t in torch.unique(slot_routes):
expert_idx = int(expert_idx_t.item())
expert = self.experts[expert_idx]
mask = slot_routes == expert_idx
out[mask] += expert(flat_x[mask]) / float(self.active_experts)
return out.reshape_as(x)
class BitRaspBlock(nn.Module):
"""One BitRASP layer: norm -> shift-state mixer -> LUT -> hard sparse MoE."""
def __init__(
self,
d_model: int,
state_dim: int,
num_experts: int,
expert_hidden: int,
active_experts: int = 1,
lut_bins: int = 16,
):
super().__init__()
self.norm_a = AbsMaxNorm(d_model)
self.mixer = ShiftSelectiveState(d_model, state_dim)
self.lut = ChannelLUT(d_model, bins=lut_bins)
self.norm_b = AbsMaxNorm(d_model)
self.moe = HardSparseMoE(d_model, num_experts, expert_hidden, active_experts)
def init_state(self, batch: int, device: torch.device, dtype: torch.dtype = torch.float32) -> Tensor:
return self.mixer.init_state(batch, device, dtype)
def forward(self, x: Tensor, input_ids: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
mixed, next_state = self.mixer(self.norm_a(x), state)
x = fake_quant_int(x + self.lut(mixed))
x = fake_quant_int(x + self.moe(self.norm_b(x), input_ids))
return x, next_state