| """Δt scalar → conditioning token in R^d via sinusoidal encoding.""" | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| from torch import nn | |
| class DeltaTEmbedding(nn.Module): | |
| def __init__(self, d_model: int = 256, n_freqs: int = 32): | |
| super().__init__() | |
| # frequencies span 10 ms to 10 s — sinusoidal, fixed (not learned) | |
| freqs = torch.exp( | |
| torch.linspace(math.log(2 * math.pi), math.log(2 * math.pi / 10.0), n_freqs) | |
| ) | |
| self.register_buffer("freqs", freqs, persistent=False) | |
| self.proj = nn.Linear(2 * n_freqs, d_model) | |
| def forward(self, dt_seconds: torch.Tensor) -> torch.Tensor: | |
| # dt_seconds: [B] | |
| x = dt_seconds.unsqueeze(-1) * self.freqs # [B, n_freqs] | |
| emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) | |
| return self.proj(emb) # [B, d] | |