Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| class SinusoidalEmbedding(nn.Module): | |
| def __init__(self, embed_dim : int, theta : int = 10000): | |
| """ | |
| Creates sinusoidal embeddings for timesteps. | |
| Args: | |
| embed_dim: The dimensionality of the embedding. | |
| theta: The base for the log-spaced frequencies. | |
| """ | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.theta = theta | |
| def forward(self, x): | |
| """ | |
| Computes sinusoidal embeddings for the input timesteps. | |
| Args: | |
| x: A 1D torch.Tensor of timesteps (shape: [batch_size]). | |
| Returns: | |
| A torch.Tensor of sinusoidal embeddings (shape: [batch_size, embed_dim]). | |
| """ | |
| assert isinstance(x, torch.Tensor) # Input must be a torch.Tensor | |
| assert x.ndim == 1 # Input must be a 1D tensor | |
| assert isinstance(self.embed_dim, int) and self.embed_dim > 0 # embed_dim must be a positive integer | |
| half_dim = self.embed_dim // 2 | |
| # Create a sequence of log-spaced frequencies | |
| embeddings = math.log(self.theta) / (half_dim - 1) | |
| embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings) | |
| # Outer product: timesteps x frequencies | |
| embeddings = x[:, None] * embeddings[None, :] | |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
| # Handle odd embedding dimensions | |
| if self.embed_dim % 2 == 1: | |
| embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1) | |
| return embeddings |