"""Minimal inference for the published LeNEPA *encoder* checkpoint (no projector). Published IO contract: - x_waveform: torch.float32 [B, 1, 5000] at 500 Hz, channel order: ["I"] - outputs: patch_tokens: [B, 200, 192] embedding: [B, 192] This code intentionally does NOT: - resample / crop / pad / normalize inputs - support other checkpoints or architectures """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path import torch from safetensors.torch import load_file as safetensors_load from torch import nn from torch.nn import functional as F # ----------------------------- # Published constants (no knobs) # ----------------------------- SAMPLING_FREQUENCY_HZ = 500 CHANNELS = ("I",) NUM_CHANNELS = 1 CHANNEL_SIZE = 5000 PATCH_SIZE = 25 NUM_PATCHES = 200 # 5000 / 25 DIM = 192 DEPTH = 8 NUM_HEADS = 4 MLP_RATIO = 4.0 QKV_BIAS = True NORM_EPS = 1e-6 ROPE_BASE = 10_000 QK_NORM_EPS = 1e-6 @dataclass(frozen=True) class LeNEPAEncoderOutput: """Outputs of the published LeNEPA encoder.""" patch_tokens: torch.Tensor # [B, T=200, D=192] embedding: torch.Tensor # [B, D=192] class RotaryEmbedding(nn.Module): """Rotary positional embeddings (RoPE) applied to Q/K.""" def __init__(self, *, dim: int, base: int) -> None: super().__init__() if dim % 2 != 0: raise ValueError(f"RoPE requires even head_dim, got {dim}") inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) # [Dh/2] self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached: int | None = None self._cos_cached: torch.Tensor | None = None self._sin_cached: torch.Tensor | None = None self._device_cached: torch.device | None = None self._dtype_cached: torch.dtype | None = None def _build_cache(self, *, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [T] freqs = torch.einsum("i,j->ij", positions, self.inv_freq) # [T, Dh/2] self._cos_cached = freqs.cos().to(dtype) # [T, Dh/2] self._sin_cached = freqs.sin().to(dtype) # [T, Dh/2] self._seq_len_cached = seq_len self._device_cached = device self._dtype_cached = dtype def _get_cos_sin( self, *, seq_len: int, device: torch.device, dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: if ( self._cos_cached is None or self._sin_cached is None or self._seq_len_cached != seq_len or self._device_cached != device or self._dtype_cached != dtype ): self._build_cache(seq_len=seq_len, device=device, dtype=dtype) if self._cos_cached is None or self._sin_cached is None: raise RuntimeError("RoPE cache was not built; this is a bug") return self._cos_cached, self._sin_cached def _apply_rotary(self, x: torch.Tensor, *, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: # x: [B, H, T, Dh] B, H, T, Dh = x.shape x_2 = x.view(B, H, T, Dh // 2, 2) # [B, H, T, Dh/2, 2] x1 = x_2[..., 0] # [B, H, T, Dh/2] x2 = x_2[..., 1] # [B, H, T, Dh/2] cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, T, Dh/2] sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, T, Dh/2] out1 = x1 * cos - x2 * sin out2 = x1 * sin + x2 * cos return torch.stack((out1, out2), dim=-1).flatten(-2) # [B, H, T, Dh] def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Apply RoPE to Q/K.""" cos, sin = self._get_cos_sin(seq_len=q.size(-2), device=q.device, dtype=q.dtype) # [T, Dh/2] return self._apply_rotary(q, cos=cos, sin=sin), self._apply_rotary(k, cos=cos, sin=sin) class Attention(nn.Module): """Causal self-attention with RoPE + QK-Norm (no dropout).""" def __init__(self) -> None: super().__init__() if DIM % NUM_HEADS != 0: raise ValueError(f"DIM must be divisible by NUM_HEADS, got DIM={DIM} NUM_HEADS={NUM_HEADS}") head_dim = DIM // NUM_HEADS self.num_heads = NUM_HEADS self.rope = RotaryEmbedding(dim=head_dim, base=ROPE_BASE) self.qk_norm = nn.LayerNorm(head_dim, eps=QK_NORM_EPS, elementwise_affine=False) self.qkv = nn.Linear(DIM, DIM * 3, bias=QKV_BIAS) self.proj = nn.Linear(DIM, DIM, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, T, D] B, T, D = x.shape qkv = ( self.qkv(x) # [B, T, 3*D] .reshape(B, T, 3, self.num_heads, D // self.num_heads) # [B, T, 3, H, Dh] .permute(2, 0, 3, 1, 4) # [3, B, H, T, Dh] ) q, k, v = qkv[0], qkv[1], qkv[2] # each [B, H, T, Dh] q, k = self.rope.apply(q, k) # [B, H, T, Dh] each q = self.qk_norm(q) # [B, H, T, Dh] k = self.qk_norm(k) # [B, H, T, Dh] attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) # [B, H, T, Dh] out = attn.transpose(1, 2).reshape(B, T, D) # [B, T, D] return self.proj(out) # [B, T, D] class GatedMLP(nn.Module): """SwiGLU MLP used in this checkpoint.""" def __init__(self) -> None: super().__init__() hidden_dim = int((2 / 3) * MLP_RATIO * DIM) if hidden_dim <= 0: raise ValueError(f"hidden_dim must be > 0, got {hidden_dim}") self.fc1 = nn.Linear(DIM, hidden_dim * 2, bias=True) self.fc2 = nn.Linear(hidden_dim, DIM, bias=True) self.act = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, T, D] gate_and_value = self.fc1(x) # [B, T, 2H] gate, value = gate_and_value.chunk(2, dim=-1) # each [B, T, H] return self.fc2(self.act(gate) * value) # [B, T, D] class Block(nn.Module): """Transformer block: LN -> Attn -> residual -> LN -> MLP -> residual.""" def __init__(self) -> None: super().__init__() self.norm1 = nn.LayerNorm(DIM, eps=NORM_EPS) self.attn = Attention() self.norm2 = nn.LayerNorm(DIM, eps=NORM_EPS) self.mlp = GatedMLP() def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, T, D] x = x + self.attn(self.norm1(x)) # [B, T, D] x = x + self.mlp(self.norm2(x)) # [B, T, D] return x class PatchEmbedding(nn.Module): """Conv patch embedding: Conv1d(C->D, kernel=stride=patch_size).""" def __init__(self) -> None: super().__init__() self.proj = nn.Conv1d( in_channels=NUM_CHANNELS, out_channels=DIM, kernel_size=PATCH_SIZE, stride=PATCH_SIZE, bias=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, C, L] z_t = self.proj(x) # [B, D, T] return z_t.transpose(1, 2) # [B, T, D] class LeNEPAEncoder(nn.Module): """LeNEPA encoder trunk for this exact checkpoint (static conv patch embed, causal).""" def __init__(self) -> None: super().__init__() if CHANNEL_SIZE % PATCH_SIZE != 0: raise ValueError("CHANNEL_SIZE must be divisible by PATCH_SIZE") self.patch_embed = PatchEmbedding() self.blocks = nn.ModuleList([Block() for _ in range(DEPTH)]) self.norm = nn.LayerNorm(DIM, eps=NORM_EPS) def forward(self, x: torch.Tensor) -> torch.Tensor: """Return final-layer patch tokens (post-final-norm).""" z = self.patch_embed(x) # [B, T, D] for block in self.blocks: z = block(z) # [B, T, D] return self.norm(z) # [B, T, D] @torch.inference_mode() def encode_lenepa(*, model: LeNEPAEncoder, x_waveform: torch.Tensor) -> LeNEPAEncoderOutput: """Encode a batch of waveforms. Args: model: LeNEPA encoder (on the same device as x_waveform). x_waveform: [B, 1, 5000] float32. """ if x_waveform.dtype is not torch.float32: raise ValueError(f"x_waveform must be float32, got {x_waveform.dtype}") if x_waveform.dim() != 3: raise ValueError(f"x_waveform must be [B, C, L], got {tuple(x_waveform.shape)}") B, C, L = x_waveform.shape if C != NUM_CHANNELS or L != CHANNEL_SIZE: raise ValueError( "Input must match the published contract: " f"expected [B, {NUM_CHANNELS}, {CHANNEL_SIZE}], got {tuple(x_waveform.shape)}" ) model_device = next(model.parameters()).device if x_waveform.device != model_device: raise ValueError( "x_waveform must be on the same device as the model. " f"x_waveform.device={x_waveform.device} model.device={model_device}" ) patch_tokens = model(x_waveform) # [B, T, D] embedding = patch_tokens.mean(dim=1) # [B, D] return LeNEPAEncoderOutput(patch_tokens=patch_tokens, embedding=embedding) def load_lenepa_encoder(*, weights_path: Path, device: torch.device) -> LeNEPAEncoder: """Load the published encoder weights from a safetensors file.""" if not weights_path.is_file(): raise ValueError(f"weights_path does not exist: {str(weights_path)!r}") state = safetensors_load(str(weights_path)) model = LeNEPAEncoder() model.load_state_dict(state, strict=True) model.eval() model.requires_grad_(False) return model.to(device) def _smoke_test() -> None: """Small end-to-end smoke test (random input, prints output shapes).""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") here = Path(__file__).resolve().parent model = load_lenepa_encoder(weights_path=here / "lenepa_encoder.safetensors", device=device) x = torch.randn(2, 1, 5000, device=device, dtype=torch.float32) # [B=2, C=1, L=5000] out = encode_lenepa(model=model, x_waveform=x) print("patch_tokens", tuple(out.patch_tokens.shape)) print("embedding", tuple(out.embedding.shape)) if __name__ == "__main__": _smoke_test()