| """Sequencer modules — input processing for all modalities.""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG |
| if _HAS_TRITON: |
| import triton |
| import triton.language as tl |
| else: |
| triton = None |
| tl = None |
| try: |
| from .kernel.ternary_scale import _TritonTernaryEmbedFn |
| except ImportError: |
| _TritonTernaryEmbedFn = None |
| from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary |
| from math import ceil as _ceil |
|
|
| _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0 |
| from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE |
|
|
|
|
| class ByteEmbedding(nn.Module): |
| """Byte-level embedding via packed ternary + BigInt correlation. |
| |
| All training state is integer. T_accum/E_accum replaced by |
| corr_accum (int64 per group, never clips or resets). |
| |
| S = 2^(E + K × mean_corr) where mean_corr = corr_accum / (step × gs) |
| """ |
| def __init__(self, tscale_type=TScaleType.T32): |
| super().__init__() |
| self.tscale_type = tscale_type |
| self.threshold = 0.05 |
| self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64]) |
| shape = (VOCAB, EMBEDDING_DIM) |
|
|
| init_std = 0.02 |
| init_threshold = min(self.threshold, 0.5 * init_std) |
| self.threshold = init_threshold |
| w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std |
| T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype) |
| packed_T, T_shape, T_pad = pack_ternary(T_init) |
|
|
| self.register_buffer("T_packed", packed_T) |
| self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long)) |
| self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) |
|
|
| out_dim, in_dim = shape |
| gpr = _ceil_div(in_dim, self.group_size) |
| total_in = gpr * self.group_size |
| padded = torch.zeros(out_dim, total_in) |
| abs_w = w_init.abs() |
| padded[:, :in_dim] = abs_w |
| grouped = padded.view(out_dim, gpr, self.group_size) |
| grp_means = grouped.mean(dim=2) |
| E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) |
| self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8)) |
|
|
| |
| n_grp = out_dim * gpr |
| self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64)) |
| self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) |
|
|
| self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type) |
|
|
| def _get_T(self): |
| return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) |
|
|
| def _get_S(self): |
| gpr = _ceil_div(EMBEDDING_DIM, self.group_size) |
| e_adj = self.E.float() |
| step = int(self.step_counter.item()) |
| if step > 0: |
| from .kernel.ternary_scale import _bigint_corr_strength |
| denom = max(step * self.group_size, 1) |
| e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength() |
| E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1) |
| if E_exp.shape[1] > EMBEDDING_DIM: |
| E_exp = E_exp[:, :EMBEDDING_DIM] |
| return torch.exp2(E_exp) |
|
|
| @torch.no_grad() |
| def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1): |
| if grad_sign is None: |
| return |
| shape = tuple(self._T_shape.tolist()) |
| out_dim, in_dim = shape |
| if tuple(grad_sign.shape) != shape: |
| return |
| gs = self.group_size |
| T = self._get_T().to(device=grad_sign.device, dtype=torch.int16) |
| signed = grad_sign.to(torch.int16) * T |
| gpr = _ceil_div(in_dim, gs) |
| total_in = gpr * gs |
| if total_in > in_dim: |
| signed = F.pad(signed, (0, total_in - in_dim)) |
| score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16) |
| self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step) |
| self.step_counter += abs(int(corr_step)) |
|
|
| def forward(self, x): |
| if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None: |
| _dummy = torch.zeros(1, device=x.device, requires_grad=True) |
| emb = _TritonTernaryEmbedFn.apply(x, _dummy, self) |
| return self.norm(emb) |
| T = self._get_T() |
| S = self._get_S() |
| w_eff = S * T.float() |
| w_eff_grad = w_eff.detach().requires_grad_(True) |
|
|
| def capture_w_grad(grad_w): |
| self._hook_grad_T_sign = grad_w.sign().to(torch.int8) |
|
|
| w_eff_grad.register_hook(capture_w_grad) |
| out = self.norm(F.embedding(x, w_eff_grad)) |
| return out |
|
|
| def ternary_step(self, accum_threshold=3): |
| if hasattr(self, "_hook_grad_T_sign"): |
| if hasattr(self, "_accumulate_corr_from_grad_sign"): |
| self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) |
| del self._hook_grad_T_sign |
|
|
| def update_E(self, loss_signal=None): |
| pass |
|
|
|
|
| class Sequencer(nn.Module): |
| def __init__(self, modality, window_size, tscale_type=TScaleType.T32): |
| super().__init__() |
| self.modality = modality |
| self.window_size = window_size |
| self.tscale_type = tscale_type |
|
|
| def forward(self, x): |
| raise NotImplementedError |
|
|
|
|
| class TextSequencer(Sequencer): |
| def __init__(self, tscale_type=TScaleType.T32): |
| super().__init__(modality='text', window_size=3, tscale_type=tscale_type) |
| self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type) |
| self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) |
|
|
| def forward(self, x): |
| trigrams = x.unfold(dimension=1, size=self.window_size, step=1) |
| trigrams = rearrange(trigrams, 'b t d w -> b t (d w)') |
| relational = self.projection(trigrams) |
| return self.norm(relational) |
| class VAE2DSequencer(Sequencer): |
| def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"): |
| super().__init__(modality='image', window_size=1, tscale_type=tscale_type) |
| from .encoders.vae2d import load_vae2d as _load_vae2d |
| self.vae = _load_vae2d(device=device, quantize=quantize) |
| self.vae_device = torch.device(device) |
| self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type) |
| self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) |
|
|
| def forward(self, x): |
| if x.device != self.vae_device: |
| x = x.to(self.vae_device) |
| latent = self.vae(x) |
| tokens = rearrange(latent, 'b c h w -> b (h w) c') |
| out = self.project(tokens) |
| return self.norm(out) |
|
|
|
|
| class VAEAudioSequencer(Sequencer): |
| def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"): |
| super().__init__(modality='audio', window_size=1, tscale_type=tscale_type) |
| from .encoders.vae2d import load_vae2d as _load_vae2d |
| from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band |
| self.vae = _load_vae2d(device=device, quantize=quantize) |
| self.vae_device = torch.device(device) |
| self.mel = _Mel3Band(sample_rate=AUDIO_SR) |
| self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type) |
| self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) |
|
|
| def forward(self, waveform): |
| if waveform.dim() == 1: |
| waveform = waveform.unsqueeze(0) |
| elif waveform.dim() == 3: |
| if waveform.shape[1] == 1: |
| waveform = waveform.squeeze(1) |
| else: |
| waveform = waveform.mean(dim=1) |
| spec = self.mel(waveform) |
| if spec.device != self.vae_device: |
| spec = spec.to(self.vae_device) |
| latent = self.vae(spec) |
| tokens = rearrange(latent, 'b c h w -> b (h w) c') |
| out = self.project(tokens) |
| return self.norm(out) |
|
|
|
|
| class MultimodalSequencer(nn.Module): |
| def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True): |
| super().__init__() |
| self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None |
| self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None |
| self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None |
| self.enabled_modalities = [] |
| if enable_text: |
| self.enabled_modalities.append('text') |
| if enable_image: |
| self.enabled_modalities.append('image') |
| if enable_audio: |
| self.enabled_modalities.append('audio') |
|
|
| def forward(self, modality_inputs): |
| outputs = {} |
| for mod in self.enabled_modalities: |
| seq = getattr(self, mod) |
| if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None: |
| outputs[mod] = seq(modality_inputs[mod]) |
| return outputs |
|
|