| | |
| | |
| | |
| |
|
| | import torch |
| | from torch import nn |
| | import numpy as np |
| |
|
| |
|
| | class FourierPositionEncoding(nn.Module): |
| | def __init__(self, n, num_bands, max_resolution): |
| | """ |
| | Module that generate Fourier encoding - no learning involved |
| | """ |
| | super().__init__() |
| |
|
| | self.num_bands = num_bands |
| | self.max_resolution = [max_resolution] * n |
| |
|
| | @property |
| | def channels(self): |
| | """ |
| | Return the output dimension |
| | """ |
| | num_dims = len(self.max_resolution) |
| | encoding_size = self.num_bands * num_dims |
| | encoding_size *= 2 |
| | encoding_size += num_dims |
| |
|
| | return encoding_size |
| |
|
| | def forward(self, pos): |
| | """ |
| | Forward pass that take rays as input and generate Fourier positional encodings |
| | """ |
| | fourier_pos_enc = _generate_fourier_features( |
| | pos, num_bands=self.num_bands, max_resolution=self.max_resolution |
| | ) |
| | return fourier_pos_enc |
| |
|
| |
|
| | def _generate_fourier_features(pos, num_bands, max_resolution): |
| | """Generate fourier features from a given set of positions and frequencies""" |
| | b, n = pos.shape[:2] |
| | device = pos.device |
| |
|
| | |
| | min_freq = 1.0 |
| | freq_bands = torch.stack( |
| | [ |
| | torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) |
| | for res in max_resolution |
| | ], |
| | dim=0, |
| | ) |
| |
|
| | |
| | per_pos_features = torch.stack( |
| | [pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0 |
| | ) |
| | per_pos_features = per_pos_features.reshape(b, n, -1) |
| |
|
| | |
| | per_pos_features = torch.cat( |
| | [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], |
| | dim=-1, |
| | ) |
| |
|
| | |
| | per_pos_features = torch.cat([pos, per_pos_features], dim=-1) |
| |
|
| | return per_pos_features |
| |
|