File size: 859 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | """Δt scalar → conditioning token in R^d via sinusoidal encoding."""
from __future__ import annotations
import math
import torch
from torch import nn
class DeltaTEmbedding(nn.Module):
def __init__(self, d_model: int = 256, n_freqs: int = 32):
super().__init__()
# frequencies span 10 ms to 10 s — sinusoidal, fixed (not learned)
freqs = torch.exp(
torch.linspace(math.log(2 * math.pi), math.log(2 * math.pi / 10.0), n_freqs)
)
self.register_buffer("freqs", freqs, persistent=False)
self.proj = nn.Linear(2 * n_freqs, d_model)
def forward(self, dt_seconds: torch.Tensor) -> torch.Tensor:
# dt_seconds: [B]
x = dt_seconds.unsqueeze(-1) * self.freqs # [B, n_freqs]
emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
return self.proj(emb) # [B, d]
|