lenepa-encoder-aiono / inference.py
alexanderchemeris's picture
LeNEPA Encoder trained on Aionoscope balanced dataset
c72c956 verified
"""Minimal inference for the published LeNEPA *encoder* checkpoint (no projector).
Published IO contract:
- x_waveform: torch.float32 [B, 1, 5000] at 500 Hz, channel order: ["I"]
- outputs:
patch_tokens: [B, 200, 192]
embedding: [B, 192]
This code intentionally does NOT:
- resample / crop / pad / normalize inputs
- support other checkpoints or architectures
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import torch
from safetensors.torch import load_file as safetensors_load
from torch import nn
from torch.nn import functional as F
# -----------------------------
# Published constants (no knobs)
# -----------------------------
SAMPLING_FREQUENCY_HZ = 500
CHANNELS = ("I",)
NUM_CHANNELS = 1
CHANNEL_SIZE = 5000
PATCH_SIZE = 25
NUM_PATCHES = 200 # 5000 / 25
DIM = 192
DEPTH = 8
NUM_HEADS = 4
MLP_RATIO = 4.0
QKV_BIAS = True
NORM_EPS = 1e-6
ROPE_BASE = 10_000
QK_NORM_EPS = 1e-6
@dataclass(frozen=True)
class LeNEPAEncoderOutput:
"""Outputs of the published LeNEPA encoder."""
patch_tokens: torch.Tensor # [B, T=200, D=192]
embedding: torch.Tensor # [B, D=192]
class RotaryEmbedding(nn.Module):
"""Rotary positional embeddings (RoPE) applied to Q/K."""
def __init__(self, *, dim: int, base: int) -> None:
super().__init__()
if dim % 2 != 0:
raise ValueError(f"RoPE requires even head_dim, got {dim}")
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) # [Dh/2]
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached: int | None = None
self._cos_cached: torch.Tensor | None = None
self._sin_cached: torch.Tensor | None = None
self._device_cached: torch.device | None = None
self._dtype_cached: torch.dtype | None = None
def _build_cache(self, *, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [T]
freqs = torch.einsum("i,j->ij", positions, self.inv_freq) # [T, Dh/2]
self._cos_cached = freqs.cos().to(dtype) # [T, Dh/2]
self._sin_cached = freqs.sin().to(dtype) # [T, Dh/2]
self._seq_len_cached = seq_len
self._device_cached = device
self._dtype_cached = dtype
def _get_cos_sin(
self, *, seq_len: int, device: torch.device, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if (
self._cos_cached is None
or self._sin_cached is None
or self._seq_len_cached != seq_len
or self._device_cached != device
or self._dtype_cached != dtype
):
self._build_cache(seq_len=seq_len, device=device, dtype=dtype)
if self._cos_cached is None or self._sin_cached is None:
raise RuntimeError("RoPE cache was not built; this is a bug")
return self._cos_cached, self._sin_cached
def _apply_rotary(self, x: torch.Tensor, *, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
# x: [B, H, T, Dh]
B, H, T, Dh = x.shape
x_2 = x.view(B, H, T, Dh // 2, 2) # [B, H, T, Dh/2, 2]
x1 = x_2[..., 0] # [B, H, T, Dh/2]
x2 = x_2[..., 1] # [B, H, T, Dh/2]
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, T, Dh/2]
sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, T, Dh/2]
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
return torch.stack((out1, out2), dim=-1).flatten(-2) # [B, H, T, Dh]
def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply RoPE to Q/K."""
cos, sin = self._get_cos_sin(seq_len=q.size(-2), device=q.device, dtype=q.dtype) # [T, Dh/2]
return self._apply_rotary(q, cos=cos, sin=sin), self._apply_rotary(k, cos=cos, sin=sin)
class Attention(nn.Module):
"""Causal self-attention with RoPE + QK-Norm (no dropout)."""
def __init__(self) -> None:
super().__init__()
if DIM % NUM_HEADS != 0:
raise ValueError(f"DIM must be divisible by NUM_HEADS, got DIM={DIM} NUM_HEADS={NUM_HEADS}")
head_dim = DIM // NUM_HEADS
self.num_heads = NUM_HEADS
self.rope = RotaryEmbedding(dim=head_dim, base=ROPE_BASE)
self.qk_norm = nn.LayerNorm(head_dim, eps=QK_NORM_EPS, elementwise_affine=False)
self.qkv = nn.Linear(DIM, DIM * 3, bias=QKV_BIAS)
self.proj = nn.Linear(DIM, DIM, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
B, T, D = x.shape
qkv = (
self.qkv(x) # [B, T, 3*D]
.reshape(B, T, 3, self.num_heads, D // self.num_heads) # [B, T, 3, H, Dh]
.permute(2, 0, 3, 1, 4) # [3, B, H, T, Dh]
)
q, k, v = qkv[0], qkv[1], qkv[2] # each [B, H, T, Dh]
q, k = self.rope.apply(q, k) # [B, H, T, Dh] each
q = self.qk_norm(q) # [B, H, T, Dh]
k = self.qk_norm(k) # [B, H, T, Dh]
attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) # [B, H, T, Dh]
out = attn.transpose(1, 2).reshape(B, T, D) # [B, T, D]
return self.proj(out) # [B, T, D]
class GatedMLP(nn.Module):
"""SwiGLU MLP used in this checkpoint."""
def __init__(self) -> None:
super().__init__()
hidden_dim = int((2 / 3) * MLP_RATIO * DIM)
if hidden_dim <= 0:
raise ValueError(f"hidden_dim must be > 0, got {hidden_dim}")
self.fc1 = nn.Linear(DIM, hidden_dim * 2, bias=True)
self.fc2 = nn.Linear(hidden_dim, DIM, bias=True)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
gate_and_value = self.fc1(x) # [B, T, 2H]
gate, value = gate_and_value.chunk(2, dim=-1) # each [B, T, H]
return self.fc2(self.act(gate) * value) # [B, T, D]
class Block(nn.Module):
"""Transformer block: LN -> Attn -> residual -> LN -> MLP -> residual."""
def __init__(self) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(DIM, eps=NORM_EPS)
self.attn = Attention()
self.norm2 = nn.LayerNorm(DIM, eps=NORM_EPS)
self.mlp = GatedMLP()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
x = x + self.attn(self.norm1(x)) # [B, T, D]
x = x + self.mlp(self.norm2(x)) # [B, T, D]
return x
class PatchEmbedding(nn.Module):
"""Conv patch embedding: Conv1d(C->D, kernel=stride=patch_size)."""
def __init__(self) -> None:
super().__init__()
self.proj = nn.Conv1d(
in_channels=NUM_CHANNELS,
out_channels=DIM,
kernel_size=PATCH_SIZE,
stride=PATCH_SIZE,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, L]
z_t = self.proj(x) # [B, D, T]
return z_t.transpose(1, 2) # [B, T, D]
class LeNEPAEncoder(nn.Module):
"""LeNEPA encoder trunk for this exact checkpoint (static conv patch embed, causal)."""
def __init__(self) -> None:
super().__init__()
if CHANNEL_SIZE % PATCH_SIZE != 0:
raise ValueError("CHANNEL_SIZE must be divisible by PATCH_SIZE")
self.patch_embed = PatchEmbedding()
self.blocks = nn.ModuleList([Block() for _ in range(DEPTH)])
self.norm = nn.LayerNorm(DIM, eps=NORM_EPS)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return final-layer patch tokens (post-final-norm)."""
z = self.patch_embed(x) # [B, T, D]
for block in self.blocks:
z = block(z) # [B, T, D]
return self.norm(z) # [B, T, D]
@torch.inference_mode()
def encode_lenepa(*, model: LeNEPAEncoder, x_waveform: torch.Tensor) -> LeNEPAEncoderOutput:
"""Encode a batch of waveforms.
Args:
model: LeNEPA encoder (on the same device as x_waveform).
x_waveform: [B, 1, 5000] float32.
"""
if x_waveform.dtype is not torch.float32:
raise ValueError(f"x_waveform must be float32, got {x_waveform.dtype}")
if x_waveform.dim() != 3:
raise ValueError(f"x_waveform must be [B, C, L], got {tuple(x_waveform.shape)}")
B, C, L = x_waveform.shape
if C != NUM_CHANNELS or L != CHANNEL_SIZE:
raise ValueError(
"Input must match the published contract: "
f"expected [B, {NUM_CHANNELS}, {CHANNEL_SIZE}], got {tuple(x_waveform.shape)}"
)
model_device = next(model.parameters()).device
if x_waveform.device != model_device:
raise ValueError(
"x_waveform must be on the same device as the model. "
f"x_waveform.device={x_waveform.device} model.device={model_device}"
)
patch_tokens = model(x_waveform) # [B, T, D]
embedding = patch_tokens.mean(dim=1) # [B, D]
return LeNEPAEncoderOutput(patch_tokens=patch_tokens, embedding=embedding)
def load_lenepa_encoder(*, weights_path: Path, device: torch.device) -> LeNEPAEncoder:
"""Load the published encoder weights from a safetensors file."""
if not weights_path.is_file():
raise ValueError(f"weights_path does not exist: {str(weights_path)!r}")
state = safetensors_load(str(weights_path))
model = LeNEPAEncoder()
model.load_state_dict(state, strict=True)
model.eval()
model.requires_grad_(False)
return model.to(device)
def _smoke_test() -> None:
"""Small end-to-end smoke test (random input, prints output shapes)."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
here = Path(__file__).resolve().parent
model = load_lenepa_encoder(weights_path=here / "lenepa_encoder.safetensors", device=device)
x = torch.randn(2, 1, 5000, device=device, dtype=torch.float32) # [B=2, C=1, L=5000]
out = encode_lenepa(model=model, x_waveform=x)
print("patch_tokens", tuple(out.patch_tokens.shape))
print("embedding", tuple(out.embedding.shape))
if __name__ == "__main__":
_smoke_test()