alexanderchemeris's picture
LeNEPA trained on a custom CauKer-2M dataset with 5000 points per series
563f37c verified
"""Minimal inference for the published LeNEPA encoder checkpoint.
Published IO contract:
- x_waveform: torch.float32 [B, 1, 5000], channel order: ["c0"]
- outputs:
patch_tokens: [B, 625, 256]
embedding: [B, 256]
This code intentionally does NOT:
- resample / crop / pad inputs
- support other checkpoints or architectures
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import torch
from safetensors.torch import load_file as safetensors_load
from torch import nn
from torch.nn import functional as F
# -----------------------------
# Published constants (no knobs)
# -----------------------------
SAMPLING_FREQUENCY = 1
CHANNELS = ("c0",)
NUM_CHANNELS = 1
CHANNEL_SIZE = 5000
PATCH_SIZE = 8
NUM_PATCHES = 625 # 5000 / 8
DIM = 256
DEPTH = 8
NUM_HEADS = 4
MLP_RATIO = 4.0
PATCH_EMBED_CNN_DIM = 192
SCALAR_HIDDEN_DIM = 32
SCALAR_SCALES = (0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0, 10000.0)
SCALAR_EPSILON = 1.1
QKV_BIAS = True
BIAS = True
NORM_EPS = 1e-6
ROPE_BASE = 10_000
QK_NORM_EPS = 1e-6
@dataclass(frozen=True)
class LeNEPAEncoderOutput:
"""Outputs of the published LeNEPA encoder."""
patch_tokens: torch.Tensor # [B, T=625, D=256]
embedding: torch.Tensor # [B, D=256]
class ScalarEncoder(nn.Module):
"""Affine + LayerNorm scalar encoder used inside the MSSE blocks."""
def __init__(self, *, k: float, hidden_dim: int, eps: float) -> None:
super().__init__()
if hidden_dim < 1:
raise ValueError(f"hidden_dim must be >= 1, got {hidden_dim}")
if eps <= 0:
raise ValueError(f"eps must be > 0, got {eps}")
self.k = float(k)
self.w = nn.Parameter(torch.rand((1, hidden_dim), dtype=torch.float32, requires_grad=True))
self.b = nn.Parameter(torch.rand((1, hidden_dim), dtype=torch.float32, requires_grad=True))
self.norm = nn.LayerNorm(hidden_dim, eps=eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., 1]
if x.size(-1) != 1:
raise ValueError(f"Expected x[..., 1], got x.shape={tuple(x.shape)}")
z = x * self.w + self.k * self.b # [..., H]
return self.norm(z) # [..., H]
class MultiScaledScalarEncoder(nn.Module):
"""Blend per-scale scalar encoders based on the scalar magnitude."""
def __init__(
self,
*,
scales: tuple[float, ...],
hidden_dim: int,
epsilon: float,
eps: float,
) -> None:
super().__init__()
if not scales:
raise ValueError("scales must be non-empty")
if any(scale <= 0 for scale in scales):
raise ValueError(f"All scales must be > 0, got scales={scales}")
if hidden_dim < 1:
raise ValueError(f"hidden_dim must be >= 1, got {hidden_dim}")
if epsilon <= 0:
raise ValueError(f"epsilon must be > 0, got {epsilon}")
if eps <= 0:
raise ValueError(f"eps must be > 0, got {eps}")
scales_t = torch.tensor(scales, dtype=torch.float32) # [S]
self.register_buffer("scales", scales_t, persistent=False)
self.epsilon = float(epsilon)
self.encoders = nn.ModuleList(
[ScalarEncoder(k=float(scale), hidden_dim=hidden_dim, eps=eps) for scale in scales]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., 1]
if x.size(-1) != 1:
raise ValueError(f"Expected x[..., 1], got x.shape={tuple(x.shape)}")
x_abs = x.abs().to(dtype=torch.float32) # [..., 1]
scales = self.scales.to(device=x.device) # [S]
inv_scales = (1.0 / scales).reshape(1, -1) # [1, S]
ratio = torch.matmul(x_abs, inv_scales) # [..., S]
alpha = (1.0 / torch.log(ratio + self.epsilon)).abs() # [..., S]
alpha = alpha / alpha.sum(dim=-1, keepdim=True) # [..., S]
alpha = alpha.unsqueeze(-1) # [..., S, 1]
encoded_list = [encoder(x) for encoder in self.encoders] # S * [..., H]
encoded = torch.stack(encoded_list, dim=-2) # [..., S, H]
mixed = (encoded.to(dtype=torch.float32) * alpha).sum(dim=-2) # [..., H]
return mixed.to(dtype=x.dtype) # [..., H]
class RotaryEmbedding(nn.Module):
"""Rotary positional embeddings (RoPE) applied to Q/K."""
def __init__(self, *, dim: int, base: int) -> None:
super().__init__()
if dim % 2 != 0:
raise ValueError(f"RoPE requires even head_dim, got {dim}")
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) # [Dh/2]
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached: int | None = None
self._cos_cached: torch.Tensor | None = None
self._sin_cached: torch.Tensor | None = None
self._device_cached: torch.device | None = None
self._dtype_cached: torch.dtype | None = None
def _build_cache(self, *, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
positions = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [T]
freqs = torch.einsum("i,j->ij", positions, self.inv_freq) # [T, Dh/2]
self._cos_cached = freqs.cos().to(dtype) # [T, Dh/2]
self._sin_cached = freqs.sin().to(dtype) # [T, Dh/2]
self._seq_len_cached = seq_len
self._device_cached = device
self._dtype_cached = dtype
def _get_cos_sin(
self, *, seq_len: int, device: torch.device, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if (
self._cos_cached is None
or self._sin_cached is None
or self._seq_len_cached != seq_len
or self._device_cached != device
or self._dtype_cached != dtype
):
self._build_cache(seq_len=seq_len, device=device, dtype=dtype)
if self._cos_cached is None or self._sin_cached is None:
raise RuntimeError("RoPE cache was not built; this is a bug")
return self._cos_cached, self._sin_cached
def _apply_rotary(self, x: torch.Tensor, *, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
# x: [B, H, T, Dh]
B, H, T, Dh = x.shape
x_2 = x.view(B, H, T, Dh // 2, 2) # [B, H, T, Dh/2, 2]
x1 = x_2[..., 0] # [B, H, T, Dh/2]
x2 = x_2[..., 1] # [B, H, T, Dh/2]
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, T, Dh/2]
sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, T, Dh/2]
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
return torch.stack((out1, out2), dim=-1).flatten(-2) # [B, H, T, Dh]
def apply(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
cos, sin = self._get_cos_sin(seq_len=q.size(-2), device=q.device, dtype=q.dtype) # [T, Dh/2]
return self._apply_rotary(q, cos=cos, sin=sin), self._apply_rotary(k, cos=cos, sin=sin)
class Attention(nn.Module):
"""Causal self-attention with RoPE + QK-Norm (no dropout)."""
def __init__(self) -> None:
super().__init__()
if DIM % NUM_HEADS != 0:
raise ValueError(f"DIM must be divisible by NUM_HEADS, got DIM={DIM} NUM_HEADS={NUM_HEADS}")
head_dim = DIM // NUM_HEADS
self.num_heads = NUM_HEADS
self.rope = RotaryEmbedding(dim=head_dim, base=ROPE_BASE)
self.qk_norm = nn.LayerNorm(head_dim, eps=QK_NORM_EPS, elementwise_affine=False)
self.qkv = nn.Linear(DIM, DIM * 3, bias=QKV_BIAS)
self.proj = nn.Linear(DIM, DIM, bias=BIAS)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
B, T, D = x.shape
qkv = (
self.qkv(x) # [B, T, 3*D]
.reshape(B, T, 3, self.num_heads, D // self.num_heads) # [B, T, 3, H, Dh]
.permute(2, 0, 3, 1, 4) # [3, B, H, T, Dh]
)
q, k, v = qkv[0], qkv[1], qkv[2] # each [B, H, T, Dh]
q, k = self.rope.apply(q, k) # [B, H, T, Dh] each
q = self.qk_norm(q) # [B, H, T, Dh]
k = self.qk_norm(k) # [B, H, T, Dh]
attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True) # [B, H, T, Dh]
out = attn.transpose(1, 2).reshape(B, T, D) # [B, T, D]
return self.proj(out) # [B, T, D]
class GatedMLP(nn.Module):
"""SwiGLU MLP used in this checkpoint."""
def __init__(self) -> None:
super().__init__()
hidden_dim = int((2 / 3) * MLP_RATIO * DIM)
if hidden_dim <= 0:
raise ValueError(f"hidden_dim must be > 0, got {hidden_dim}")
self.fc1 = nn.Linear(DIM, hidden_dim * 2, bias=BIAS)
self.fc2 = nn.Linear(hidden_dim, DIM, bias=BIAS)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
gate_and_value = self.fc1(x) # [B, T, 2H]
gate, value = gate_and_value.chunk(2, dim=-1) # each [B, T, H]
return self.fc2(self.act(gate) * value) # [B, T, D]
class Block(nn.Module):
"""Transformer block: LN -> Attn -> residual -> LN -> MLP -> residual."""
def __init__(self) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(DIM, eps=NORM_EPS)
self.attn = Attention()
self.norm2 = nn.LayerNorm(DIM, eps=NORM_EPS)
self.mlp = GatedMLP()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, D]
x = x + self.attn(self.norm1(x)) # [B, T, D]
x = x + self.mlp(self.norm2(x)) # [B, T, D]
return x
class PatchEmbedding(nn.Module):
"""Conv patch embedding: Conv1d(C->D_cnn, kernel=stride=patch_size)."""
def __init__(self) -> None:
super().__init__()
self.proj = nn.Conv1d(
in_channels=NUM_CHANNELS,
out_channels=PATCH_EMBED_CNN_DIM,
kernel_size=PATCH_SIZE,
stride=PATCH_SIZE,
bias=BIAS,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, L]
z_t = self.proj(x) # [B, D_cnn, T]
return z_t.transpose(1, 2) # [B, T, D_cnn]
class LeNEPAEncoder(nn.Module):
"""LeNEPA encoder trunk for this exact checkpoint."""
def __init__(self) -> None:
super().__init__()
if CHANNEL_SIZE % PATCH_SIZE != 0:
raise ValueError("CHANNEL_SIZE must be divisible by PATCH_SIZE")
self.nepa_patch_embed_mean_encoder = MultiScaledScalarEncoder(
scales=SCALAR_SCALES,
hidden_dim=SCALAR_HIDDEN_DIM,
epsilon=SCALAR_EPSILON,
eps=NORM_EPS,
)
self.nepa_patch_embed_std_encoder = MultiScaledScalarEncoder(
scales=SCALAR_SCALES,
hidden_dim=SCALAR_HIDDEN_DIM,
epsilon=SCALAR_EPSILON,
eps=NORM_EPS,
)
self.patch_embed = PatchEmbedding()
self.blocks = nn.ModuleList([Block() for _ in range(DEPTH)])
self.norm = nn.LayerNorm(DIM, eps=NORM_EPS)
def _tokenize(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, L]
B, C, L = x.shape
if C != NUM_CHANNELS or L != CHANNEL_SIZE:
raise ValueError(
"Input must match the published contract: "
f"expected [B, {NUM_CHANNELS}, {CHANNEL_SIZE}], got {tuple(x.shape)}"
)
if L % PATCH_SIZE != 0:
raise ValueError(f"Expected L divisible by PATCH_SIZE, got L={L}, PATCH_SIZE={PATCH_SIZE}")
T = L // PATCH_SIZE
x_f32 = x.to(dtype=torch.float32) # [B, C, L]
x_patches = x_f32.reshape(B, C, T, PATCH_SIZE) # [B, C, T, P]
mean_patch = x_patches.mean(dim=(1, 3)) # [B, T]
std_patch = x_patches.std(dim=(1, 3), unbiased=False) # [B, T]
mean_broadcast = mean_patch[:, None, :, None] # [B, 1, T, 1]
std_broadcast = std_patch[:, None, :, None] # [B, 1, T, 1]
x_norm_patches = (x_patches - mean_broadcast) / (std_broadcast + NORM_EPS) # [B, C, T, P]
x_norm = x_norm_patches.reshape(B, C, L).to(dtype=x.dtype) # [B, C, L]
z_cnn = self.patch_embed(x_norm) # [B, T, D_cnn]
e_mean = self.nepa_patch_embed_mean_encoder(mean_patch.unsqueeze(-1)) # [B, T, H]
e_std = self.nepa_patch_embed_std_encoder(std_patch.unsqueeze(-1)) # [B, T, H]
tokens = torch.cat([z_cnn, e_mean.to(dtype=z_cnn.dtype), e_std.to(dtype=z_cnn.dtype)], dim=-1) # [B, T, D]
if tokens.size(-1) != DIM:
raise RuntimeError(
"Tokenizer produced unexpected dim. "
f"Got tokens.shape={tuple(tokens.shape)}, expected last dim={DIM}"
)
return tokens
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, L]
z = self._tokenize(x) # [B, T, D]
for block in self.blocks:
z = block(z) # [B, T, D]
return self.norm(z) # [B, T, D]
@torch.inference_mode()
def encode_lenepa(*, model: LeNEPAEncoder, x_waveform: torch.Tensor) -> LeNEPAEncoderOutput:
"""Encode a batch of waveforms."""
if x_waveform.dtype != torch.float32:
raise ValueError(f"x_waveform must be float32, got {x_waveform.dtype}")
if x_waveform.dim() != 3:
raise ValueError(f"x_waveform must be [B, C, L], got {tuple(x_waveform.shape)}")
model_device = next(model.parameters()).device
if x_waveform.device != model_device:
raise ValueError(
"x_waveform must be on the same device as the model. "
f"x_waveform.device={x_waveform.device} model.device={model_device}"
)
patch_tokens = model(x_waveform) # [B, T, D]
embedding = patch_tokens.mean(dim=1) # [B, D]
return LeNEPAEncoderOutput(patch_tokens=patch_tokens, embedding=embedding)
def load_lenepa_encoder(*, weights_path: Path, device: torch.device) -> LeNEPAEncoder:
"""Load the published encoder weights from a safetensors file."""
if not weights_path.is_file():
raise ValueError(f"weights_path does not exist: {str(weights_path)!r}")
state = safetensors_load(str(weights_path))
model = LeNEPAEncoder()
model.load_state_dict(state, strict=True)
model.eval()
model.requires_grad_(False)
return model.to(device)
def _smoke_test() -> None:
"""Small end-to-end smoke test (random input, prints output shapes)."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
here = Path(__file__).resolve().parent
model = load_lenepa_encoder(weights_path=here / "lenepa_encoder.safetensors", device=device)
x = torch.randn(2, 1, 5000, device=device, dtype=torch.float32) # [B=2, C=1, L=5000]
out = encode_lenepa(model=model, x_waveform=x)
print("patch_tokens", tuple(out.patch_tokens.shape))
print("embedding", tuple(out.embedding.shape))
if __name__ == "__main__":
_smoke_test()