| | """Minimal inference for the published LeNEPA *encoder* checkpoint (no projector). |
| | |
| | Published IO contract: |
| | - x_waveform: torch.float32 [B, 1, 5000] at 500 Hz, channel order: ["I"] |
| | - outputs: |
| | patch_tokens: [B, 200, 192] |
| | embedding: [B, 192] |
| | |
| | This code intentionally does NOT: |
| | - resample / crop / pad / normalize inputs |
| | - support other checkpoints or architectures |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from pathlib import Path |
| |
|
| | import torch |
| | from safetensors.torch import load_file as safetensors_load |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | |
| | |
| | |
| |
|
| | SAMPLING_FREQUENCY_HZ = 500 |
| | CHANNELS = ("I",) |
| |
|
| | NUM_CHANNELS = 1 |
| | CHANNEL_SIZE = 5000 |
| | PATCH_SIZE = 25 |
| | NUM_PATCHES = 200 |
| |
|
| | DIM = 192 |
| | DEPTH = 8 |
| | NUM_HEADS = 4 |
| | MLP_RATIO = 4.0 |
| |
|
| | QKV_BIAS = True |
| | NORM_EPS = 1e-6 |
| |
|
| | ROPE_BASE = 10_000 |
| | QK_NORM_EPS = 1e-6 |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class LeNEPAEncoderOutput: |
| | """Outputs of the published LeNEPA encoder.""" |
| |
|
| | patch_tokens: torch.Tensor |
| | embedding: torch.Tensor |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary positional embeddings (RoPE) applied to Q/K.""" |
| |
|
| | def __init__(self, *, dim: int, base: int) -> None: |
| | super().__init__() |
| | if dim % 2 != 0: |
| | raise ValueError(f"RoPE requires even head_dim, got {dim}") |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| | self._seq_len_cached: int | None = None |
| | self._cos_cached: torch.Tensor | None = None |
| | self._sin_cached: torch.Tensor | None = None |
| | self._device_cached: torch.device | None = None |
| | self._dtype_cached: torch.dtype | None = None |
| |
|
| | def _build_cache(self, *, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: |
| | positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| | freqs = torch.einsum("i,j->ij", positions, self.inv_freq) |
| | self._cos_cached = freqs.cos().to(dtype) |
| | self._sin_cached = freqs.sin().to(dtype) |
| | self._seq_len_cached = seq_len |
| | self._device_cached = device |
| | self._dtype_cached = dtype |
| |
|
| | def _get_cos_sin( |
| | self, *, seq_len: int, device: torch.device, dtype: torch.dtype |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | if ( |
| | self._cos_cached is None |
| | or self._sin_cached is None |
| | or self._seq_len_cached != seq_len |
| | or self._device_cached != device |
| | or self._dtype_cached != dtype |
| | ): |
| | self._build_cache(seq_len=seq_len, device=device, dtype=dtype) |
| | if self._cos_cached is None or self._sin_cached is None: |
| | raise RuntimeError("RoPE cache was not built; this is a bug") |
| | return self._cos_cached, self._sin_cached |
| |
|
| | def _apply_rotary(self, x: torch.Tensor, *, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| | |
| | B, H, T, Dh = x.shape |
| | x_2 = x.view(B, H, T, Dh // 2, 2) |
| | x1 = x_2[..., 0] |
| | x2 = x_2[..., 1] |
| | cos = cos.unsqueeze(0).unsqueeze(0) |
| | sin = sin.unsqueeze(0).unsqueeze(0) |
| | out1 = x1 * cos - x2 * sin |
| | out2 = x1 * sin + x2 * cos |
| | return torch.stack((out1, out2), dim=-1).flatten(-2) |
| |
|
| | def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Apply RoPE to Q/K.""" |
| | cos, sin = self._get_cos_sin(seq_len=q.size(-2), device=q.device, dtype=q.dtype) |
| | return self._apply_rotary(q, cos=cos, sin=sin), self._apply_rotary(k, cos=cos, sin=sin) |
| |
|
| |
|
| | class Attention(nn.Module): |
| | """Causal self-attention with RoPE + QK-Norm (no dropout).""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | if DIM % NUM_HEADS != 0: |
| | raise ValueError(f"DIM must be divisible by NUM_HEADS, got DIM={DIM} NUM_HEADS={NUM_HEADS}") |
| | head_dim = DIM // NUM_HEADS |
| | self.num_heads = NUM_HEADS |
| | self.rope = RotaryEmbedding(dim=head_dim, base=ROPE_BASE) |
| | self.qk_norm = nn.LayerNorm(head_dim, eps=QK_NORM_EPS, elementwise_affine=False) |
| | self.qkv = nn.Linear(DIM, DIM * 3, bias=QKV_BIAS) |
| | self.proj = nn.Linear(DIM, DIM, bias=True) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | B, T, D = x.shape |
| | qkv = ( |
| | self.qkv(x) |
| | .reshape(B, T, 3, self.num_heads, D // self.num_heads) |
| | .permute(2, 0, 3, 1, 4) |
| | ) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | q, k = self.rope.apply(q, k) |
| | q = self.qk_norm(q) |
| | k = self.qk_norm(k) |
| | attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) |
| | out = attn.transpose(1, 2).reshape(B, T, D) |
| | return self.proj(out) |
| |
|
| |
|
| | class GatedMLP(nn.Module): |
| | """SwiGLU MLP used in this checkpoint.""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | hidden_dim = int((2 / 3) * MLP_RATIO * DIM) |
| | if hidden_dim <= 0: |
| | raise ValueError(f"hidden_dim must be > 0, got {hidden_dim}") |
| | self.fc1 = nn.Linear(DIM, hidden_dim * 2, bias=True) |
| | self.fc2 = nn.Linear(hidden_dim, DIM, bias=True) |
| | self.act = nn.SiLU() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | gate_and_value = self.fc1(x) |
| | gate, value = gate_and_value.chunk(2, dim=-1) |
| | return self.fc2(self.act(gate) * value) |
| |
|
| |
|
| | class Block(nn.Module): |
| | """Transformer block: LN -> Attn -> residual -> LN -> MLP -> residual.""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(DIM, eps=NORM_EPS) |
| | self.attn = Attention() |
| | self.norm2 = nn.LayerNorm(DIM, eps=NORM_EPS) |
| | self.mlp = GatedMLP() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | x = x + self.attn(self.norm1(x)) |
| | x = x + self.mlp(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class PatchEmbedding(nn.Module): |
| | """Conv patch embedding: Conv1d(C->D, kernel=stride=patch_size).""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.proj = nn.Conv1d( |
| | in_channels=NUM_CHANNELS, |
| | out_channels=DIM, |
| | kernel_size=PATCH_SIZE, |
| | stride=PATCH_SIZE, |
| | bias=True, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | z_t = self.proj(x) |
| | return z_t.transpose(1, 2) |
| |
|
| |
|
| | class LeNEPAEncoder(nn.Module): |
| | """LeNEPA encoder trunk for this exact checkpoint (static conv patch embed, causal).""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | if CHANNEL_SIZE % PATCH_SIZE != 0: |
| | raise ValueError("CHANNEL_SIZE must be divisible by PATCH_SIZE") |
| | self.patch_embed = PatchEmbedding() |
| | self.blocks = nn.ModuleList([Block() for _ in range(DEPTH)]) |
| | self.norm = nn.LayerNorm(DIM, eps=NORM_EPS) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Return final-layer patch tokens (post-final-norm).""" |
| | z = self.patch_embed(x) |
| | for block in self.blocks: |
| | z = block(z) |
| | return self.norm(z) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def encode_lenepa(*, model: LeNEPAEncoder, x_waveform: torch.Tensor) -> LeNEPAEncoderOutput: |
| | """Encode a batch of waveforms. |
| | |
| | Args: |
| | model: LeNEPA encoder (on the same device as x_waveform). |
| | x_waveform: [B, 1, 5000] float32. |
| | """ |
| | if x_waveform.dtype is not torch.float32: |
| | raise ValueError(f"x_waveform must be float32, got {x_waveform.dtype}") |
| | if x_waveform.dim() != 3: |
| | raise ValueError(f"x_waveform must be [B, C, L], got {tuple(x_waveform.shape)}") |
| | B, C, L = x_waveform.shape |
| | if C != NUM_CHANNELS or L != CHANNEL_SIZE: |
| | raise ValueError( |
| | "Input must match the published contract: " |
| | f"expected [B, {NUM_CHANNELS}, {CHANNEL_SIZE}], got {tuple(x_waveform.shape)}" |
| | ) |
| | model_device = next(model.parameters()).device |
| | if x_waveform.device != model_device: |
| | raise ValueError( |
| | "x_waveform must be on the same device as the model. " |
| | f"x_waveform.device={x_waveform.device} model.device={model_device}" |
| | ) |
| |
|
| | patch_tokens = model(x_waveform) |
| | embedding = patch_tokens.mean(dim=1) |
| | return LeNEPAEncoderOutput(patch_tokens=patch_tokens, embedding=embedding) |
| |
|
| |
|
| | def load_lenepa_encoder(*, weights_path: Path, device: torch.device) -> LeNEPAEncoder: |
| | """Load the published encoder weights from a safetensors file.""" |
| | if not weights_path.is_file(): |
| | raise ValueError(f"weights_path does not exist: {str(weights_path)!r}") |
| | state = safetensors_load(str(weights_path)) |
| | model = LeNEPAEncoder() |
| | model.load_state_dict(state, strict=True) |
| | model.eval() |
| | model.requires_grad_(False) |
| | return model.to(device) |
| |
|
| |
|
| | def _smoke_test() -> None: |
| | """Small end-to-end smoke test (random input, prints output shapes).""" |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | here = Path(__file__).resolve().parent |
| | model = load_lenepa_encoder(weights_path=here / "lenepa_encoder.safetensors", device=device) |
| | x = torch.randn(2, 1, 5000, device=device, dtype=torch.float32) |
| | out = encode_lenepa(model=model, x_waveform=x) |
| | print("patch_tokens", tuple(out.patch_tokens.shape)) |
| | print("embedding", tuple(out.embedding.shape)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | _smoke_test() |
| |
|