File size: 859 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Δt scalar → conditioning token in R^d via sinusoidal encoding."""
from __future__ import annotations

import math

import torch
from torch import nn


class DeltaTEmbedding(nn.Module):
    def __init__(self, d_model: int = 256, n_freqs: int = 32):
        super().__init__()
        # frequencies span 10 ms to 10 s — sinusoidal, fixed (not learned)
        freqs = torch.exp(
            torch.linspace(math.log(2 * math.pi), math.log(2 * math.pi / 10.0), n_freqs)
        )
        self.register_buffer("freqs", freqs, persistent=False)
        self.proj = nn.Linear(2 * n_freqs, d_model)

    def forward(self, dt_seconds: torch.Tensor) -> torch.Tensor:
        # dt_seconds: [B]
        x = dt_seconds.unsqueeze(-1) * self.freqs  # [B, n_freqs]
        emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
        return self.proj(emb)  # [B, d]