| """Minimal inference for the published LeNEPA encoder checkpoint. |
| |
| Published IO contract: |
| - x_waveform: torch.float32 [B, 1, 5000], channel order: ["c0"] |
| - outputs: |
| patch_tokens: [B, 625, 256] |
| embedding: [B, 256] |
| |
| This code intentionally does NOT: |
| - resample / crop / pad inputs |
| - support other checkpoints or architectures |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file as safetensors_load |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| |
| |
| |
|
|
| SAMPLING_FREQUENCY = 1 |
| CHANNELS = ("c0",) |
|
|
| NUM_CHANNELS = 1 |
| CHANNEL_SIZE = 5000 |
| PATCH_SIZE = 8 |
| NUM_PATCHES = 625 |
|
|
| DIM = 256 |
| DEPTH = 8 |
| NUM_HEADS = 4 |
| MLP_RATIO = 4.0 |
|
|
| PATCH_EMBED_CNN_DIM = 192 |
| SCALAR_HIDDEN_DIM = 32 |
| SCALAR_SCALES = (0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0, 10000.0) |
| SCALAR_EPSILON = 1.1 |
|
|
| QKV_BIAS = True |
| BIAS = True |
| NORM_EPS = 1e-6 |
|
|
| ROPE_BASE = 10_000 |
| QK_NORM_EPS = 1e-6 |
|
|
|
|
| @dataclass(frozen=True) |
| class LeNEPAEncoderOutput: |
| """Outputs of the published LeNEPA encoder.""" |
|
|
| patch_tokens: torch.Tensor |
| embedding: torch.Tensor |
|
|
|
|
| class ScalarEncoder(nn.Module): |
| """Affine + LayerNorm scalar encoder used inside the MSSE blocks.""" |
|
|
| def __init__(self, *, k: float, hidden_dim: int, eps: float) -> None: |
| super().__init__() |
| if hidden_dim < 1: |
| raise ValueError(f"hidden_dim must be >= 1, got {hidden_dim}") |
| if eps <= 0: |
| raise ValueError(f"eps must be > 0, got {eps}") |
| self.k = float(k) |
| self.w = nn.Parameter(torch.rand((1, hidden_dim), dtype=torch.float32, requires_grad=True)) |
| self.b = nn.Parameter(torch.rand((1, hidden_dim), dtype=torch.float32, requires_grad=True)) |
| self.norm = nn.LayerNorm(hidden_dim, eps=eps) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| if x.size(-1) != 1: |
| raise ValueError(f"Expected x[..., 1], got x.shape={tuple(x.shape)}") |
| z = x * self.w + self.k * self.b |
| return self.norm(z) |
|
|
|
|
| class MultiScaledScalarEncoder(nn.Module): |
| """Blend per-scale scalar encoders based on the scalar magnitude.""" |
|
|
| def __init__( |
| self, |
| *, |
| scales: tuple[float, ...], |
| hidden_dim: int, |
| epsilon: float, |
| eps: float, |
| ) -> None: |
| super().__init__() |
| if not scales: |
| raise ValueError("scales must be non-empty") |
| if any(scale <= 0 for scale in scales): |
| raise ValueError(f"All scales must be > 0, got scales={scales}") |
| if hidden_dim < 1: |
| raise ValueError(f"hidden_dim must be >= 1, got {hidden_dim}") |
| if epsilon <= 0: |
| raise ValueError(f"epsilon must be > 0, got {epsilon}") |
| if eps <= 0: |
| raise ValueError(f"eps must be > 0, got {eps}") |
|
|
| scales_t = torch.tensor(scales, dtype=torch.float32) |
| self.register_buffer("scales", scales_t, persistent=False) |
| self.epsilon = float(epsilon) |
| self.encoders = nn.ModuleList( |
| [ScalarEncoder(k=float(scale), hidden_dim=hidden_dim, eps=eps) for scale in scales] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| if x.size(-1) != 1: |
| raise ValueError(f"Expected x[..., 1], got x.shape={tuple(x.shape)}") |
|
|
| x_abs = x.abs().to(dtype=torch.float32) |
| scales = self.scales.to(device=x.device) |
| inv_scales = (1.0 / scales).reshape(1, -1) |
| ratio = torch.matmul(x_abs, inv_scales) |
| alpha = (1.0 / torch.log(ratio + self.epsilon)).abs() |
| alpha = alpha / alpha.sum(dim=-1, keepdim=True) |
| alpha = alpha.unsqueeze(-1) |
|
|
| encoded_list = [encoder(x) for encoder in self.encoders] |
| encoded = torch.stack(encoded_list, dim=-2) |
| mixed = (encoded.to(dtype=torch.float32) * alpha).sum(dim=-2) |
| return mixed.to(dtype=x.dtype) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary positional embeddings (RoPE) applied to Q/K.""" |
|
|
| def __init__(self, *, dim: int, base: int) -> None: |
| super().__init__() |
| if dim % 2 != 0: |
| raise ValueError(f"RoPE requires even head_dim, got {dim}") |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._seq_len_cached: int | None = None |
| self._cos_cached: torch.Tensor | None = None |
| self._sin_cached: torch.Tensor | None = None |
| self._device_cached: torch.device | None = None |
| self._dtype_cached: torch.dtype | None = None |
|
|
| def _build_cache(self, *, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: |
| positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("i,j->ij", positions, self.inv_freq) |
| self._cos_cached = freqs.cos().to(dtype) |
| self._sin_cached = freqs.sin().to(dtype) |
| self._seq_len_cached = seq_len |
| self._device_cached = device |
| self._dtype_cached = dtype |
|
|
| def _get_cos_sin( |
| self, *, seq_len: int, device: torch.device, dtype: torch.dtype |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if ( |
| self._cos_cached is None |
| or self._sin_cached is None |
| or self._seq_len_cached != seq_len |
| or self._device_cached != device |
| or self._dtype_cached != dtype |
| ): |
| self._build_cache(seq_len=seq_len, device=device, dtype=dtype) |
| if self._cos_cached is None or self._sin_cached is None: |
| raise RuntimeError("RoPE cache was not built; this is a bug") |
| return self._cos_cached, self._sin_cached |
|
|
| def _apply_rotary(self, x: torch.Tensor, *, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| |
| B, H, T, Dh = x.shape |
| x_2 = x.view(B, H, T, Dh // 2, 2) |
| x1 = x_2[..., 0] |
| x2 = x_2[..., 1] |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| out1 = x1 * cos - x2 * sin |
| out2 = x1 * sin + x2 * cos |
| return torch.stack((out1, out2), dim=-1).flatten(-2) |
|
|
| def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| cos, sin = self._get_cos_sin(seq_len=q.size(-2), device=q.device, dtype=q.dtype) |
| return self._apply_rotary(q, cos=cos, sin=sin), self._apply_rotary(k, cos=cos, sin=sin) |
|
|
|
|
| class Attention(nn.Module): |
| """Causal self-attention with RoPE + QK-Norm (no dropout).""" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| if DIM % NUM_HEADS != 0: |
| raise ValueError(f"DIM must be divisible by NUM_HEADS, got DIM={DIM} NUM_HEADS={NUM_HEADS}") |
| head_dim = DIM // NUM_HEADS |
| self.num_heads = NUM_HEADS |
| self.rope = RotaryEmbedding(dim=head_dim, base=ROPE_BASE) |
| self.qk_norm = nn.LayerNorm(head_dim, eps=QK_NORM_EPS, elementwise_affine=False) |
| self.qkv = nn.Linear(DIM, DIM * 3, bias=QKV_BIAS) |
| self.proj = nn.Linear(DIM, DIM, bias=BIAS) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| B, T, D = x.shape |
| qkv = ( |
| self.qkv(x) |
| .reshape(B, T, 3, self.num_heads, D // self.num_heads) |
| .permute(2, 0, 3, 1, 4) |
| ) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| q, k = self.rope.apply(q, k) |
| q = self.qk_norm(q) |
| k = self.qk_norm(k) |
| attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) |
| out = attn.transpose(1, 2).reshape(B, T, D) |
| return self.proj(out) |
|
|
|
|
| class GatedMLP(nn.Module): |
| """SwiGLU MLP used in this checkpoint.""" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| hidden_dim = int((2 / 3) * MLP_RATIO * DIM) |
| if hidden_dim <= 0: |
| raise ValueError(f"hidden_dim must be > 0, got {hidden_dim}") |
| self.fc1 = nn.Linear(DIM, hidden_dim * 2, bias=BIAS) |
| self.fc2 = nn.Linear(hidden_dim, DIM, bias=BIAS) |
| self.act = nn.SiLU() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| gate_and_value = self.fc1(x) |
| gate, value = gate_and_value.chunk(2, dim=-1) |
| return self.fc2(self.act(gate) * value) |
|
|
|
|
| class Block(nn.Module): |
| """Transformer block: LN -> Attn -> residual -> LN -> MLP -> residual.""" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self.norm1 = nn.LayerNorm(DIM, eps=NORM_EPS) |
| self.attn = Attention() |
| self.norm2 = nn.LayerNorm(DIM, eps=NORM_EPS) |
| self.mlp = GatedMLP() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class PatchEmbedding(nn.Module): |
| """Conv patch embedding: Conv1d(C->D_cnn, kernel=stride=patch_size).""" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self.proj = nn.Conv1d( |
| in_channels=NUM_CHANNELS, |
| out_channels=PATCH_EMBED_CNN_DIM, |
| kernel_size=PATCH_SIZE, |
| stride=PATCH_SIZE, |
| bias=BIAS, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| z_t = self.proj(x) |
| return z_t.transpose(1, 2) |
|
|
|
|
| class LeNEPAEncoder(nn.Module): |
| """LeNEPA encoder trunk for this exact checkpoint.""" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| if CHANNEL_SIZE % PATCH_SIZE != 0: |
| raise ValueError("CHANNEL_SIZE must be divisible by PATCH_SIZE") |
| self.nepa_patch_embed_mean_encoder = MultiScaledScalarEncoder( |
| scales=SCALAR_SCALES, |
| hidden_dim=SCALAR_HIDDEN_DIM, |
| epsilon=SCALAR_EPSILON, |
| eps=NORM_EPS, |
| ) |
| self.nepa_patch_embed_std_encoder = MultiScaledScalarEncoder( |
| scales=SCALAR_SCALES, |
| hidden_dim=SCALAR_HIDDEN_DIM, |
| epsilon=SCALAR_EPSILON, |
| eps=NORM_EPS, |
| ) |
| self.patch_embed = PatchEmbedding() |
| self.blocks = nn.ModuleList([Block() for _ in range(DEPTH)]) |
| self.norm = nn.LayerNorm(DIM, eps=NORM_EPS) |
|
|
| def _tokenize(self, x: torch.Tensor) -> torch.Tensor: |
| |
| B, C, L = x.shape |
| if C != NUM_CHANNELS or L != CHANNEL_SIZE: |
| raise ValueError( |
| "Input must match the published contract: " |
| f"expected [B, {NUM_CHANNELS}, {CHANNEL_SIZE}], got {tuple(x.shape)}" |
| ) |
| if L % PATCH_SIZE != 0: |
| raise ValueError(f"Expected L divisible by PATCH_SIZE, got L={L}, PATCH_SIZE={PATCH_SIZE}") |
|
|
| T = L // PATCH_SIZE |
| x_f32 = x.to(dtype=torch.float32) |
| x_patches = x_f32.reshape(B, C, T, PATCH_SIZE) |
| mean_patch = x_patches.mean(dim=(1, 3)) |
| std_patch = x_patches.std(dim=(1, 3), unbiased=False) |
| mean_broadcast = mean_patch[:, None, :, None] |
| std_broadcast = std_patch[:, None, :, None] |
| x_norm_patches = (x_patches - mean_broadcast) / (std_broadcast + NORM_EPS) |
| x_norm = x_norm_patches.reshape(B, C, L).to(dtype=x.dtype) |
|
|
| z_cnn = self.patch_embed(x_norm) |
| e_mean = self.nepa_patch_embed_mean_encoder(mean_patch.unsqueeze(-1)) |
| e_std = self.nepa_patch_embed_std_encoder(std_patch.unsqueeze(-1)) |
| tokens = torch.cat([z_cnn, e_mean.to(dtype=z_cnn.dtype), e_std.to(dtype=z_cnn.dtype)], dim=-1) |
| if tokens.size(-1) != DIM: |
| raise RuntimeError( |
| "Tokenizer produced unexpected dim. " |
| f"Got tokens.shape={tuple(tokens.shape)}, expected last dim={DIM}" |
| ) |
| return tokens |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| z = self._tokenize(x) |
| for block in self.blocks: |
| z = block(z) |
| return self.norm(z) |
|
|
|
|
| @torch.inference_mode() |
| def encode_lenepa(*, model: LeNEPAEncoder, x_waveform: torch.Tensor) -> LeNEPAEncoderOutput: |
| """Encode a batch of waveforms.""" |
| if x_waveform.dtype != torch.float32: |
| raise ValueError(f"x_waveform must be float32, got {x_waveform.dtype}") |
| if x_waveform.dim() != 3: |
| raise ValueError(f"x_waveform must be [B, C, L], got {tuple(x_waveform.shape)}") |
| model_device = next(model.parameters()).device |
| if x_waveform.device != model_device: |
| raise ValueError( |
| "x_waveform must be on the same device as the model. " |
| f"x_waveform.device={x_waveform.device} model.device={model_device}" |
| ) |
|
|
| patch_tokens = model(x_waveform) |
| embedding = patch_tokens.mean(dim=1) |
| return LeNEPAEncoderOutput(patch_tokens=patch_tokens, embedding=embedding) |
|
|
|
|
| def load_lenepa_encoder(*, weights_path: Path, device: torch.device) -> LeNEPAEncoder: |
| """Load the published encoder weights from a safetensors file.""" |
| if not weights_path.is_file(): |
| raise ValueError(f"weights_path does not exist: {str(weights_path)!r}") |
| state = safetensors_load(str(weights_path)) |
| model = LeNEPAEncoder() |
| model.load_state_dict(state, strict=True) |
| model.eval() |
| model.requires_grad_(False) |
| return model.to(device) |
|
|
|
|
| def _smoke_test() -> None: |
| """Small end-to-end smoke test (random input, prints output shapes).""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| here = Path(__file__).resolve().parent |
| model = load_lenepa_encoder(weights_path=here / "lenepa_encoder.safetensors", device=device) |
| x = torch.randn(2, 1, 5000, device=device, dtype=torch.float32) |
| out = encode_lenepa(model=model, x_waveform=x) |
| print("patch_tokens", tuple(out.patch_tokens.shape)) |
| print("embedding", tuple(out.embedding.shape)) |
|
|
|
|
| if __name__ == "__main__": |
| _smoke_test() |
|
|