geolocation / src /g3 /locationencoder.py
3v324v23's picture
init prj
eff2be4
import torch
from torch import nn
from .nn.mlp import MLP
from .nn.rff_mlp import RFFMLP
from .nn.siren import SirenNet
from .pe.projection import Projection
from .pe.projection_rff import ProjectionRFF
from .pe.spherical_harmonics import SphericalHarmonics
def get_positional_encoding(positional_encoding_type, hparams, device="cuda"):
"""
Returns a positional encoding module based on the specified encoding type.
Args:
encoding_type (str): The type of positional encoding to use. Options are 'rff', 'siren', 'sh', 'capsule'.
input_dim (int): The input dimension for the positional encoding.
output_dim (int): The output dimension for the positional encoding.
hparams: Additional arguments for specific encoding types.
Returns:
nn.Module: The positional encoding module.
"""
if positional_encoding_type == "projectionrff":
return ProjectionRFF(
projection=hparams["projection"],
sigma=hparams["sigma"],
hparams=hparams,
device=device,
)
elif positional_encoding_type == "projection":
return Projection(
projection=hparams["projection"], hparams=hparams, device=device
)
elif positional_encoding_type == "sh":
return SphericalHarmonics(
legendre_polys=hparams["legendre_polys"],
harmonics_calculation=hparams["harmonics_calculation"],
hparams=hparams,
device=device,
)
else:
raise ValueError(f"Unsupported encoding type: {positional_encoding_type}")
def get_neural_network(
neural_network_type: str,
input_dim: int,
hparams: dict,
device="cuda",
):
"""
Returns a neural network module based on the specified network type.
Args:
neural_network_type (str): The type of neural network to use. Options are 'siren'.
input_dim (int): The input dimension for the neural network.
output_dim (int): The output dimension for the neural network.
hparams: Additional arguments for specific network types.
Returns:
nn.Module: The neural network module.
"""
if neural_network_type == "siren":
return SirenNet(
input_dim=input_dim,
output_dim=hparams["output_dim"],
hidden_dim=hparams["hidden_dim"],
num_layers=hparams["num_layers"],
hparams=hparams,
device=device,
)
elif neural_network_type == "mlp":
return MLP(
input_dim=input_dim,
hidden_dim=hparams["hidden_dim"],
hparams=hparams,
device=device,
)
elif neural_network_type == "rffmlp":
return RFFMLP(
input_dim=input_dim,
hidden_dim=hparams["hidden_dim"],
sigma=hparams["sigma"],
hparams=hparams,
device=device,
)
else:
raise ValueError(f"Unsupported network type: {neural_network_type}")
class LocationEncoder(nn.Module):
def __init__(
self,
positional_encoding_type: str = "sh",
neural_network_type: str = "siren",
hparams: dict | None = None,
device: str = "cuda",
):
super().__init__()
self.device = device
self.position_encoder = get_positional_encoding(
positional_encoding_type=positional_encoding_type,
hparams=hparams,
device=device,
)
if hparams is None:
hparams = {}
self.neural_network = nn.ModuleList(
[
get_neural_network(
neural_network_type, input_dim=dim, hparams=hparams, device=device
)
for dim in self.position_encoder.embedding_dim
]
)
def forward(self, x):
embedding = self.position_encoder(x)
if embedding.ndim == 2:
# If the embedding is (batch, n), we need to add a dimension
embedding = embedding.unsqueeze(0)
location_features = torch.zeros(embedding.shape[1], 512).to(self.device)
for nn, e in zip(self.neural_network, embedding):
location_features += nn(e)
return location_features