File size: 5,560 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""tilelli.core.ternary_conv β€” depthwise causal 1-D conv with ternary weights.

Depthwise (groups=channels) so input channels per group is 1, making the
Hadamard rotation trivial (identity); we only expose per_row + lsq.
"""
from __future__ import annotations

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from tilelli.core.ternary import (
    LearnableScale,
    absmean_scale,
    absmean_scale_per_row,
    ternarize,
    ternarize_lsq,
    ternarize_per_row,
    ternary_signs,
)


class TernaryCausalConv1d(nn.Module):
    """Depthwise causal 1-D conv with ternary weights and an FP32 shadow param."""

    def __init__(
        self,
        channels: int,
        kernel_size: int = 5,
        quantize: bool = True,
        per_row: bool = False,
        lsq: bool = False,
    ) -> None:
        super().__init__()
        if lsq and per_row:
            raise ValueError("lsq + per_row not supported")
        self.channels = channels
        self.kernel_size = kernel_size
        self.quantize = quantize
        self.per_row = per_row
        self.lsq = lsq
        w = torch.randn(channels, 1, kernel_size) * (1.0 / kernel_size**0.5)
        self.weight = nn.Parameter(w)
        if lsq:
            init_alpha = (w.abs().mean().item() or 1.0)
            self.lsq_scale = LearnableScale(initial=init_alpha)
        else:
            self.lsq_scale = None  # type: ignore[assignment]

    def _quantize(self, w: Tensor) -> Tensor:
        if self.lsq:
            return ternarize_lsq(w, self.lsq_scale.value())
        if self.per_row:
            return ternarize_per_row(w)
        return ternarize(w)

    def forward(self, x: Tensor) -> Tensor:
        if x.dim() != 3:
            raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
        if x.shape[-1] != self.channels:
            raise ValueError(
                f"channel mismatch: module has {self.channels}, input has {x.shape[-1]}"
            )
        x_ = x.transpose(1, 2)
        x_ = F.pad(x_, (self.kernel_size - 1, 0))
        w = self.weight if not self.quantize else self._quantize(self.weight)
        y = F.conv1d(x_, w, groups=self.channels)
        return y.transpose(1, 2)

    @torch.no_grad()
    def trits(self) -> Tensor:
        if self.lsq:
            alpha = self.lsq_scale.value()
            return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8)
        if self.per_row:
            alpha = absmean_scale_per_row(self.weight)
            return torch.round(self.weight / alpha).clamp_(-1.0, 1.0).to(torch.int8)
        return ternary_signs(self.weight)

    @torch.no_grad()
    def scale(self) -> Tensor:
        if self.lsq:
            return self.lsq_scale.value()
        if self.per_row:
            return absmean_scale_per_row(self.weight)
        return absmean_scale(self.weight)

    @torch.no_grad()
    def infer(self, x: Tensor) -> Tensor:
        x_ = x.transpose(1, 2)
        x_ = F.pad(x_, (self.kernel_size - 1, 0))
        if not self.quantize:
            y = F.conv1d(x_, self.weight, groups=self.channels)
            return y.transpose(1, 2)
        trits = self.trits().to(x.dtype)
        alpha = self.scale()
        if self.per_row:
            y = F.conv1d(x_, trits, groups=self.channels) * alpha.view(1, self.channels, 1)
        else:
            y = alpha * F.conv1d(x_, trits, groups=self.channels)
        return y.transpose(1, 2)

    # ── Incremental-decode helpers (KV-cache equivalent for conv) ──────── #
    # The conv pathway is convolutional, not attention, but it still has a
    # "state" you can cache: the last (kernel_size - 1) inputs. A single new
    # input plus that buffer is sufficient to compute the next 1-token
    # output, identical to running the full conv over the whole prefix.

    def empty_buffer(self, batch_size: int, device, dtype) -> Tensor:
        """Zero-init buffer matching what the left-pad would produce."""
        return torch.zeros(batch_size, self.kernel_size - 1, self.channels,
                           device=device, dtype=dtype)

    def warmup_buffer(self, x: Tensor) -> Tensor:
        """Build the buffer from the FULL prompt β€” keep the last (k-1) inputs.
        x is (B, L, C). Returns (B, k-1, C) ready to feed forward_incremental."""
        L = x.size(1)
        k1 = self.kernel_size - 1
        if L >= k1:
            return x[:, -k1:, :].contiguous()
        buf = self.empty_buffer(x.size(0), x.device, x.dtype)
        if L > 0:
            buf[:, -L:, :] = x
        return buf

    def forward_incremental(self, x_step: Tensor, buffer: Tensor) -> tuple[Tensor, Tensor]:
        """Step one token through the conv, given the buffered last (k-1) inputs.
        Returns (y_step, new_buffer) where y_step is (B, 1, C) and new_buffer
        is (B, k-1, C) ready for the next step.
        """
        # Concatenate buffer + new token β†’ (B, k, C). Conv with kernel size k
        # over a sequence of length k gives a single output.
        full = torch.cat([buffer, x_step], dim=1)             # (B, k, C)
        x_ = full.transpose(1, 2)                              # (B, C, k)
        if not self.quantize:
            w = self.weight
        else:
            w = self._quantize(self.weight)
        y = F.conv1d(x_, w, groups=self.channels)              # (B, C, 1)
        y_step = y.transpose(1, 2)                             # (B, 1, C)
        new_buffer = full[:, 1:, :].contiguous()               # drop oldest
        return y_step, new_buffer