"""Δt scalar → conditioning token in R^d via sinusoidal encoding.""" from __future__ import annotations import math import torch from torch import nn class DeltaTEmbedding(nn.Module): def __init__(self, d_model: int = 256, n_freqs: int = 32): super().__init__() # frequencies span 10 ms to 10 s — sinusoidal, fixed (not learned) freqs = torch.exp( torch.linspace(math.log(2 * math.pi), math.log(2 * math.pi / 10.0), n_freqs) ) self.register_buffer("freqs", freqs, persistent=False) self.proj = nn.Linear(2 * n_freqs, d_model) def forward(self, dt_seconds: torch.Tensor) -> torch.Tensor: # dt_seconds: [B] x = dt_seconds.unsqueeze(-1) * self.freqs # [B, n_freqs] emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) return self.proj(emb) # [B, d]