Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import math | |
| VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] | |
| class FourierEmbedder(nn.Module): | |
| """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts | |
| each feature dimension of `x[..., i]` into: | |
| [ | |
| sin(x[..., i]), | |
| sin(f_1*x[..., i]), | |
| sin(f_2*x[..., i]), | |
| ... | |
| sin(f_N * x[..., i]), | |
| cos(x[..., i]), | |
| cos(f_1*x[..., i]), | |
| cos(f_2*x[..., i]), | |
| ... | |
| cos(f_N * x[..., i]), | |
| x[..., i] # only present if include_input is True. | |
| ], here f_i is the frequency. | |
| Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. | |
| If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; | |
| Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. | |
| Args: | |
| num_freqs (int): the number of frequencies, default is 6; | |
| logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
| otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; | |
| input_dim (int): the input dimension, default is 3; | |
| include_input (bool): include the input tensor or not, default is True. | |
| Attributes: | |
| frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
| otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); | |
| out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), | |
| otherwise, it is input_dim * num_freqs * 2. | |
| """ | |
| def __init__(self, | |
| num_freqs: int = 6, | |
| logspace: bool = True, | |
| input_dim: int = 3, | |
| include_input: bool = True, | |
| include_pi: bool = True) -> None: | |
| """The initialization""" | |
| super().__init__() | |
| if logspace: | |
| frequencies = 2.0 ** torch.arange( | |
| num_freqs, | |
| dtype=torch.float32 | |
| ) | |
| else: | |
| frequencies = torch.linspace( | |
| 1.0, | |
| 2.0 ** (num_freqs - 1), | |
| num_freqs, | |
| dtype=torch.float32 | |
| ) | |
| if include_pi: | |
| frequencies *= torch.pi | |
| self.register_buffer("frequencies", frequencies, persistent=False) | |
| self.include_input = include_input | |
| self.num_freqs = num_freqs | |
| self.out_dim = self.get_dims(input_dim) | |
| def get_dims(self, input_dim): | |
| temp = 1 if self.include_input or self.num_freqs == 0 else 0 | |
| out_dim = input_dim * (self.num_freqs * 2 + temp) | |
| return out_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ Forward process. | |
| Args: | |
| x: tensor of shape [..., dim] | |
| Returns: | |
| embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] | |
| where temp is 1 if include_input is True and 0 otherwise. | |
| """ | |
| if self.num_freqs > 0: | |
| embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) | |
| if self.include_input: | |
| return torch.cat((x, embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return torch.cat((embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return x | |
| class LearnedFourierEmbedder(nn.Module): | |
| """ following @crowsonkb "s lead with learned sinusoidal pos emb """ | |
| """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ | |
| def __init__(self, in_channels, dim): | |
| super().__init__() | |
| assert (dim % 2) == 0 | |
| half_dim = dim // 2 | |
| per_channel_dim = half_dim // in_channels | |
| self.weights = nn.Parameter(torch.randn(per_channel_dim)) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (torch.FloatTensor): [..., c] | |
| Returns: | |
| x (torch.FloatTensor): [..., d] | |
| """ | |
| # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] | |
| freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) | |
| fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) | |
| return fouriered | |
| class TriplaneLearnedFourierEmbedder(nn.Module): | |
| def __init__(self, in_channels, dim): | |
| super().__init__() | |
| self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) | |
| self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) | |
| self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) | |
| self.out_dim = in_channels + dim | |
| def forward(self, x): | |
| yz_embed = self.yz_plane_embedder(x) | |
| xz_embed = self.xz_plane_embedder(x) | |
| xy_embed = self.xy_plane_embedder(x) | |
| embed = yz_embed + xz_embed + xy_embed | |
| return embed | |
| def sequential_pos_embed(num_len, embed_dim): | |
| assert embed_dim % 2 == 0 | |
| pos = torch.arange(num_len, dtype=torch.float32) | |
| omega = torch.arange(embed_dim // 2, dtype=torch.float32) | |
| omega /= embed_dim / 2. | |
| omega = 1. / 10000 ** omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = torch.sin(out) # (M, D/2) | |
| emb_cos = torch.cos(out) # (M, D/2) | |
| embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
| return embeddings | |
| def timestep_embedding(timesteps, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=timesteps.device) | |
| args = timesteps[:, None].to(timesteps.dtype) * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, | |
| num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, | |
| log2_hashmap_size=19, desired_resolution=None): | |
| if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): | |
| return nn.Identity(), input_dim | |
| elif embed_type == "fourier": | |
| embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, | |
| logspace=True, include_input=True) | |
| return embedder_obj, embedder_obj.out_dim | |
| elif embed_type == "hashgrid": | |
| raise NotImplementedError | |
| elif embed_type == "sphere_harmonic": | |
| raise NotImplementedError | |
| else: | |
| raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") | |