Spaces:
Sleeping
Sleeping
File size: 4,263 Bytes
eff2be4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
from torch import nn
from .nn.mlp import MLP
from .nn.rff_mlp import RFFMLP
from .nn.siren import SirenNet
from .pe.projection import Projection
from .pe.projection_rff import ProjectionRFF
from .pe.spherical_harmonics import SphericalHarmonics
def get_positional_encoding(positional_encoding_type, hparams, device="cuda"):
"""
Returns a positional encoding module based on the specified encoding type.
Args:
encoding_type (str): The type of positional encoding to use. Options are 'rff', 'siren', 'sh', 'capsule'.
input_dim (int): The input dimension for the positional encoding.
output_dim (int): The output dimension for the positional encoding.
hparams: Additional arguments for specific encoding types.
Returns:
nn.Module: The positional encoding module.
"""
if positional_encoding_type == "projectionrff":
return ProjectionRFF(
projection=hparams["projection"],
sigma=hparams["sigma"],
hparams=hparams,
device=device,
)
elif positional_encoding_type == "projection":
return Projection(
projection=hparams["projection"], hparams=hparams, device=device
)
elif positional_encoding_type == "sh":
return SphericalHarmonics(
legendre_polys=hparams["legendre_polys"],
harmonics_calculation=hparams["harmonics_calculation"],
hparams=hparams,
device=device,
)
else:
raise ValueError(f"Unsupported encoding type: {positional_encoding_type}")
def get_neural_network(
neural_network_type: str,
input_dim: int,
hparams: dict,
device="cuda",
):
"""
Returns a neural network module based on the specified network type.
Args:
neural_network_type (str): The type of neural network to use. Options are 'siren'.
input_dim (int): The input dimension for the neural network.
output_dim (int): The output dimension for the neural network.
hparams: Additional arguments for specific network types.
Returns:
nn.Module: The neural network module.
"""
if neural_network_type == "siren":
return SirenNet(
input_dim=input_dim,
output_dim=hparams["output_dim"],
hidden_dim=hparams["hidden_dim"],
num_layers=hparams["num_layers"],
hparams=hparams,
device=device,
)
elif neural_network_type == "mlp":
return MLP(
input_dim=input_dim,
hidden_dim=hparams["hidden_dim"],
hparams=hparams,
device=device,
)
elif neural_network_type == "rffmlp":
return RFFMLP(
input_dim=input_dim,
hidden_dim=hparams["hidden_dim"],
sigma=hparams["sigma"],
hparams=hparams,
device=device,
)
else:
raise ValueError(f"Unsupported network type: {neural_network_type}")
class LocationEncoder(nn.Module):
def __init__(
self,
positional_encoding_type: str = "sh",
neural_network_type: str = "siren",
hparams: dict | None = None,
device: str = "cuda",
):
super().__init__()
self.device = device
self.position_encoder = get_positional_encoding(
positional_encoding_type=positional_encoding_type,
hparams=hparams,
device=device,
)
if hparams is None:
hparams = {}
self.neural_network = nn.ModuleList(
[
get_neural_network(
neural_network_type, input_dim=dim, hparams=hparams, device=device
)
for dim in self.position_encoder.embedding_dim
]
)
def forward(self, x):
embedding = self.position_encoder(x)
if embedding.ndim == 2:
# If the embedding is (batch, n), we need to add a dimension
embedding = embedding.unsqueeze(0)
location_features = torch.zeros(embedding.shape[1], 512).to(self.device)
for nn, e in zip(self.neural_network, embedding):
location_features += nn(e)
return location_features
|