PhysioJEPA / src /physiojepa /dt_embed.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Δt scalar → conditioning token in R^d via sinusoidal encoding."""
from __future__ import annotations
import math
import torch
from torch import nn
class DeltaTEmbedding(nn.Module):
def __init__(self, d_model: int = 256, n_freqs: int = 32):
super().__init__()
# frequencies span 10 ms to 10 s — sinusoidal, fixed (not learned)
freqs = torch.exp(
torch.linspace(math.log(2 * math.pi), math.log(2 * math.pi / 10.0), n_freqs)
)
self.register_buffer("freqs", freqs, persistent=False)
self.proj = nn.Linear(2 * n_freqs, d_model)
def forward(self, dt_seconds: torch.Tensor) -> torch.Tensor:
# dt_seconds: [B]
x = dt_seconds.unsqueeze(-1) * self.freqs # [B, n_freqs]
emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
return self.proj(emb) # [B, d]