File size: 5,560 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """tilelli.core.ternary_conv β depthwise causal 1-D conv with ternary weights.
Depthwise (groups=channels) so input channels per group is 1, making the
Hadamard rotation trivial (identity); we only expose per_row + lsq.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from tilelli.core.ternary import (
LearnableScale,
absmean_scale,
absmean_scale_per_row,
ternarize,
ternarize_lsq,
ternarize_per_row,
ternary_signs,
)
class TernaryCausalConv1d(nn.Module):
"""Depthwise causal 1-D conv with ternary weights and an FP32 shadow param."""
def __init__(
self,
channels: int,
kernel_size: int = 5,
quantize: bool = True,
per_row: bool = False,
lsq: bool = False,
) -> None:
super().__init__()
if lsq and per_row:
raise ValueError("lsq + per_row not supported")
self.channels = channels
self.kernel_size = kernel_size
self.quantize = quantize
self.per_row = per_row
self.lsq = lsq
w = torch.randn(channels, 1, kernel_size) * (1.0 / kernel_size**0.5)
self.weight = nn.Parameter(w)
if lsq:
init_alpha = (w.abs().mean().item() or 1.0)
self.lsq_scale = LearnableScale(initial=init_alpha)
else:
self.lsq_scale = None # type: ignore[assignment]
def _quantize(self, w: Tensor) -> Tensor:
if self.lsq:
return ternarize_lsq(w, self.lsq_scale.value())
if self.per_row:
return ternarize_per_row(w)
return ternarize(w)
def forward(self, x: Tensor) -> Tensor:
if x.dim() != 3:
raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
if x.shape[-1] != self.channels:
raise ValueError(
f"channel mismatch: module has {self.channels}, input has {x.shape[-1]}"
)
x_ = x.transpose(1, 2)
x_ = F.pad(x_, (self.kernel_size - 1, 0))
w = self.weight if not self.quantize else self._quantize(self.weight)
y = F.conv1d(x_, w, groups=self.channels)
return y.transpose(1, 2)
@torch.no_grad()
def trits(self) -> Tensor:
if self.lsq:
alpha = self.lsq_scale.value()
return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8)
if self.per_row:
alpha = absmean_scale_per_row(self.weight)
return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8)
return ternary_signs(self.weight)
@torch.no_grad()
def scale(self) -> Tensor:
if self.lsq:
return self.lsq_scale.value()
if self.per_row:
return absmean_scale_per_row(self.weight)
return absmean_scale(self.weight)
@torch.no_grad()
def infer(self, x: Tensor) -> Tensor:
x_ = x.transpose(1, 2)
x_ = F.pad(x_, (self.kernel_size - 1, 0))
if not self.quantize:
y = F.conv1d(x_, self.weight, groups=self.channels)
return y.transpose(1, 2)
trits = self.trits().to(x.dtype)
alpha = self.scale()
if self.per_row:
y = F.conv1d(x_, trits, groups=self.channels) * alpha.view(1, self.channels, 1)
else:
y = alpha * F.conv1d(x_, trits, groups=self.channels)
return y.transpose(1, 2)
# ββ Incremental-decode helpers (KV-cache equivalent for conv) ββββββββ #
# The conv pathway is convolutional, not attention, but it still has a
# "state" you can cache: the last (kernel_size - 1) inputs. A single new
# input plus that buffer is sufficient to compute the next 1-token
# output, identical to running the full conv over the whole prefix.
def empty_buffer(self, batch_size: int, device, dtype) -> Tensor:
"""Zero-init buffer matching what the left-pad would produce."""
return torch.zeros(batch_size, self.kernel_size - 1, self.channels,
device=device, dtype=dtype)
def warmup_buffer(self, x: Tensor) -> Tensor:
"""Build the buffer from the FULL prompt β keep the last (k-1) inputs.
x is (B, L, C). Returns (B, k-1, C) ready to feed forward_incremental."""
L = x.size(1)
k1 = self.kernel_size - 1
if L >= k1:
return x[:, -k1:, :].contiguous()
buf = self.empty_buffer(x.size(0), x.device, x.dtype)
if L > 0:
buf[:, -L:, :] = x
return buf
def forward_incremental(self, x_step: Tensor, buffer: Tensor) -> tuple[Tensor, Tensor]:
"""Step one token through the conv, given the buffered last (k-1) inputs.
Returns (y_step, new_buffer) where y_step is (B, 1, C) and new_buffer
is (B, k-1, C) ready for the next step.
"""
# Concatenate buffer + new token β (B, k, C). Conv with kernel size k
# over a sequence of length k gives a single output.
full = torch.cat([buffer, x_step], dim=1) # (B, k, C)
x_ = full.transpose(1, 2) # (B, C, k)
if not self.quantize:
w = self.weight
else:
w = self._quantize(self.weight)
y = F.conv1d(x_, w, groups=self.channels) # (B, C, 1)
y_step = y.transpose(1, 2) # (B, 1, C)
new_buffer = full[:, 1:, :].contiguous() # drop oldest
return y_step, new_buffer
|