| import math | |
| import torch | |
| import torch.nn.functional as F | |
| def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000): | |
| assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 | |
| half_dim = embedding_dim // 2 | |
| # magic number 10000 is from transformers | |
| emb = math.log(max_positions) / (half_dim - 1) | |
| # emb = math.log(2.) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) | |
| # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] | |
| # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] | |
| emb = timesteps.float()[:, None] * emb[None, :] | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
| if embedding_dim % 2 == 1: # zero pad | |
| emb = F.pad(emb, (0, 1), mode='constant') | |
| assert emb.shape == (timesteps.shape[0], embedding_dim) | |
| return emb | |