File size: 4,263 Bytes
eff2be4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from torch import nn
from .nn.mlp import MLP
from .nn.rff_mlp import RFFMLP
from .nn.siren import SirenNet
from .pe.projection import Projection
from .pe.projection_rff import ProjectionRFF
from .pe.spherical_harmonics import SphericalHarmonics


def get_positional_encoding(positional_encoding_type, hparams, device="cuda"):
    """
    Returns a positional encoding module based on the specified encoding type.

    Args:
        encoding_type (str): The type of positional encoding to use. Options are 'rff', 'siren', 'sh', 'capsule'.
        input_dim (int): The input dimension for the positional encoding.
        output_dim (int): The output dimension for the positional encoding.
        hparams: Additional arguments for specific encoding types.

    Returns:
        nn.Module: The positional encoding module.
    """
    if positional_encoding_type == "projectionrff":
        return ProjectionRFF(
            projection=hparams["projection"],
            sigma=hparams["sigma"],
            hparams=hparams,
            device=device,
        )
    elif positional_encoding_type == "projection":
        return Projection(
            projection=hparams["projection"], hparams=hparams, device=device
        )
    elif positional_encoding_type == "sh":
        return SphericalHarmonics(
            legendre_polys=hparams["legendre_polys"],
            harmonics_calculation=hparams["harmonics_calculation"],
            hparams=hparams,
            device=device,
        )
    else:
        raise ValueError(f"Unsupported encoding type: {positional_encoding_type}")


def get_neural_network(
    neural_network_type: str,
    input_dim: int,
    hparams: dict,
    device="cuda",
):
    """
    Returns a neural network module based on the specified network type.

    Args:
        neural_network_type (str): The type of neural network to use. Options are 'siren'.
        input_dim (int): The input dimension for the neural network.
        output_dim (int): The output dimension for the neural network.
        hparams: Additional arguments for specific network types.

    Returns:
        nn.Module: The neural network module.
    """
    if neural_network_type == "siren":
        return SirenNet(
            input_dim=input_dim,
            output_dim=hparams["output_dim"],
            hidden_dim=hparams["hidden_dim"],
            num_layers=hparams["num_layers"],
            hparams=hparams,
            device=device,
        )
    elif neural_network_type == "mlp":
        return MLP(
            input_dim=input_dim,
            hidden_dim=hparams["hidden_dim"],
            hparams=hparams,
            device=device,
        )
    elif neural_network_type == "rffmlp":
        return RFFMLP(
            input_dim=input_dim,
            hidden_dim=hparams["hidden_dim"],
            sigma=hparams["sigma"],
            hparams=hparams,
            device=device,
        )
    else:
        raise ValueError(f"Unsupported network type: {neural_network_type}")


class LocationEncoder(nn.Module):
    def __init__(
        self,
        positional_encoding_type: str = "sh",
        neural_network_type: str = "siren",
        hparams: dict | None = None,
        device: str = "cuda",
    ):
        super().__init__()
        self.device = device

        self.position_encoder = get_positional_encoding(
            positional_encoding_type=positional_encoding_type,
            hparams=hparams,
            device=device,
        )

        if hparams is None:
            hparams = {}

        self.neural_network = nn.ModuleList(
            [
                get_neural_network(
                    neural_network_type, input_dim=dim, hparams=hparams, device=device
                )
                for dim in self.position_encoder.embedding_dim
            ]
        )

    def forward(self, x):
        embedding = self.position_encoder(x)

        if embedding.ndim == 2:
            # If the embedding is (batch, n), we need to add a dimension
            embedding = embedding.unsqueeze(0)

        location_features = torch.zeros(embedding.shape[1], 512).to(self.device)

        for nn, e in zip(self.neural_network, embedding):
            location_features += nn(e)

        return location_features