ARBS / arbitor /sequencers.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Sequencer modules — input processing for all modalities."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG
if _HAS_TRITON:
import triton
import triton.language as tl
else:
triton = None
tl = None
try:
from .kernel.ternary_scale import _TritonTernaryEmbedFn
except ImportError:
_TritonTernaryEmbedFn = None
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
from math import ceil as _ceil
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE
class ByteEmbedding(nn.Module):
"""Byte-level embedding via packed ternary + BigInt correlation.
All training state is integer. T_accum/E_accum replaced by
corr_accum (int64 per group, never clips or resets).
S = 2^(E + K × mean_corr) where mean_corr = corr_accum / (step × gs)
"""
def __init__(self, tscale_type=TScaleType.T32):
super().__init__()
self.tscale_type = tscale_type
self.threshold = 0.05
self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
shape = (VOCAB, EMBEDDING_DIM)
init_std = 0.02
init_threshold = min(self.threshold, 0.5 * init_std)
self.threshold = init_threshold
w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std
T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
packed_T, T_shape, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
out_dim, in_dim = shape
gpr = _ceil_div(in_dim, self.group_size)
total_in = gpr * self.group_size
padded = torch.zeros(out_dim, total_in)
abs_w = w_init.abs()
padded[:, :in_dim] = abs_w
grouped = padded.view(out_dim, gpr, self.group_size)
grp_means = grouped.mean(dim=2)
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
# BigInt correlation accumulator (replaces T_accum + E_accum)
n_grp = out_dim * gpr
self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type)
def _get_T(self):
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
def _get_S(self):
gpr = _ceil_div(EMBEDDING_DIM, self.group_size)
e_adj = self.E.float()
step = int(self.step_counter.item())
if step > 0:
from .kernel.ternary_scale import _bigint_corr_strength
denom = max(step * self.group_size, 1)
e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1)
if E_exp.shape[1] > EMBEDDING_DIM:
E_exp = E_exp[:, :EMBEDDING_DIM]
return torch.exp2(E_exp)
@torch.no_grad()
def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
if grad_sign is None:
return
shape = tuple(self._T_shape.tolist())
out_dim, in_dim = shape
if tuple(grad_sign.shape) != shape:
return
gs = self.group_size
T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
signed = grad_sign.to(torch.int16) * T
gpr = _ceil_div(in_dim, gs)
total_in = gpr * gs
if total_in > in_dim:
signed = F.pad(signed, (0, total_in - in_dim))
score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16)
self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step)
self.step_counter += abs(int(corr_step))
def forward(self, x):
if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
_dummy = torch.zeros(1, device=x.device, requires_grad=True)
emb = _TritonTernaryEmbedFn.apply(x, _dummy, self)
return self.norm(emb)
T = self._get_T()
S = self._get_S()
w_eff = S * T.float()
w_eff_grad = w_eff.detach().requires_grad_(True)
def capture_w_grad(grad_w):
self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
w_eff_grad.register_hook(capture_w_grad)
out = self.norm(F.embedding(x, w_eff_grad))
return out
def ternary_step(self, accum_threshold=3):
if hasattr(self, "_hook_grad_T_sign"):
if hasattr(self, "_accumulate_corr_from_grad_sign"):
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
del self._hook_grad_T_sign
def update_E(self, loss_signal=None):
pass # E is fixed; S adjusted via corr_accum
class Sequencer(nn.Module):
def __init__(self, modality, window_size, tscale_type=TScaleType.T32):
super().__init__()
self.modality = modality
self.window_size = window_size
self.tscale_type = tscale_type
def forward(self, x):
raise NotImplementedError
class TextSequencer(Sequencer):
def __init__(self, tscale_type=TScaleType.T32):
super().__init__(modality='text', window_size=3, tscale_type=tscale_type)
self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
def forward(self, x):
trigrams = x.unfold(dimension=1, size=self.window_size, step=1)
trigrams = rearrange(trigrams, 'b t d w -> b t (d w)')
relational = self.projection(trigrams)
return self.norm(relational)
class VAE2DSequencer(Sequencer):
def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
super().__init__(modality='image', window_size=1, tscale_type=tscale_type)
from .encoders.vae2d import load_vae2d as _load_vae2d
self.vae = _load_vae2d(device=device, quantize=quantize)
self.vae_device = torch.device(device)
self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
def forward(self, x):
if x.device != self.vae_device:
x = x.to(self.vae_device)
latent = self.vae(x)
tokens = rearrange(latent, 'b c h w -> b (h w) c')
out = self.project(tokens)
return self.norm(out)
class VAEAudioSequencer(Sequencer):
def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
super().__init__(modality='audio', window_size=1, tscale_type=tscale_type)
from .encoders.vae2d import load_vae2d as _load_vae2d
from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band
self.vae = _load_vae2d(device=device, quantize=quantize)
self.vae_device = torch.device(device)
self.mel = _Mel3Band(sample_rate=AUDIO_SR)
self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
def forward(self, waveform):
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
elif waveform.dim() == 3:
if waveform.shape[1] == 1:
waveform = waveform.squeeze(1)
else:
waveform = waveform.mean(dim=1)
spec = self.mel(waveform)
if spec.device != self.vae_device:
spec = spec.to(self.vae_device)
latent = self.vae(spec)
tokens = rearrange(latent, 'b c h w -> b (h w) c')
out = self.project(tokens)
return self.norm(out)
class MultimodalSequencer(nn.Module):
def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True):
super().__init__()
self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None
self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None
self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None
self.enabled_modalities = []
if enable_text:
self.enabled_modalities.append('text')
if enable_image:
self.enabled_modalities.append('image')
if enable_audio:
self.enabled_modalities.append('audio')
def forward(self, modality_inputs):
outputs = {}
for mod in self.enabled_modalities:
seq = getattr(self, mod)
if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None:
outputs[mod] = seq(modality_inputs[mod])
return outputs