File size: 10,147 Bytes
c72c956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""Minimal inference for the published LeNEPA *encoder* checkpoint (no projector).

Published IO contract:
  - x_waveform: torch.float32 [B, 1, 5000] at 500 Hz, channel order: ["I"]
  - outputs:
      patch_tokens: [B, 200, 192]
      embedding:    [B, 192]

This code intentionally does NOT:
  - resample / crop / pad / normalize inputs
  - support other checkpoints or architectures
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import torch
from safetensors.torch import load_file as safetensors_load
from torch import nn
from torch.nn import functional as F

# -----------------------------
# Published constants (no knobs)
# -----------------------------

SAMPLING_FREQUENCY_HZ = 500
CHANNELS = ("I",)

NUM_CHANNELS = 1
CHANNEL_SIZE = 5000
PATCH_SIZE = 25
NUM_PATCHES = 200  # 5000 / 25

DIM = 192
DEPTH = 8
NUM_HEADS = 4
MLP_RATIO = 4.0

QKV_BIAS = True
NORM_EPS = 1e-6

ROPE_BASE = 10_000
QK_NORM_EPS = 1e-6


@dataclass(frozen=True)
class LeNEPAEncoderOutput:
    """Outputs of the published LeNEPA encoder."""

    patch_tokens: torch.Tensor  # [B, T=200, D=192]
    embedding: torch.Tensor  # [B, D=192]


class RotaryEmbedding(nn.Module):
    """Rotary positional embeddings (RoPE) applied to Q/K."""

    def __init__(self, *, dim: int, base: int) -> None:
        super().__init__()
        if dim % 2 != 0:
            raise ValueError(f"RoPE requires even head_dim, got {dim}")
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))  # [Dh/2]
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached: int | None = None
        self._cos_cached: torch.Tensor | None = None
        self._sin_cached: torch.Tensor | None = None
        self._device_cached: torch.device | None = None
        self._dtype_cached: torch.dtype | None = None

    def _build_cache(self, *, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
        positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)  # [T]
        freqs = torch.einsum("i,j->ij", positions, self.inv_freq)  # [T, Dh/2]
        self._cos_cached = freqs.cos().to(dtype)  # [T, Dh/2]
        self._sin_cached = freqs.sin().to(dtype)  # [T, Dh/2]
        self._seq_len_cached = seq_len
        self._device_cached = device
        self._dtype_cached = dtype

    def _get_cos_sin(
        self, *, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if (
            self._cos_cached is None
            or self._sin_cached is None
            or self._seq_len_cached != seq_len
            or self._device_cached != device
            or self._dtype_cached != dtype
        ):
            self._build_cache(seq_len=seq_len, device=device, dtype=dtype)
        if self._cos_cached is None or self._sin_cached is None:
            raise RuntimeError("RoPE cache was not built; this is a bug")
        return self._cos_cached, self._sin_cached

    def _apply_rotary(self, x: torch.Tensor, *, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        # x: [B, H, T, Dh]
        B, H, T, Dh = x.shape
        x_2 = x.view(B, H, T, Dh // 2, 2)  # [B, H, T, Dh/2, 2]
        x1 = x_2[..., 0]  # [B, H, T, Dh/2]
        x2 = x_2[..., 1]  # [B, H, T, Dh/2]
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, T, Dh/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, T, Dh/2]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        return torch.stack((out1, out2), dim=-1).flatten(-2)  # [B, H, T, Dh]

    def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply RoPE to Q/K."""
        cos, sin = self._get_cos_sin(seq_len=q.size(-2), device=q.device, dtype=q.dtype)  # [T, Dh/2]
        return self._apply_rotary(q, cos=cos, sin=sin), self._apply_rotary(k, cos=cos, sin=sin)


class Attention(nn.Module):
    """Causal self-attention with RoPE + QK-Norm (no dropout)."""

    def __init__(self) -> None:
        super().__init__()
        if DIM % NUM_HEADS != 0:
            raise ValueError(f"DIM must be divisible by NUM_HEADS, got DIM={DIM} NUM_HEADS={NUM_HEADS}")
        head_dim = DIM // NUM_HEADS
        self.num_heads = NUM_HEADS
        self.rope = RotaryEmbedding(dim=head_dim, base=ROPE_BASE)
        self.qk_norm = nn.LayerNorm(head_dim, eps=QK_NORM_EPS, elementwise_affine=False)
        self.qkv = nn.Linear(DIM, DIM * 3, bias=QKV_BIAS)
        self.proj = nn.Linear(DIM, DIM, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        B, T, D = x.shape
        qkv = (
            self.qkv(x)  # [B, T, 3*D]
            .reshape(B, T, 3, self.num_heads, D // self.num_heads)  # [B, T, 3, H, Dh]
            .permute(2, 0, 3, 1, 4)  # [3, B, H, T, Dh]
        )
        q, k, v = qkv[0], qkv[1], qkv[2]  # each [B, H, T, Dh]
        q, k = self.rope.apply(q, k)  # [B, H, T, Dh] each
        q = self.qk_norm(q)  # [B, H, T, Dh]
        k = self.qk_norm(k)  # [B, H, T, Dh]
        attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)  # [B, H, T, Dh]
        out = attn.transpose(1, 2).reshape(B, T, D)  # [B, T, D]
        return self.proj(out)  # [B, T, D]


class GatedMLP(nn.Module):
    """SwiGLU MLP used in this checkpoint."""

    def __init__(self) -> None:
        super().__init__()
        hidden_dim = int((2 / 3) * MLP_RATIO * DIM)
        if hidden_dim <= 0:
            raise ValueError(f"hidden_dim must be > 0, got {hidden_dim}")
        self.fc1 = nn.Linear(DIM, hidden_dim * 2, bias=True)
        self.fc2 = nn.Linear(hidden_dim, DIM, bias=True)
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        gate_and_value = self.fc1(x)  # [B, T, 2H]
        gate, value = gate_and_value.chunk(2, dim=-1)  # each [B, T, H]
        return self.fc2(self.act(gate) * value)  # [B, T, D]


class Block(nn.Module):
    """Transformer block: LN -> Attn -> residual -> LN -> MLP -> residual."""

    def __init__(self) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm(DIM, eps=NORM_EPS)
        self.attn = Attention()
        self.norm2 = nn.LayerNorm(DIM, eps=NORM_EPS)
        self.mlp = GatedMLP()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        x = x + self.attn(self.norm1(x))  # [B, T, D]
        x = x + self.mlp(self.norm2(x))  # [B, T, D]
        return x


class PatchEmbedding(nn.Module):
    """Conv patch embedding: Conv1d(C->D, kernel=stride=patch_size)."""

    def __init__(self) -> None:
        super().__init__()
        self.proj = nn.Conv1d(
            in_channels=NUM_CHANNELS,
            out_channels=DIM,
            kernel_size=PATCH_SIZE,
            stride=PATCH_SIZE,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, C, L]
        z_t = self.proj(x)  # [B, D, T]
        return z_t.transpose(1, 2)  # [B, T, D]


class LeNEPAEncoder(nn.Module):
    """LeNEPA encoder trunk for this exact checkpoint (static conv patch embed, causal)."""

    def __init__(self) -> None:
        super().__init__()
        if CHANNEL_SIZE % PATCH_SIZE != 0:
            raise ValueError("CHANNEL_SIZE must be divisible by PATCH_SIZE")
        self.patch_embed = PatchEmbedding()
        self.blocks = nn.ModuleList([Block() for _ in range(DEPTH)])
        self.norm = nn.LayerNorm(DIM, eps=NORM_EPS)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Return final-layer patch tokens (post-final-norm)."""
        z = self.patch_embed(x)  # [B, T, D]
        for block in self.blocks:
            z = block(z)  # [B, T, D]
        return self.norm(z)  # [B, T, D]


@torch.inference_mode()
def encode_lenepa(*, model: LeNEPAEncoder, x_waveform: torch.Tensor) -> LeNEPAEncoderOutput:
    """Encode a batch of waveforms.

    Args:
      model: LeNEPA encoder (on the same device as x_waveform).
      x_waveform: [B, 1, 5000] float32.
    """
    if x_waveform.dtype is not torch.float32:
        raise ValueError(f"x_waveform must be float32, got {x_waveform.dtype}")
    if x_waveform.dim() != 3:
        raise ValueError(f"x_waveform must be [B, C, L], got {tuple(x_waveform.shape)}")
    B, C, L = x_waveform.shape
    if C != NUM_CHANNELS or L != CHANNEL_SIZE:
        raise ValueError(
            "Input must match the published contract: "
            f"expected [B, {NUM_CHANNELS}, {CHANNEL_SIZE}], got {tuple(x_waveform.shape)}"
        )
    model_device = next(model.parameters()).device
    if x_waveform.device != model_device:
        raise ValueError(
            "x_waveform must be on the same device as the model. "
            f"x_waveform.device={x_waveform.device} model.device={model_device}"
        )

    patch_tokens = model(x_waveform)  # [B, T, D]
    embedding = patch_tokens.mean(dim=1)  # [B, D]
    return LeNEPAEncoderOutput(patch_tokens=patch_tokens, embedding=embedding)


def load_lenepa_encoder(*, weights_path: Path, device: torch.device) -> LeNEPAEncoder:
    """Load the published encoder weights from a safetensors file."""
    if not weights_path.is_file():
        raise ValueError(f"weights_path does not exist: {str(weights_path)!r}")
    state = safetensors_load(str(weights_path))
    model = LeNEPAEncoder()
    model.load_state_dict(state, strict=True)
    model.eval()
    model.requires_grad_(False)
    return model.to(device)


def _smoke_test() -> None:
    """Small end-to-end smoke test (random input, prints output shapes)."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    here = Path(__file__).resolve().parent
    model = load_lenepa_encoder(weights_path=here / "lenepa_encoder.safetensors", device=device)
    x = torch.randn(2, 1, 5000, device=device, dtype=torch.float32)  # [B=2, C=1, L=5000]
    out = encode_lenepa(model=model, x_waveform=x)
    print("patch_tokens", tuple(out.patch_tokens.shape))
    print("embedding", tuple(out.embedding.shape))


if __name__ == "__main__":
    _smoke_test()