# MIT License # Copyright (c) [2026] [Tim Büchner, Sai Karthikeya Vemuri, Joachim Denzler] # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. __all__ = ["get_model_2D", "MLPType", "EmbeddingType", "DecompositionType"] from abc import ABC from enum import Enum from typing import Optional import flax.linen as nn import jax import jax.numpy as jnp class MLPType(Enum): TANH = "TANH" RELU = "ReLU" WIRE = "WIRE" SIREN = "SIREN" SIREN2 = "SIREN2" FINER = "FINER" NEURBF = "NeuRBF" class EmbeddingType(Enum): PE000 = "PE000" PE010 = "PE010" PE020 = "PE020" PE100 = "PE100" HE = "HE" class DecompositionType(Enum): BASELINE = "Baseline" CP = "CP" TT = "TT" TU = "TU" TR = "TR" class NeuRBF1D(nn.Module): num_rbfs: int feature_dim: int @nn.compact def __call__(self, x): centers = self.param("centers", nn.initializers.uniform(), (self.num_rbfs, 1)) log_sigma = self.param("log_sigma", nn.initializers.zeros, (self.num_rbfs, 1)) sigma = jnp.exp(log_sigma) + 1e-6 freq = self.param("freq", nn.initializers.normal(stddev=5.0), (1, self.feature_dim)) bias = self.param("bias", nn.initializers.zeros, (1, self.feature_dim)) features = self.param("features", nn.initializers.normal(stddev=0.1), (self.num_rbfs, self.feature_dim)) x_exp = x[:, None, :] c = centers[None, :, :] s = sigma[None, :, :] sq_dist = ((x_exp - c) ** 2) / (s**2) rbf_vals = 1.0 / (1.0 + sq_dist.sum(-1)) composed = jnp.sin(rbf_vals[:, :, None] * freq + bias) modulated = composed * features[None, :, :] aggregated = jnp.sum(modulated, axis=1) h = nn.Dense(self.feature_dim)(aggregated) h = jnp.sin(h * freq[0]) + h return nn.Dense(self.feature_dim)(h) class RealGaborLayer(nn.Module): in_features: int out_features: int bias: bool = True is_first: bool = False omega0: float = 10.0 sigma0: float = 10.0 def setup(self): self.omega_0 = self.omega0 self.scale_0 = self.sigma0 self.freqs = nn.Dense(self.out_features, use_bias=self.bias) self.scale = nn.Dense(self.out_features, use_bias=self.bias) def __call__(self, input): omega = self.omega_0 * self.freqs(input) scale = self.scale(input) * self.scale_0 return jnp.cos(omega) * jnp.exp(-(scale**2)) class SineLayer(nn.Module): in_features: int out_features: int bias: bool = True is_first: bool = False omega_0: float = 30.0 init_weights: bool = True def setup(self): self.linear = nn.Dense(self.out_features, use_bias=self.bias, kernel_init=self.init_weights_fn()) def init_weights_fn(self): if self.is_first: def init(key, shape, dtype=jnp.float32): limit = 1.0 / shape[0] return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) else: def init(key, shape, dtype=jnp.float32): limit = jnp.sqrt(6.0 / shape[0]) / self.omega_0 return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) return init def __call__(self, input): return jnp.sin(self.omega_0 * self.linear(input)) class SimpleHashEncoder1D(nn.Module): L: int F: int N_min: int N_max: int T: int = 2**14 @property def b(self) -> jax.Array: return jnp.exp((jnp.log(self.N_max) - jnp.log(self.N_min)) / (self.L - 1)) @nn.compact def __call__(self, x: jax.Array, bound: float) -> jax.Array: x = (x + bound) / (2 * bound) scales = self.N_min * (self.b ** jnp.arange(self.L)) - 1 x_scaled = x[:, None] * scales[None, :] + 0.5 indices = jnp.floor(x_scaled).astype(jnp.int32) % self.T embeddings = self.param("hash_table", lambda key, shape: jax.random.uniform(key, shape, minval=-0.001, maxval=0.001), (self.T, self.F)) return embeddings[indices].reshape(x.shape[0], -1) class BACKEND(ABC, nn.Module): features: list r: int in_dim: int out_dim: int embedding: EmbeddingType mlp: MLPType L: int = 16 F: int = 2 N_min: int = 16 N_max: int = 524288 T: int = 2**14 def setup(self): if self.embedding == EmbeddingType.HE: self.hash_encoder = SimpleHashEncoder1D(L=self.L, F=self.F, N_min=self.N_min, N_max=self.N_max, T=self.T) def encode(self, input): if self.mlp == MLPType.NEURBF: return input if self.embedding == EmbeddingType.HE: return self.hash_encoder(input, 1.0) elif self.embedding == EmbeddingType.PE000: return input elif self.embedding == EmbeddingType.PE010: pos_enc = 10 elif self.embedding == EmbeddingType.PE020: pos_enc = 20 elif self.embedding == EmbeddingType.PE100: pos_enc = 100 else: raise ValueError(f"Unsupported embedding type: {self.embedding}") freq = jnp.array([[2**k for k in range(-((pos_enc - 1) // 2), ((pos_enc + 1) // 2))]]) return jnp.concatenate((jnp.sin(input @ freq), jnp.cos(input @ freq)), axis=1) def create_subnetwork(self, decomposition: Optional[DecompositionType] = None): layers = [] if self.mlp == MLPType.RELU: init = nn.initializers.glorot_uniform() for fs in self.features[:-1]: layers.append(nn.Dense(fs, kernel_init=init)) layers.append(nn.relu) layers.append(nn.Dense(self.r * self.out_dim, kernel_init=init)) return nn.Sequential(layers) elif self.mlp == MLPType.TANH: init = nn.initializers.xavier_normal() for fs in self.features[:-1]: layers.append(nn.Dense(fs, kernel_init=init)) layers.append(nn.tanh) layers.append(nn.Dense(self.r * self.out_dim, kernel_init=init)) return nn.Sequential(layers) elif self.mlp == MLPType.WIRE: omega, sigma = 5, 5 for idx, fs in enumerate(self.features[:-1]): layers.append(RealGaborLayer(fs, fs, is_first=(idx == 0), omega0=omega, sigma0=sigma)) layers.append(nn.Dense(self.r * self.out_dim, kernel_init=self.custom_init(False))) return nn.Sequential(layers) elif self.mlp == MLPType.SIREN: for idx, fs in enumerate(self.features[:-1]): if idx == 0: layers.append(nn.Dense(fs, kernel_init=self.custom_init(True))) layers.append(self.scaled_sine_activation) else: layers.append(nn.Dense(fs, kernel_init=self.custom_init(False))) layers.append(self.sine_activation) layers.append(nn.Dense(self.r * self.out_dim, kernel_init=self.custom_init(False))) return nn.Sequential(layers) elif self.mlp == MLPType.FINER: for fs in self.features[:-1]: layers.append(nn.Dense(fs, kernel_init=self.finer_init(0.5))) layers.append(self.finer_activation) layers.append(nn.Dense(self.r * self.out_dim, kernel_init=self.finer_init(0.5))) return nn.Sequential(layers) elif self.mlp == MLPType.NEURBF: return NeuRBF1D(num_rbfs=self.r, feature_dim=self.r * self.out_dim) raise ValueError(f"Unsupported MLP type: {self.mlp}") def custom_init(self, is_first): def init(key, shape, dtype=jnp.float32): limit = 1.0 / shape[0] if is_first else jnp.sqrt(6.0 / shape[0]) / 100 return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) return init def finer_init(self, scale=1.0): def init(key, shape, dtype=jnp.float32): limit = scale / jnp.sqrt(shape[0]) return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) return init @staticmethod def sine_activation(x): return jnp.sin(30 * x) @staticmethod def scaled_sine_activation(x): return jnp.sin(100.0 * x) @staticmethod def finer_activation(x): return jnp.sin((jnp.abs(x) + 1.0) * x) class INR_Baseline2D(BACKEND): def setup(self): super().setup() self.network = self.create_subnetwork() def __call__(self, x, y): x, y = self.encode(x), self.encode(y) X = jnp.concatenate([x, y], axis=1) return self.network(X) class FINR_CP_2D(BACKEND): def setup(self): super().setup() self.network_x = self.create_subnetwork() self.network_y = self.create_subnetwork() def __call__(self, x, y): x, y = self.encode(x), self.encode(y) out_x, out_y = self.network_x(x), self.network_y(y) out_x, out_y = jnp.transpose(out_x, (1, 0)), jnp.transpose(out_y, (1, 0)) pred = [] for i in range(self.out_dim): pred.append(jnp.einsum("fx, fy->xy", out_x[self.r * i : self.r * (i + 1)], out_y[self.r * i : self.r * (i + 1)])) return pred def get_model_2D(backend=MLPType.RELU, embedding=EmbeddingType.PE100, decomp=DecompositionType.CP, rank=128, **kwargs): if decomp == DecompositionType.BASELINE: return INR_Baseline2D(r=rank, embedding=embedding, mlp=backend, in_dim=2, out_dim=3, **kwargs) elif decomp == DecompositionType.CP: return FINR_CP_2D(r=rank, embedding=embedding, mlp=backend, in_dim=2, out_dim=3, **kwargs) raise ValueError(f"Unsupported decomposition type: {decomp}")