File size: 432 Bytes
a3682cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch
import math
class TimeEncoding(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
freqs = torch.arange(self.dim, device=device).float()
freqs = 1 / (10 ** (freqs / self.dim))
t = t.unsqueeze(1)
angles = t * freqs
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=1) |