upload models
Browse files- PirateNet.py +85 -0
- __init__.py +91 -0
- activations.py +95 -0
- mlp.py +45 -0
- mlp_pinn.py +44 -0
- siren.py +66 -0
- utils.py +167 -0
- wire.py +135 -0
PirateNet.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import flax.linen as nn
|
| 4 |
+
from .utils import Dense, FourierEmbs
|
| 5 |
+
from typing import Union, Dict, Callable
|
| 6 |
+
|
| 7 |
+
class PIModifiedBottleneck(nn.Module):
|
| 8 |
+
hidden_dim: int
|
| 9 |
+
output_dim: int
|
| 10 |
+
act: Callable
|
| 11 |
+
nonlinearity: float
|
| 12 |
+
reparam: Union[None, Dict]
|
| 13 |
+
dtype: jnp.dtype = jnp.float32
|
| 14 |
+
|
| 15 |
+
@nn.compact
|
| 16 |
+
def __call__(self, x, u, v):
|
| 17 |
+
identity = x
|
| 18 |
+
|
| 19 |
+
x = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
|
| 20 |
+
x = self.act(x)
|
| 21 |
+
|
| 22 |
+
x = x * u + (1 - x) * v
|
| 23 |
+
|
| 24 |
+
x = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
|
| 25 |
+
x = self.act(x)
|
| 26 |
+
|
| 27 |
+
x = x * u + (1 - x) * v
|
| 28 |
+
|
| 29 |
+
x = Dense(features=self.output_dim, reparam=self.reparam, dtype=self.dtype)(x)
|
| 30 |
+
x = self.act(x)
|
| 31 |
+
|
| 32 |
+
alpha = self.param("alpha", nn.initializers.constant(self.nonlinearity), (1,))
|
| 33 |
+
x = alpha * x + (1 - alpha) * identity
|
| 34 |
+
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
class PirateNet(nn.Module):
|
| 38 |
+
num_layers: int
|
| 39 |
+
hidden_dim: int
|
| 40 |
+
output_dim: int
|
| 41 |
+
act: Callable = nn.silu
|
| 42 |
+
nonlinearity: float = 0.0
|
| 43 |
+
pi_init: Union[None, jnp.ndarray] = None
|
| 44 |
+
reparam : Union[None, Dict] = None
|
| 45 |
+
fourier_emb : Union[None, Dict] = None
|
| 46 |
+
dtype: jnp.dtype = jnp.float32
|
| 47 |
+
|
| 48 |
+
@nn.compact
|
| 49 |
+
def __call__(self, x):
|
| 50 |
+
embs = FourierEmbs(**self.fourier_emb)(x)
|
| 51 |
+
x = embs
|
| 52 |
+
|
| 53 |
+
u = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
|
| 54 |
+
u = self.act(u)
|
| 55 |
+
|
| 56 |
+
v = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
|
| 57 |
+
v = self.act(v)
|
| 58 |
+
|
| 59 |
+
for _ in range(self.num_layers):
|
| 60 |
+
x = PIModifiedBottleneck(
|
| 61 |
+
hidden_dim=self.hidden_dim,
|
| 62 |
+
output_dim=x.shape[-1],
|
| 63 |
+
act=self.act,
|
| 64 |
+
nonlinearity=self.nonlinearity,
|
| 65 |
+
reparam=self.reparam,
|
| 66 |
+
dtype=self.dtype
|
| 67 |
+
)(x, u, v)
|
| 68 |
+
|
| 69 |
+
if self.pi_init is not None:
|
| 70 |
+
kernel = self.param("pi_init", nn.initializers.constant(self.pi_init, dtype=self.dtype), self.pi_init.shape)
|
| 71 |
+
y = jnp.dot(x, kernel)
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
y = Dense(features=self.output_dim, reparam=self.reparam, dtype=self.dtype)(x)
|
| 75 |
+
|
| 76 |
+
return x, y
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
# Example usage
|
| 80 |
+
from activations import cauchy
|
| 81 |
+
cauchy_mod = lambda x : cauchy()(x)
|
| 82 |
+
model = PirateNet(num_layers=3, hidden_dim=32, output_dim=16, act=cauchy_mod, reparam=None, fourier_emb={'embed_scale': 1.0, 'embed_dim': 64})
|
| 83 |
+
params = model.init(jax.random.PRNGKey(0), jnp.ones(3))
|
| 84 |
+
output = model.apply(params, jnp.ones(3))
|
| 85 |
+
print(params)
|
__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp_pinn import MLP_PINN
|
| 2 |
+
from .PirateNet import PirateNet
|
| 3 |
+
from .mlp import MLP
|
| 4 |
+
from .siren import SIREN
|
| 5 |
+
from .wire import WIRE
|
| 6 |
+
from .activations import get_activation, list_activations
|
| 7 |
+
|
| 8 |
+
model_key_dict = {
|
| 9 |
+
"MLP": MLP,
|
| 10 |
+
"SIREN": SIREN,
|
| 11 |
+
"WIRE": WIRE,
|
| 12 |
+
"PirateNet": PirateNet,
|
| 13 |
+
"MLP_PINN": MLP_PINN
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def get_model(model_name : str):
|
| 17 |
+
"""
|
| 18 |
+
Get the model class by name.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model_name (str): Name of the model.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
nn.Module: The model class.
|
| 25 |
+
"""
|
| 26 |
+
if model_name not in model_key_dict:
|
| 27 |
+
raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}")
|
| 28 |
+
|
| 29 |
+
return model_key_dict[model_name]
|
| 30 |
+
|
| 31 |
+
def create_model_configs():
|
| 32 |
+
"""
|
| 33 |
+
Create a dictionary of model configurations.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
dict: A dictionary of model configurations.
|
| 37 |
+
"""
|
| 38 |
+
model_configs = {
|
| 39 |
+
"MLP": {},
|
| 40 |
+
"SIREN": {
|
| 41 |
+
"omega_0": 3.
|
| 42 |
+
},
|
| 43 |
+
"WIRE": {
|
| 44 |
+
"first_omega_0": 4.,
|
| 45 |
+
"hidden_omega_0": 4.,
|
| 46 |
+
"scale": 5.,
|
| 47 |
+
},
|
| 48 |
+
"PirateNet": {
|
| 49 |
+
"nonlinearity": 0.0,
|
| 50 |
+
"pi_init": None,
|
| 51 |
+
"reparam": {
|
| 52 |
+
"type": "weight_fact",
|
| 53 |
+
"mean": 1.0,
|
| 54 |
+
"stddev": 0.1,
|
| 55 |
+
},
|
| 56 |
+
"fourier_emb": {
|
| 57 |
+
"embed_scale": 2.,
|
| 58 |
+
"embed_dim": 256,
|
| 59 |
+
},
|
| 60 |
+
},
|
| 61 |
+
"MLP_PINN": {
|
| 62 |
+
"reparam": {
|
| 63 |
+
"type": "weight_fact",
|
| 64 |
+
"mean": 1.0,
|
| 65 |
+
"stddev": 0.1,
|
| 66 |
+
},
|
| 67 |
+
"fourier_emb": {
|
| 68 |
+
"embed_scale": 2.,
|
| 69 |
+
"embed_dim": 256,
|
| 70 |
+
},
|
| 71 |
+
},
|
| 72 |
+
|
| 73 |
+
}
|
| 74 |
+
return model_configs
|
| 75 |
+
|
| 76 |
+
model_configs = create_model_configs()
|
| 77 |
+
|
| 78 |
+
def get_extra_model_cfg(model_name: str):
|
| 79 |
+
"""
|
| 80 |
+
Get the extra model configuration for a given model name.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
model_name (str): Name of the model.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
dict: The extra model configuration.
|
| 87 |
+
"""
|
| 88 |
+
if model_name not in model_configs:
|
| 89 |
+
raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}")
|
| 90 |
+
|
| 91 |
+
return model_configs[model_name]
|
activations.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import flax.linen as nn
|
| 4 |
+
|
| 5 |
+
activation_function = {
|
| 6 |
+
"relu": nn.relu,
|
| 7 |
+
"gelu": nn.gelu,
|
| 8 |
+
"silu": nn.silu,
|
| 9 |
+
"swish": nn.silu,
|
| 10 |
+
"tanh": nn.tanh,
|
| 11 |
+
"sigmoid": nn.sigmoid,
|
| 12 |
+
"softplus": nn.softplus,
|
| 13 |
+
"softmax": nn.softmax,
|
| 14 |
+
"leaky_relu": nn.leaky_relu,
|
| 15 |
+
"elu": nn.elu,
|
| 16 |
+
"selu": nn.selu,
|
| 17 |
+
"telu": lambda x: x * jnp.tanh(jnp.exp(x)),
|
| 18 |
+
"mish": lambda x: x * jnp.tanh(nn.softplus(x)),
|
| 19 |
+
"cauchy": lambda x: cauchy()(x),
|
| 20 |
+
"identity": lambda x: x,
|
| 21 |
+
"react": lambda x: react()(x),
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# https://arxiv.org/pdf/2503.02267v1
|
| 25 |
+
class react(nn.Module):
|
| 26 |
+
@nn.compact
|
| 27 |
+
def __call__(self, x):
|
| 28 |
+
a = self.param(
|
| 29 |
+
'a',
|
| 30 |
+
jax.nn.initializers.normal(0.1),
|
| 31 |
+
()
|
| 32 |
+
)
|
| 33 |
+
b = self.param(
|
| 34 |
+
'b',
|
| 35 |
+
jax.nn.initializers.normal(0.1),
|
| 36 |
+
()
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
c = self.param(
|
| 40 |
+
'c',
|
| 41 |
+
jax.nn.initializers.normal(0.1),
|
| 42 |
+
()
|
| 43 |
+
)
|
| 44 |
+
d = self.param(
|
| 45 |
+
'd',
|
| 46 |
+
jax.nn.initializers.normal(0.1),
|
| 47 |
+
()
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return (1 - jnp.exp(a * x + b)) / (1 + jnp.exp(c * x + d))
|
| 51 |
+
|
| 52 |
+
# https://arxiv.org/abs/2409.19221
|
| 53 |
+
class cauchy(nn.Module):
|
| 54 |
+
@nn.compact
|
| 55 |
+
def __call__(self, x):
|
| 56 |
+
l1 = self.param(
|
| 57 |
+
'lambda1',
|
| 58 |
+
jax.nn.initializers.constant(1.0),
|
| 59 |
+
()
|
| 60 |
+
)
|
| 61 |
+
l2 = self.param(
|
| 62 |
+
'lambda2',
|
| 63 |
+
jax.nn.initializers.constant(1.0),
|
| 64 |
+
()
|
| 65 |
+
)
|
| 66 |
+
d = self.param(
|
| 67 |
+
'd',
|
| 68 |
+
jax.nn.initializers.constant(1.0),
|
| 69 |
+
()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return l1 * x / (x**2 + d**2) + l2 / (x**2 + d**2)
|
| 73 |
+
|
| 74 |
+
def get_activation(name):
|
| 75 |
+
"""
|
| 76 |
+
Get the activation function by name.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
name (str): Name of the activation function.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Callable: The activation function.
|
| 83 |
+
"""
|
| 84 |
+
if name not in activation_function:
|
| 85 |
+
raise ValueError(f"Activation function `{name}` is not supported. Supported activations are : {list(activation_function.keys())}")
|
| 86 |
+
return activation_function[name]
|
| 87 |
+
|
| 88 |
+
def list_activations():
|
| 89 |
+
"""
|
| 90 |
+
List all available activation functions.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
list: A list of activation names.
|
| 94 |
+
"""
|
| 95 |
+
return list(activation_function.keys())
|
mlp.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Flax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
from flax import linen as nn
|
| 4 |
+
import jax
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
class MLP(nn.Module):
|
| 8 |
+
hidden_dim: int
|
| 9 |
+
output_dim: int
|
| 10 |
+
num_layers: int
|
| 11 |
+
act: Callable = nn.silu
|
| 12 |
+
dtype: jnp.dtype = jnp.float32
|
| 13 |
+
|
| 14 |
+
@nn.compact
|
| 15 |
+
def __call__(self, x):
|
| 16 |
+
x = nn.Dense(
|
| 17 |
+
features=self.hidden_dim,
|
| 18 |
+
use_bias=True,
|
| 19 |
+
kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
|
| 20 |
+
param_dtype=self.dtype
|
| 21 |
+
)(x)
|
| 22 |
+
x = self.act(x)
|
| 23 |
+
for _ in range(self.num_layers):
|
| 24 |
+
x = nn.Dense(
|
| 25 |
+
features=self.hidden_dim,
|
| 26 |
+
use_bias=True,
|
| 27 |
+
kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
|
| 28 |
+
param_dtype=self.dtype
|
| 29 |
+
)(x)
|
| 30 |
+
x = self.act(x)
|
| 31 |
+
x = nn.Dense(
|
| 32 |
+
features=self.output_dim,
|
| 33 |
+
use_bias=True,
|
| 34 |
+
kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
|
| 35 |
+
param_dtype=self.dtype
|
| 36 |
+
)(x)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
# Example usage
|
| 41 |
+
x = jax.random.uniform(jax.random.PRNGKey(0), (1, 3), minval=-3, maxval=3)
|
| 42 |
+
model = MLP(hidden_dim=32, output_dim=16, num_layers=3)
|
| 43 |
+
params = model.init(jax.random.PRNGKey(0), x)
|
| 44 |
+
model_fn = lambda params, x : model.apply(params, x)
|
| 45 |
+
print(model_fn(params, x).shape)
|
mlp_pinn.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Flax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
from flax import linen as nn
|
| 4 |
+
from typing import Callable, Union, Dict
|
| 5 |
+
from .utils import Dense, FourierEmbs
|
| 6 |
+
|
| 7 |
+
# Modified MLP version based on the state-of-the-art practicies in PINN training:
|
| 8 |
+
# Fourier embeddings and random weight factorization
|
| 9 |
+
# You can read more about it in the paper: https://arxiv.org/pdf/2210.01274
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MLP_PINN(nn.Module):
|
| 13 |
+
hidden_dim: int
|
| 14 |
+
output_dim: int
|
| 15 |
+
num_layers: int
|
| 16 |
+
act: Callable = nn.silu
|
| 17 |
+
dtype: jnp.dtype = jnp.float32
|
| 18 |
+
reparam : Union[None, Dict] = None
|
| 19 |
+
fourier_emb : Union[None, Dict] = None
|
| 20 |
+
|
| 21 |
+
@nn.compact
|
| 22 |
+
def __call__(self, x):
|
| 23 |
+
if self.fourier_emb is not None:
|
| 24 |
+
x = FourierEmbs(**self.fourier_emb)(x)
|
| 25 |
+
else:
|
| 26 |
+
x = Dense(
|
| 27 |
+
features=self.hidden_dim,
|
| 28 |
+
reparam=self.reparam,
|
| 29 |
+
dtype=self.dtype
|
| 30 |
+
)(x)
|
| 31 |
+
x = self.act(x)
|
| 32 |
+
for _ in range(self.num_layers):
|
| 33 |
+
x = Dense(
|
| 34 |
+
features=self.hidden_dim,
|
| 35 |
+
reparam=self.reparam,
|
| 36 |
+
dtype=self.dtype
|
| 37 |
+
)(x)
|
| 38 |
+
x = self.act(x)
|
| 39 |
+
x = Dense(
|
| 40 |
+
features=self.output_dim,
|
| 41 |
+
reparam=self.reparam,
|
| 42 |
+
dtype=self.dtype
|
| 43 |
+
)(x)
|
| 44 |
+
return x
|
siren.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax.numpy as jnp
|
| 2 |
+
from flax import linen as nn
|
| 3 |
+
|
| 4 |
+
from .utils import custom_uniform
|
| 5 |
+
|
| 6 |
+
class SIREN(nn.Module):
|
| 7 |
+
output_dim: int
|
| 8 |
+
hidden_dim: int
|
| 9 |
+
num_layers: int
|
| 10 |
+
omega_0: float
|
| 11 |
+
dtype: jnp.dtype = jnp.float32
|
| 12 |
+
|
| 13 |
+
def setup(self):
|
| 14 |
+
self.kernel_net = [
|
| 15 |
+
SirenLayer(
|
| 16 |
+
output_dim=self.hidden_dim,
|
| 17 |
+
omega_0=self.omega_0,
|
| 18 |
+
is_first_layer=True,
|
| 19 |
+
dtype=self.dtype
|
| 20 |
+
)
|
| 21 |
+
] + [
|
| 22 |
+
SirenLayer(
|
| 23 |
+
output_dim=self.hidden_dim,
|
| 24 |
+
omega_0=self.omega_0,
|
| 25 |
+
dtype=self.dtype
|
| 26 |
+
)
|
| 27 |
+
for _ in range(self.num_layers)
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
self.output_linear = nn.Dense(
|
| 31 |
+
features=self.output_dim,
|
| 32 |
+
use_bias=True,
|
| 33 |
+
kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal", dtype=self.dtype),
|
| 34 |
+
bias_init=nn.initializers.zeros,
|
| 35 |
+
param_dtype=self.dtype
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def __call__(self, x):
|
| 39 |
+
for layer in self.kernel_net:
|
| 40 |
+
x = layer(x)
|
| 41 |
+
|
| 42 |
+
out = self.output_linear(x)
|
| 43 |
+
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SirenLayer(nn.Module):
|
| 48 |
+
output_dim: int
|
| 49 |
+
omega_0: float
|
| 50 |
+
is_first_layer: bool = False
|
| 51 |
+
dtype: jnp.dtype = jnp.float32
|
| 52 |
+
|
| 53 |
+
def setup(self):
|
| 54 |
+
c = 1 if self.is_first_layer else 6 / self.omega_0**2
|
| 55 |
+
distrib = "uniform_squared" if self.is_first_layer else "uniform"
|
| 56 |
+
self.linear = nn.Dense(
|
| 57 |
+
features=self.output_dim,
|
| 58 |
+
use_bias=True,
|
| 59 |
+
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
|
| 60 |
+
bias_init=nn.initializers.zeros,
|
| 61 |
+
param_dtype=self.dtype
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def __call__(self, x):
|
| 65 |
+
after_linear = self.omega_0 * self.linear(x)
|
| 66 |
+
return jnp.sin(after_linear)
|
utils.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import math
|
| 3 |
+
from typing import Any, Dict, Sequence, Union
|
| 4 |
+
|
| 5 |
+
import jax.numpy as jnp
|
| 6 |
+
from jax import dtypes, random
|
| 7 |
+
from jax.nn.initializers import Initializer
|
| 8 |
+
from typing import Callable
|
| 9 |
+
from flax import linen as nn
|
| 10 |
+
|
| 11 |
+
class FourierEmbs(nn.Module):
|
| 12 |
+
embed_scale: float
|
| 13 |
+
embed_dim: int
|
| 14 |
+
dtype: jnp.dtype = jnp.float32
|
| 15 |
+
|
| 16 |
+
@nn.compact
|
| 17 |
+
def __call__(self, x):
|
| 18 |
+
kernel = self.param(
|
| 19 |
+
"kernel", jax.nn.initializers.normal(self.embed_scale, dtype=self.dtype), (x.shape[-1], self.embed_dim // 2)
|
| 20 |
+
)
|
| 21 |
+
y = jnp.concatenate(
|
| 22 |
+
[jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
|
| 23 |
+
)
|
| 24 |
+
return y
|
| 25 |
+
|
| 26 |
+
def _weight_fact(init_fn, mean, stddev, dtype=jnp.float32):
|
| 27 |
+
def init(key, shape):
|
| 28 |
+
key1, key2 = jax.random.split(key)
|
| 29 |
+
w = init_fn(key1, shape)
|
| 30 |
+
g = mean + nn.initializers.normal(stddev, dtype=dtype)(key2, (shape[-1],))
|
| 31 |
+
g = jnp.exp(g)
|
| 32 |
+
v = w / g
|
| 33 |
+
return g, v
|
| 34 |
+
|
| 35 |
+
return init
|
| 36 |
+
|
| 37 |
+
class Dense(nn.Module):
|
| 38 |
+
features: int
|
| 39 |
+
kernel_init: Callable = nn.initializers.glorot_normal()
|
| 40 |
+
bias_init: Callable = nn.initializers.zeros
|
| 41 |
+
reparam : Union[None, Dict] = None
|
| 42 |
+
dtype: jnp.dtype = jnp.float32
|
| 43 |
+
|
| 44 |
+
@nn.compact
|
| 45 |
+
def __call__(self, x):
|
| 46 |
+
if self.reparam is None:
|
| 47 |
+
kernel = self.param(
|
| 48 |
+
"kernel", self.kernel_init(dtype=self.dtype), (x.shape[-1], self.features)
|
| 49 |
+
)
|
| 50 |
+
elif self.reparam["type"] == "weight_fact":
|
| 51 |
+
g, v = self.param(
|
| 52 |
+
"kernel",
|
| 53 |
+
_weight_fact(
|
| 54 |
+
self.kernel_init,
|
| 55 |
+
mean=self.reparam["mean"],
|
| 56 |
+
stddev=self.reparam["stddev"],
|
| 57 |
+
dtype=self.dtype
|
| 58 |
+
),
|
| 59 |
+
(x.shape[-1], self.features),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
kernel = g * v
|
| 63 |
+
|
| 64 |
+
bias = self.param("bias", self.bias_init(dtype=self.dtype), (self.features,))
|
| 65 |
+
|
| 66 |
+
y = jnp.dot(x, kernel) + bias
|
| 67 |
+
|
| 68 |
+
return y
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _compute_fans(
|
| 72 |
+
shape: tuple,
|
| 73 |
+
in_axis: Union[int, Sequence[int]] = -2,
|
| 74 |
+
out_axis: Union[int, Sequence[int]] = -1,
|
| 75 |
+
batch_axis: Union[int, Sequence[int]] = (),
|
| 76 |
+
):
|
| 77 |
+
"""Compute effective input and output sizes for a linear or convolutional layer.
|
| 78 |
+
|
| 79 |
+
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of
|
| 80 |
+
a convolution (kernel spatial dimensions).
|
| 81 |
+
"""
|
| 82 |
+
if len(shape) <= 1:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"Can't compute input and output sizes of a {shape.rank}"
|
| 85 |
+
"-dimensional weights tensor. Must be at least 2D."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if isinstance(in_axis, int):
|
| 89 |
+
in_size = shape[in_axis]
|
| 90 |
+
else:
|
| 91 |
+
in_size = math.prod([shape[i] for i in in_axis])
|
| 92 |
+
if isinstance(out_axis, int):
|
| 93 |
+
out_size = shape[out_axis]
|
| 94 |
+
else:
|
| 95 |
+
out_size = math.prod([shape[i] for i in out_axis])
|
| 96 |
+
if isinstance(batch_axis, int):
|
| 97 |
+
batch_size = shape[batch_axis]
|
| 98 |
+
else:
|
| 99 |
+
batch_size = math.prod([shape[i] for i in batch_axis])
|
| 100 |
+
receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
|
| 101 |
+
fan_in = in_size * receptive_field_size
|
| 102 |
+
fan_out = out_size * receptive_field_size
|
| 103 |
+
return fan_in, fan_out
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def custom_uniform(
|
| 107 |
+
numerator: float = 6,
|
| 108 |
+
mode: str = "fan_in",
|
| 109 |
+
dtype: jnp.dtype = jnp.float32,
|
| 110 |
+
in_axis: Union[int, Sequence[int]] = -2,
|
| 111 |
+
out_axis: Union[int, Sequence[int]] = -1,
|
| 112 |
+
batch_axis: Sequence[int] = (),
|
| 113 |
+
distribution: str = "uniform",
|
| 114 |
+
) -> Initializer:
|
| 115 |
+
"""Builds an initializer that returns real uniformly-distributed random arrays.
|
| 116 |
+
|
| 117 |
+
:param numerator: the numerator of the range of the random distribution.
|
| 118 |
+
:type numerator: float
|
| 119 |
+
:param mode: the mode for computing the range of the random distribution.
|
| 120 |
+
:type mode: str
|
| 121 |
+
:param dtype: optional; the initializer's default dtype.
|
| 122 |
+
:type dtype: jnp.dtype
|
| 123 |
+
:param in_axis: the axis or axes that specify the input size.
|
| 124 |
+
:type in_axis: Union[int, Sequence[int]]
|
| 125 |
+
:param out_axis: the axis or axes that specify the output size.
|
| 126 |
+
:type out_axis: Union[int, Sequence[int]]
|
| 127 |
+
:param batch_axis: the axis or axes that specify the batch size.
|
| 128 |
+
:type batch_axis: Sequence[int]
|
| 129 |
+
:param distribution: the distribution of the random distribution.
|
| 130 |
+
:type distribution: str
|
| 131 |
+
|
| 132 |
+
:return: An initializer that returns arrays whose values are uniformly distributed in
|
| 133 |
+
the range ``[-range, range)``.
|
| 134 |
+
:rtype: Initializer
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
|
| 138 |
+
dtype = dtypes.canonicalize_dtype(dtype)
|
| 139 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
|
| 140 |
+
if mode == "fan_in":
|
| 141 |
+
denominator = fan_in
|
| 142 |
+
elif mode == "fan_out":
|
| 143 |
+
denominator = fan_out
|
| 144 |
+
elif mode == "fan_avg":
|
| 145 |
+
denominator = (fan_in + fan_out) / 2
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError(f"invalid mode for variance scaling initializer: {mode}")
|
| 148 |
+
if distribution == "uniform":
|
| 149 |
+
return random.uniform(
|
| 150 |
+
key,
|
| 151 |
+
shape,
|
| 152 |
+
dtype,
|
| 153 |
+
minval=-jnp.sqrt(numerator / denominator),
|
| 154 |
+
maxval=jnp.sqrt(numerator / denominator),
|
| 155 |
+
)
|
| 156 |
+
elif distribution == "normal":
|
| 157 |
+
return random.normal(key, shape, dtype) * jnp.sqrt(numerator / denominator)
|
| 158 |
+
elif distribution == "uniform_squared":
|
| 159 |
+
return random.uniform(
|
| 160 |
+
key, shape, dtype, minval=-numerator / denominator, maxval=numerator / denominator
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"invalid distribution for variance scaling initializer: {distribution}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
return init
|
wire.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax.numpy as jnp
|
| 2 |
+
from flax import linen as nn
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import jax
|
| 6 |
+
from .utils import custom_uniform
|
| 7 |
+
from jax.nn.initializers import Initializer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def complex_kernel_uniform_init(numerator : float = 6,
|
| 11 |
+
mode : str = "fan_in",
|
| 12 |
+
dtype : jnp.dtype = jnp.float32,
|
| 13 |
+
distribution: str = "uniform") -> Initializer:
|
| 14 |
+
def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
|
| 15 |
+
real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
|
| 16 |
+
imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
|
| 17 |
+
|
| 18 |
+
return real_kernel + 1j * imag_kernel
|
| 19 |
+
|
| 20 |
+
return init
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WIRE(nn.Module):
|
| 24 |
+
output_dim: int
|
| 25 |
+
hidden_dim: int
|
| 26 |
+
num_layers: int
|
| 27 |
+
hidden_omega_0: float
|
| 28 |
+
first_omega_0: float
|
| 29 |
+
scale: float
|
| 30 |
+
complexgabor: bool = False
|
| 31 |
+
dtype: jnp.dtype = jnp.float32
|
| 32 |
+
|
| 33 |
+
def setup(self):
|
| 34 |
+
if self.complexgabor:
|
| 35 |
+
WIRElayer = ComplexGaborLayer
|
| 36 |
+
dtype = jnp.complex64
|
| 37 |
+
else:
|
| 38 |
+
WIRElayer = RealGaborLayer
|
| 39 |
+
dtype = self.dtype
|
| 40 |
+
self.kernel_net = [
|
| 41 |
+
WIRElayer(
|
| 42 |
+
output_dim=self.hidden_dim,
|
| 43 |
+
omega_0=self.first_omega_0,
|
| 44 |
+
s_0=self.scale,
|
| 45 |
+
is_first_layer=True,
|
| 46 |
+
dtype=dtype
|
| 47 |
+
)
|
| 48 |
+
] + [
|
| 49 |
+
WIRElayer(
|
| 50 |
+
output_dim=self.hidden_dim,
|
| 51 |
+
omega_0=self.hidden_omega_0,
|
| 52 |
+
s_0=self.scale,
|
| 53 |
+
is_first_layer=False,
|
| 54 |
+
dtype=dtype
|
| 55 |
+
)
|
| 56 |
+
for _ in range(self.num_layers)
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
self.output_linear = nn.Dense(
|
| 60 |
+
features=self.output_dim,
|
| 61 |
+
use_bias=True,
|
| 62 |
+
kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"),
|
| 63 |
+
param_dtype=self.dtype,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def __call__(self, x):
|
| 67 |
+
for layer in self.kernel_net:
|
| 68 |
+
x = layer(x)
|
| 69 |
+
|
| 70 |
+
out = jnp.real(self.output_linear(x))
|
| 71 |
+
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ComplexGaborLayer(nn.Module):
|
| 76 |
+
output_dim: int
|
| 77 |
+
omega_0: float
|
| 78 |
+
s_0: float
|
| 79 |
+
is_first_layer: bool = False
|
| 80 |
+
dtype: jnp.dtype = jnp.float32
|
| 81 |
+
|
| 82 |
+
def setup(self):
|
| 83 |
+
c = 1 if self.is_first_layer else 6 / self.omega_0**2
|
| 84 |
+
distrib = "uniform_squared" if self.is_first_layer else "uniform"
|
| 85 |
+
|
| 86 |
+
if self.is_first_layer:
|
| 87 |
+
dtype = self.dtype
|
| 88 |
+
else:
|
| 89 |
+
dtype = jnp.complex64
|
| 90 |
+
|
| 91 |
+
self.linear = nn.Dense(
|
| 92 |
+
features=self.output_dim,
|
| 93 |
+
use_bias=True,
|
| 94 |
+
kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib),
|
| 95 |
+
param_dtype=dtype
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def __call__(self, x):
|
| 99 |
+
omega = self.omega_0 * self.linear(x)
|
| 100 |
+
scale = self.s_0 * self.linear(x)
|
| 101 |
+
|
| 102 |
+
return jnp.exp(1j * omega - (jnp.abs(scale)**2))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class RealGaborLayer(nn.Module):
|
| 106 |
+
output_dim: int
|
| 107 |
+
omega_0: float
|
| 108 |
+
s_0: float
|
| 109 |
+
is_first_layer: bool = False
|
| 110 |
+
dtype: jnp.dtype = jnp.float32
|
| 111 |
+
|
| 112 |
+
def setup(self):
|
| 113 |
+
|
| 114 |
+
c = 1 if self.is_first_layer else 6 / self.omega_0**2
|
| 115 |
+
distrib = "uniform_squared" if self.is_first_layer else "uniform"
|
| 116 |
+
|
| 117 |
+
self.freqs = nn.Dense(
|
| 118 |
+
features=self.output_dim,
|
| 119 |
+
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
|
| 120 |
+
use_bias=True,
|
| 121 |
+
param_dtype=self.dtype
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.scales = nn.Dense(
|
| 125 |
+
features = self.output_dim,
|
| 126 |
+
kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
|
| 127 |
+
use_bias=True,
|
| 128 |
+
param_dtype=self.dtype
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def __call__(self, x):
|
| 132 |
+
omega = self.omega_0 * self.freqs(x)
|
| 133 |
+
scale = self.s_0 * self.scales(x)
|
| 134 |
+
|
| 135 |
+
return jnp.cos(omega) * jnp.exp(-(scale**2))
|