wjnwjn59's picture
add code for main model
e1dc1a5
raw
history blame contribute delete
733 Bytes
import math
import torch
def embed_a_timestep(timestep, embedding_dim=320):
half_dim = embedding_dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(start=0,
end=half_dim, dtype=torch.float32) / half_dim)
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
def embed_timesteps(timesteps, embedding_dim=320):
half_dim = embedding_dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(half_dim,
dtype=torch.float32) / half_dim).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None, :]
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)