Spaces:
Sleeping
Sleeping
File size: 10,670 Bytes
67556c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 | # MIT License
# Copyright (c) [2026] [Tim Büchner, Sai Karthikeya Vemuri, Joachim Denzler]
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
__all__ = ["get_model_2D", "MLPType", "EmbeddingType", "DecompositionType"]
from abc import ABC
from enum import Enum
from typing import Optional
import flax.linen as nn
import jax
import jax.numpy as jnp
class MLPType(Enum):
TANH = "TANH"
RELU = "ReLU"
WIRE = "WIRE"
SIREN = "SIREN"
SIREN2 = "SIREN2"
FINER = "FINER"
NEURBF = "NeuRBF"
class EmbeddingType(Enum):
PE000 = "PE000"
PE010 = "PE010"
PE020 = "PE020"
PE100 = "PE100"
HE = "HE"
class DecompositionType(Enum):
BASELINE = "Baseline"
CP = "CP"
TT = "TT"
TU = "TU"
TR = "TR"
class NeuRBF1D(nn.Module):
num_rbfs: int
feature_dim: int
@nn.compact
def __call__(self, x):
centers = self.param("centers", nn.initializers.uniform(), (self.num_rbfs, 1))
log_sigma = self.param("log_sigma", nn.initializers.zeros, (self.num_rbfs, 1))
sigma = jnp.exp(log_sigma) + 1e-6
freq = self.param("freq", nn.initializers.normal(stddev=5.0), (1, self.feature_dim))
bias = self.param("bias", nn.initializers.zeros, (1, self.feature_dim))
features = self.param("features", nn.initializers.normal(stddev=0.1), (self.num_rbfs, self.feature_dim))
x_exp = x[:, None, :]
c = centers[None, :, :]
s = sigma[None, :, :]
sq_dist = ((x_exp - c) ** 2) / (s**2)
rbf_vals = 1.0 / (1.0 + sq_dist.sum(-1))
composed = jnp.sin(rbf_vals[:, :, None] * freq + bias)
modulated = composed * features[None, :, :]
aggregated = jnp.sum(modulated, axis=1)
h = nn.Dense(self.feature_dim)(aggregated)
h = jnp.sin(h * freq[0]) + h
return nn.Dense(self.feature_dim)(h)
class RealGaborLayer(nn.Module):
in_features: int
out_features: int
bias: bool = True
is_first: bool = False
omega0: float = 10.0
sigma0: float = 10.0
def setup(self):
self.omega_0 = self.omega0
self.scale_0 = self.sigma0
self.freqs = nn.Dense(self.out_features, use_bias=self.bias)
self.scale = nn.Dense(self.out_features, use_bias=self.bias)
def __call__(self, input):
omega = self.omega_0 * self.freqs(input)
scale = self.scale(input) * self.scale_0
return jnp.cos(omega) * jnp.exp(-(scale**2))
class SineLayer(nn.Module):
in_features: int
out_features: int
bias: bool = True
is_first: bool = False
omega_0: float = 30.0
init_weights: bool = True
def setup(self):
self.linear = nn.Dense(self.out_features, use_bias=self.bias, kernel_init=self.init_weights_fn())
def init_weights_fn(self):
if self.is_first:
def init(key, shape, dtype=jnp.float32):
limit = 1.0 / shape[0]
return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit)
else:
def init(key, shape, dtype=jnp.float32):
limit = jnp.sqrt(6.0 / shape[0]) / self.omega_0
return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit)
return init
def __call__(self, input):
return jnp.sin(self.omega_0 * self.linear(input))
class SimpleHashEncoder1D(nn.Module):
L: int
F: int
N_min: int
N_max: int
T: int = 2**14
@property
def b(self) -> jax.Array:
return jnp.exp((jnp.log(self.N_max) - jnp.log(self.N_min)) / (self.L - 1))
@nn.compact
def __call__(self, x: jax.Array, bound: float) -> jax.Array:
x = (x + bound) / (2 * bound)
scales = self.N_min * (self.b ** jnp.arange(self.L)) - 1
x_scaled = x[:, None] * scales[None, :] + 0.5
indices = jnp.floor(x_scaled).astype(jnp.int32) % self.T
embeddings = self.param("hash_table", lambda key, shape: jax.random.uniform(key, shape, minval=-0.001, maxval=0.001), (self.T, self.F))
return embeddings[indices].reshape(x.shape[0], -1)
class BACKEND(ABC, nn.Module):
features: list
r: int
in_dim: int
out_dim: int
embedding: EmbeddingType
mlp: MLPType
L: int = 16
F: int = 2
N_min: int = 16
N_max: int = 524288
T: int = 2**14
def setup(self):
if self.embedding == EmbeddingType.HE:
self.hash_encoder = SimpleHashEncoder1D(L=self.L, F=self.F, N_min=self.N_min, N_max=self.N_max, T=self.T)
def encode(self, input):
if self.mlp == MLPType.NEURBF:
return input
if self.embedding == EmbeddingType.HE:
return self.hash_encoder(input, 1.0)
elif self.embedding == EmbeddingType.PE000:
return input
elif self.embedding == EmbeddingType.PE010:
pos_enc = 10
elif self.embedding == EmbeddingType.PE020:
pos_enc = 20
elif self.embedding == EmbeddingType.PE100:
pos_enc = 100
else:
raise ValueError(f"Unsupported embedding type: {self.embedding}")
freq = jnp.array([[2**k for k in range(-((pos_enc - 1) // 2), ((pos_enc + 1) // 2))]])
return jnp.concatenate((jnp.sin(input @ freq), jnp.cos(input @ freq)), axis=1)
def create_subnetwork(self, decomposition: Optional[DecompositionType] = None):
layers = []
if self.mlp == MLPType.RELU:
init = nn.initializers.glorot_uniform()
for fs in self.features[:-1]:
layers.append(nn.Dense(fs, kernel_init=init))
layers.append(nn.relu)
layers.append(nn.Dense(self.r * self.out_dim, kernel_init=init))
return nn.Sequential(layers)
elif self.mlp == MLPType.TANH:
init = nn.initializers.xavier_normal()
for fs in self.features[:-1]:
layers.append(nn.Dense(fs, kernel_init=init))
layers.append(nn.tanh)
layers.append(nn.Dense(self.r * self.out_dim, kernel_init=init))
return nn.Sequential(layers)
elif self.mlp == MLPType.WIRE:
omega, sigma = 5, 5
for idx, fs in enumerate(self.features[:-1]):
layers.append(RealGaborLayer(fs, fs, is_first=(idx == 0), omega0=omega, sigma0=sigma))
layers.append(nn.Dense(self.r * self.out_dim, kernel_init=self.custom_init(False)))
return nn.Sequential(layers)
elif self.mlp == MLPType.SIREN:
for idx, fs in enumerate(self.features[:-1]):
if idx == 0:
layers.append(nn.Dense(fs, kernel_init=self.custom_init(True)))
layers.append(self.scaled_sine_activation)
else:
layers.append(nn.Dense(fs, kernel_init=self.custom_init(False)))
layers.append(self.sine_activation)
layers.append(nn.Dense(self.r * self.out_dim, kernel_init=self.custom_init(False)))
return nn.Sequential(layers)
elif self.mlp == MLPType.FINER:
for fs in self.features[:-1]:
layers.append(nn.Dense(fs, kernel_init=self.finer_init(0.5)))
layers.append(self.finer_activation)
layers.append(nn.Dense(self.r * self.out_dim, kernel_init=self.finer_init(0.5)))
return nn.Sequential(layers)
elif self.mlp == MLPType.NEURBF:
return NeuRBF1D(num_rbfs=self.r, feature_dim=self.r * self.out_dim)
raise ValueError(f"Unsupported MLP type: {self.mlp}")
def custom_init(self, is_first):
def init(key, shape, dtype=jnp.float32):
limit = 1.0 / shape[0] if is_first else jnp.sqrt(6.0 / shape[0]) / 100
return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit)
return init
def finer_init(self, scale=1.0):
def init(key, shape, dtype=jnp.float32):
limit = scale / jnp.sqrt(shape[0])
return jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit)
return init
@staticmethod
def sine_activation(x):
return jnp.sin(30 * x)
@staticmethod
def scaled_sine_activation(x):
return jnp.sin(100.0 * x)
@staticmethod
def finer_activation(x):
return jnp.sin((jnp.abs(x) + 1.0) * x)
class INR_Baseline2D(BACKEND):
def setup(self):
super().setup()
self.network = self.create_subnetwork()
def __call__(self, x, y):
x, y = self.encode(x), self.encode(y)
X = jnp.concatenate([x, y], axis=1)
return self.network(X)
class FINR_CP_2D(BACKEND):
def setup(self):
super().setup()
self.network_x = self.create_subnetwork()
self.network_y = self.create_subnetwork()
def __call__(self, x, y):
x, y = self.encode(x), self.encode(y)
out_x, out_y = self.network_x(x), self.network_y(y)
out_x, out_y = jnp.transpose(out_x, (1, 0)), jnp.transpose(out_y, (1, 0))
pred = []
for i in range(self.out_dim):
pred.append(jnp.einsum("fx, fy->xy", out_x[self.r * i : self.r * (i + 1)], out_y[self.r * i : self.r * (i + 1)]))
return pred
def get_model_2D(backend=MLPType.RELU, embedding=EmbeddingType.PE100, decomp=DecompositionType.CP, rank=128, **kwargs):
if decomp == DecompositionType.BASELINE:
return INR_Baseline2D(r=rank, embedding=embedding, mlp=backend, in_dim=2, out_dim=3, **kwargs)
elif decomp == DecompositionType.CP:
return FINR_CP_2D(r=rank, embedding=embedding, mlp=backend, in_dim=2, out_dim=3, **kwargs)
raise ValueError(f"Unsupported decomposition type: {decomp}")
|