AndreiB137 commited on
Commit
2b59497
·
verified ·
1 Parent(s): 28f79a0

upload models

Browse files
Files changed (8) hide show
  1. PirateNet.py +85 -0
  2. __init__.py +91 -0
  3. activations.py +95 -0
  4. mlp.py +45 -0
  5. mlp_pinn.py +44 -0
  6. siren.py +66 -0
  7. utils.py +167 -0
  8. wire.py +135 -0
PirateNet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+ from .utils import Dense, FourierEmbs
5
+ from typing import Union, Dict, Callable
6
+
7
+ class PIModifiedBottleneck(nn.Module):
8
+ hidden_dim: int
9
+ output_dim: int
10
+ act: Callable
11
+ nonlinearity: float
12
+ reparam: Union[None, Dict]
13
+ dtype: jnp.dtype = jnp.float32
14
+
15
+ @nn.compact
16
+ def __call__(self, x, u, v):
17
+ identity = x
18
+
19
+ x = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
20
+ x = self.act(x)
21
+
22
+ x = x * u + (1 - x) * v
23
+
24
+ x = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
25
+ x = self.act(x)
26
+
27
+ x = x * u + (1 - x) * v
28
+
29
+ x = Dense(features=self.output_dim, reparam=self.reparam, dtype=self.dtype)(x)
30
+ x = self.act(x)
31
+
32
+ alpha = self.param("alpha", nn.initializers.constant(self.nonlinearity), (1,))
33
+ x = alpha * x + (1 - alpha) * identity
34
+
35
+ return x
36
+
37
+ class PirateNet(nn.Module):
38
+ num_layers: int
39
+ hidden_dim: int
40
+ output_dim: int
41
+ act: Callable = nn.silu
42
+ nonlinearity: float = 0.0
43
+ pi_init: Union[None, jnp.ndarray] = None
44
+ reparam : Union[None, Dict] = None
45
+ fourier_emb : Union[None, Dict] = None
46
+ dtype: jnp.dtype = jnp.float32
47
+
48
+ @nn.compact
49
+ def __call__(self, x):
50
+ embs = FourierEmbs(**self.fourier_emb)(x)
51
+ x = embs
52
+
53
+ u = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
54
+ u = self.act(u)
55
+
56
+ v = Dense(features=self.hidden_dim, reparam=self.reparam, dtype=self.dtype)(x)
57
+ v = self.act(v)
58
+
59
+ for _ in range(self.num_layers):
60
+ x = PIModifiedBottleneck(
61
+ hidden_dim=self.hidden_dim,
62
+ output_dim=x.shape[-1],
63
+ act=self.act,
64
+ nonlinearity=self.nonlinearity,
65
+ reparam=self.reparam,
66
+ dtype=self.dtype
67
+ )(x, u, v)
68
+
69
+ if self.pi_init is not None:
70
+ kernel = self.param("pi_init", nn.initializers.constant(self.pi_init, dtype=self.dtype), self.pi_init.shape)
71
+ y = jnp.dot(x, kernel)
72
+
73
+ else:
74
+ y = Dense(features=self.output_dim, reparam=self.reparam, dtype=self.dtype)(x)
75
+
76
+ return x, y
77
+
78
+ if __name__ == "__main__":
79
+ # Example usage
80
+ from activations import cauchy
81
+ cauchy_mod = lambda x : cauchy()(x)
82
+ model = PirateNet(num_layers=3, hidden_dim=32, output_dim=16, act=cauchy_mod, reparam=None, fourier_emb={'embed_scale': 1.0, 'embed_dim': 64})
83
+ params = model.init(jax.random.PRNGKey(0), jnp.ones(3))
84
+ output = model.apply(params, jnp.ones(3))
85
+ print(params)
__init__.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mlp_pinn import MLP_PINN
2
+ from .PirateNet import PirateNet
3
+ from .mlp import MLP
4
+ from .siren import SIREN
5
+ from .wire import WIRE
6
+ from .activations import get_activation, list_activations
7
+
8
+ model_key_dict = {
9
+ "MLP": MLP,
10
+ "SIREN": SIREN,
11
+ "WIRE": WIRE,
12
+ "PirateNet": PirateNet,
13
+ "MLP_PINN": MLP_PINN
14
+ }
15
+
16
+ def get_model(model_name : str):
17
+ """
18
+ Get the model class by name.
19
+
20
+ Args:
21
+ model_name (str): Name of the model.
22
+
23
+ Returns:
24
+ nn.Module: The model class.
25
+ """
26
+ if model_name not in model_key_dict:
27
+ raise ValueError(f"Model `{model_name}` is not supported. Supported models are: {list(model_key_dict.keys())}")
28
+
29
+ return model_key_dict[model_name]
30
+
31
+ def create_model_configs():
32
+ """
33
+ Create a dictionary of model configurations.
34
+
35
+ Returns:
36
+ dict: A dictionary of model configurations.
37
+ """
38
+ model_configs = {
39
+ "MLP": {},
40
+ "SIREN": {
41
+ "omega_0": 3.
42
+ },
43
+ "WIRE": {
44
+ "first_omega_0": 4.,
45
+ "hidden_omega_0": 4.,
46
+ "scale": 5.,
47
+ },
48
+ "PirateNet": {
49
+ "nonlinearity": 0.0,
50
+ "pi_init": None,
51
+ "reparam": {
52
+ "type": "weight_fact",
53
+ "mean": 1.0,
54
+ "stddev": 0.1,
55
+ },
56
+ "fourier_emb": {
57
+ "embed_scale": 2.,
58
+ "embed_dim": 256,
59
+ },
60
+ },
61
+ "MLP_PINN": {
62
+ "reparam": {
63
+ "type": "weight_fact",
64
+ "mean": 1.0,
65
+ "stddev": 0.1,
66
+ },
67
+ "fourier_emb": {
68
+ "embed_scale": 2.,
69
+ "embed_dim": 256,
70
+ },
71
+ },
72
+
73
+ }
74
+ return model_configs
75
+
76
+ model_configs = create_model_configs()
77
+
78
+ def get_extra_model_cfg(model_name: str):
79
+ """
80
+ Get the extra model configuration for a given model name.
81
+
82
+ Args:
83
+ model_name (str): Name of the model.
84
+
85
+ Returns:
86
+ dict: The extra model configuration.
87
+ """
88
+ if model_name not in model_configs:
89
+ raise ValueError(f"Model `{model_name}` is not supported. Available models are: {list(model_configs.keys())}")
90
+
91
+ return model_configs[model_name]
activations.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+
5
+ activation_function = {
6
+ "relu": nn.relu,
7
+ "gelu": nn.gelu,
8
+ "silu": nn.silu,
9
+ "swish": nn.silu,
10
+ "tanh": nn.tanh,
11
+ "sigmoid": nn.sigmoid,
12
+ "softplus": nn.softplus,
13
+ "softmax": nn.softmax,
14
+ "leaky_relu": nn.leaky_relu,
15
+ "elu": nn.elu,
16
+ "selu": nn.selu,
17
+ "telu": lambda x: x * jnp.tanh(jnp.exp(x)),
18
+ "mish": lambda x: x * jnp.tanh(nn.softplus(x)),
19
+ "cauchy": lambda x: cauchy()(x),
20
+ "identity": lambda x: x,
21
+ "react": lambda x: react()(x),
22
+ }
23
+
24
+ # https://arxiv.org/pdf/2503.02267v1
25
+ class react(nn.Module):
26
+ @nn.compact
27
+ def __call__(self, x):
28
+ a = self.param(
29
+ 'a',
30
+ jax.nn.initializers.normal(0.1),
31
+ ()
32
+ )
33
+ b = self.param(
34
+ 'b',
35
+ jax.nn.initializers.normal(0.1),
36
+ ()
37
+ )
38
+
39
+ c = self.param(
40
+ 'c',
41
+ jax.nn.initializers.normal(0.1),
42
+ ()
43
+ )
44
+ d = self.param(
45
+ 'd',
46
+ jax.nn.initializers.normal(0.1),
47
+ ()
48
+ )
49
+
50
+ return (1 - jnp.exp(a * x + b)) / (1 + jnp.exp(c * x + d))
51
+
52
+ # https://arxiv.org/abs/2409.19221
53
+ class cauchy(nn.Module):
54
+ @nn.compact
55
+ def __call__(self, x):
56
+ l1 = self.param(
57
+ 'lambda1',
58
+ jax.nn.initializers.constant(1.0),
59
+ ()
60
+ )
61
+ l2 = self.param(
62
+ 'lambda2',
63
+ jax.nn.initializers.constant(1.0),
64
+ ()
65
+ )
66
+ d = self.param(
67
+ 'd',
68
+ jax.nn.initializers.constant(1.0),
69
+ ()
70
+ )
71
+
72
+ return l1 * x / (x**2 + d**2) + l2 / (x**2 + d**2)
73
+
74
+ def get_activation(name):
75
+ """
76
+ Get the activation function by name.
77
+
78
+ Args:
79
+ name (str): Name of the activation function.
80
+
81
+ Returns:
82
+ Callable: The activation function.
83
+ """
84
+ if name not in activation_function:
85
+ raise ValueError(f"Activation function `{name}` is not supported. Supported activations are : {list(activation_function.keys())}")
86
+ return activation_function[name]
87
+
88
+ def list_activations():
89
+ """
90
+ List all available activation functions.
91
+
92
+ Returns:
93
+ list: A list of activation names.
94
+ """
95
+ return list(activation_function.keys())
mlp.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable
6
+
7
+ class MLP(nn.Module):
8
+ hidden_dim: int
9
+ output_dim: int
10
+ num_layers: int
11
+ act: Callable = nn.silu
12
+ dtype: jnp.dtype = jnp.float32
13
+
14
+ @nn.compact
15
+ def __call__(self, x):
16
+ x = nn.Dense(
17
+ features=self.hidden_dim,
18
+ use_bias=True,
19
+ kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
20
+ param_dtype=self.dtype
21
+ )(x)
22
+ x = self.act(x)
23
+ for _ in range(self.num_layers):
24
+ x = nn.Dense(
25
+ features=self.hidden_dim,
26
+ use_bias=True,
27
+ kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
28
+ param_dtype=self.dtype
29
+ )(x)
30
+ x = self.act(x)
31
+ x = nn.Dense(
32
+ features=self.output_dim,
33
+ use_bias=True,
34
+ kernel_init=nn.initializers.glorot_normal(dtype=self.dtype),
35
+ param_dtype=self.dtype
36
+ )(x)
37
+ return x
38
+
39
+ if __name__ == "__main__":
40
+ # Example usage
41
+ x = jax.random.uniform(jax.random.PRNGKey(0), (1, 3), minval=-3, maxval=3)
42
+ model = MLP(hidden_dim=32, output_dim=16, num_layers=3)
43
+ params = model.init(jax.random.PRNGKey(0), x)
44
+ model_fn = lambda params, x : model.apply(params, x)
45
+ print(model_fn(params, x).shape)
mlp_pinn.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Callable, Union, Dict
5
+ from .utils import Dense, FourierEmbs
6
+
7
+ # Modified MLP version based on the state-of-the-art practicies in PINN training:
8
+ # Fourier embeddings and random weight factorization
9
+ # You can read more about it in the paper: https://arxiv.org/pdf/2210.01274
10
+
11
+
12
+ class MLP_PINN(nn.Module):
13
+ hidden_dim: int
14
+ output_dim: int
15
+ num_layers: int
16
+ act: Callable = nn.silu
17
+ dtype: jnp.dtype = jnp.float32
18
+ reparam : Union[None, Dict] = None
19
+ fourier_emb : Union[None, Dict] = None
20
+
21
+ @nn.compact
22
+ def __call__(self, x):
23
+ if self.fourier_emb is not None:
24
+ x = FourierEmbs(**self.fourier_emb)(x)
25
+ else:
26
+ x = Dense(
27
+ features=self.hidden_dim,
28
+ reparam=self.reparam,
29
+ dtype=self.dtype
30
+ )(x)
31
+ x = self.act(x)
32
+ for _ in range(self.num_layers):
33
+ x = Dense(
34
+ features=self.hidden_dim,
35
+ reparam=self.reparam,
36
+ dtype=self.dtype
37
+ )(x)
38
+ x = self.act(x)
39
+ x = Dense(
40
+ features=self.output_dim,
41
+ reparam=self.reparam,
42
+ dtype=self.dtype
43
+ )(x)
44
+ return x
siren.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ from flax import linen as nn
3
+
4
+ from .utils import custom_uniform
5
+
6
+ class SIREN(nn.Module):
7
+ output_dim: int
8
+ hidden_dim: int
9
+ num_layers: int
10
+ omega_0: float
11
+ dtype: jnp.dtype = jnp.float32
12
+
13
+ def setup(self):
14
+ self.kernel_net = [
15
+ SirenLayer(
16
+ output_dim=self.hidden_dim,
17
+ omega_0=self.omega_0,
18
+ is_first_layer=True,
19
+ dtype=self.dtype
20
+ )
21
+ ] + [
22
+ SirenLayer(
23
+ output_dim=self.hidden_dim,
24
+ omega_0=self.omega_0,
25
+ dtype=self.dtype
26
+ )
27
+ for _ in range(self.num_layers)
28
+ ]
29
+
30
+ self.output_linear = nn.Dense(
31
+ features=self.output_dim,
32
+ use_bias=True,
33
+ kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal", dtype=self.dtype),
34
+ bias_init=nn.initializers.zeros,
35
+ param_dtype=self.dtype
36
+ )
37
+
38
+ def __call__(self, x):
39
+ for layer in self.kernel_net:
40
+ x = layer(x)
41
+
42
+ out = self.output_linear(x)
43
+
44
+ return out
45
+
46
+
47
+ class SirenLayer(nn.Module):
48
+ output_dim: int
49
+ omega_0: float
50
+ is_first_layer: bool = False
51
+ dtype: jnp.dtype = jnp.float32
52
+
53
+ def setup(self):
54
+ c = 1 if self.is_first_layer else 6 / self.omega_0**2
55
+ distrib = "uniform_squared" if self.is_first_layer else "uniform"
56
+ self.linear = nn.Dense(
57
+ features=self.output_dim,
58
+ use_bias=True,
59
+ kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
60
+ bias_init=nn.initializers.zeros,
61
+ param_dtype=self.dtype
62
+ )
63
+
64
+ def __call__(self, x):
65
+ after_linear = self.omega_0 * self.linear(x)
66
+ return jnp.sin(after_linear)
utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import math
3
+ from typing import Any, Dict, Sequence, Union
4
+
5
+ import jax.numpy as jnp
6
+ from jax import dtypes, random
7
+ from jax.nn.initializers import Initializer
8
+ from typing import Callable
9
+ from flax import linen as nn
10
+
11
+ class FourierEmbs(nn.Module):
12
+ embed_scale: float
13
+ embed_dim: int
14
+ dtype: jnp.dtype = jnp.float32
15
+
16
+ @nn.compact
17
+ def __call__(self, x):
18
+ kernel = self.param(
19
+ "kernel", jax.nn.initializers.normal(self.embed_scale, dtype=self.dtype), (x.shape[-1], self.embed_dim // 2)
20
+ )
21
+ y = jnp.concatenate(
22
+ [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
23
+ )
24
+ return y
25
+
26
+ def _weight_fact(init_fn, mean, stddev, dtype=jnp.float32):
27
+ def init(key, shape):
28
+ key1, key2 = jax.random.split(key)
29
+ w = init_fn(key1, shape)
30
+ g = mean + nn.initializers.normal(stddev, dtype=dtype)(key2, (shape[-1],))
31
+ g = jnp.exp(g)
32
+ v = w / g
33
+ return g, v
34
+
35
+ return init
36
+
37
+ class Dense(nn.Module):
38
+ features: int
39
+ kernel_init: Callable = nn.initializers.glorot_normal()
40
+ bias_init: Callable = nn.initializers.zeros
41
+ reparam : Union[None, Dict] = None
42
+ dtype: jnp.dtype = jnp.float32
43
+
44
+ @nn.compact
45
+ def __call__(self, x):
46
+ if self.reparam is None:
47
+ kernel = self.param(
48
+ "kernel", self.kernel_init(dtype=self.dtype), (x.shape[-1], self.features)
49
+ )
50
+ elif self.reparam["type"] == "weight_fact":
51
+ g, v = self.param(
52
+ "kernel",
53
+ _weight_fact(
54
+ self.kernel_init,
55
+ mean=self.reparam["mean"],
56
+ stddev=self.reparam["stddev"],
57
+ dtype=self.dtype
58
+ ),
59
+ (x.shape[-1], self.features),
60
+ )
61
+
62
+ kernel = g * v
63
+
64
+ bias = self.param("bias", self.bias_init(dtype=self.dtype), (self.features,))
65
+
66
+ y = jnp.dot(x, kernel) + bias
67
+
68
+ return y
69
+
70
+
71
+ def _compute_fans(
72
+ shape: tuple,
73
+ in_axis: Union[int, Sequence[int]] = -2,
74
+ out_axis: Union[int, Sequence[int]] = -1,
75
+ batch_axis: Union[int, Sequence[int]] = (),
76
+ ):
77
+ """Compute effective input and output sizes for a linear or convolutional layer.
78
+
79
+ Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the "receptive field" of
80
+ a convolution (kernel spatial dimensions).
81
+ """
82
+ if len(shape) <= 1:
83
+ raise ValueError(
84
+ f"Can't compute input and output sizes of a {shape.rank}"
85
+ "-dimensional weights tensor. Must be at least 2D."
86
+ )
87
+
88
+ if isinstance(in_axis, int):
89
+ in_size = shape[in_axis]
90
+ else:
91
+ in_size = math.prod([shape[i] for i in in_axis])
92
+ if isinstance(out_axis, int):
93
+ out_size = shape[out_axis]
94
+ else:
95
+ out_size = math.prod([shape[i] for i in out_axis])
96
+ if isinstance(batch_axis, int):
97
+ batch_size = shape[batch_axis]
98
+ else:
99
+ batch_size = math.prod([shape[i] for i in batch_axis])
100
+ receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
101
+ fan_in = in_size * receptive_field_size
102
+ fan_out = out_size * receptive_field_size
103
+ return fan_in, fan_out
104
+
105
+
106
+ def custom_uniform(
107
+ numerator: float = 6,
108
+ mode: str = "fan_in",
109
+ dtype: jnp.dtype = jnp.float32,
110
+ in_axis: Union[int, Sequence[int]] = -2,
111
+ out_axis: Union[int, Sequence[int]] = -1,
112
+ batch_axis: Sequence[int] = (),
113
+ distribution: str = "uniform",
114
+ ) -> Initializer:
115
+ """Builds an initializer that returns real uniformly-distributed random arrays.
116
+
117
+ :param numerator: the numerator of the range of the random distribution.
118
+ :type numerator: float
119
+ :param mode: the mode for computing the range of the random distribution.
120
+ :type mode: str
121
+ :param dtype: optional; the initializer's default dtype.
122
+ :type dtype: jnp.dtype
123
+ :param in_axis: the axis or axes that specify the input size.
124
+ :type in_axis: Union[int, Sequence[int]]
125
+ :param out_axis: the axis or axes that specify the output size.
126
+ :type out_axis: Union[int, Sequence[int]]
127
+ :param batch_axis: the axis or axes that specify the batch size.
128
+ :type batch_axis: Sequence[int]
129
+ :param distribution: the distribution of the random distribution.
130
+ :type distribution: str
131
+
132
+ :return: An initializer that returns arrays whose values are uniformly distributed in
133
+ the range ``[-range, range)``.
134
+ :rtype: Initializer
135
+ """
136
+
137
+ def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
138
+ dtype = dtypes.canonicalize_dtype(dtype)
139
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
140
+ if mode == "fan_in":
141
+ denominator = fan_in
142
+ elif mode == "fan_out":
143
+ denominator = fan_out
144
+ elif mode == "fan_avg":
145
+ denominator = (fan_in + fan_out) / 2
146
+ else:
147
+ raise ValueError(f"invalid mode for variance scaling initializer: {mode}")
148
+ if distribution == "uniform":
149
+ return random.uniform(
150
+ key,
151
+ shape,
152
+ dtype,
153
+ minval=-jnp.sqrt(numerator / denominator),
154
+ maxval=jnp.sqrt(numerator / denominator),
155
+ )
156
+ elif distribution == "normal":
157
+ return random.normal(key, shape, dtype) * jnp.sqrt(numerator / denominator)
158
+ elif distribution == "uniform_squared":
159
+ return random.uniform(
160
+ key, shape, dtype, minval=-numerator / denominator, maxval=numerator / denominator
161
+ )
162
+ else:
163
+ raise ValueError(
164
+ f"invalid distribution for variance scaling initializer: {distribution}"
165
+ )
166
+
167
+ return init
wire.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ from flax import linen as nn
3
+ from typing import Any
4
+
5
+ import jax
6
+ from .utils import custom_uniform
7
+ from jax.nn.initializers import Initializer
8
+
9
+
10
+ def complex_kernel_uniform_init(numerator : float = 6,
11
+ mode : str = "fan_in",
12
+ dtype : jnp.dtype = jnp.float32,
13
+ distribution: str = "uniform") -> Initializer:
14
+ def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any:
15
+ real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
16
+ imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype)
17
+
18
+ return real_kernel + 1j * imag_kernel
19
+
20
+ return init
21
+
22
+
23
+ class WIRE(nn.Module):
24
+ output_dim: int
25
+ hidden_dim: int
26
+ num_layers: int
27
+ hidden_omega_0: float
28
+ first_omega_0: float
29
+ scale: float
30
+ complexgabor: bool = False
31
+ dtype: jnp.dtype = jnp.float32
32
+
33
+ def setup(self):
34
+ if self.complexgabor:
35
+ WIRElayer = ComplexGaborLayer
36
+ dtype = jnp.complex64
37
+ else:
38
+ WIRElayer = RealGaborLayer
39
+ dtype = self.dtype
40
+ self.kernel_net = [
41
+ WIRElayer(
42
+ output_dim=self.hidden_dim,
43
+ omega_0=self.first_omega_0,
44
+ s_0=self.scale,
45
+ is_first_layer=True,
46
+ dtype=dtype
47
+ )
48
+ ] + [
49
+ WIRElayer(
50
+ output_dim=self.hidden_dim,
51
+ omega_0=self.hidden_omega_0,
52
+ s_0=self.scale,
53
+ is_first_layer=False,
54
+ dtype=dtype
55
+ )
56
+ for _ in range(self.num_layers)
57
+ ]
58
+
59
+ self.output_linear = nn.Dense(
60
+ features=self.output_dim,
61
+ use_bias=True,
62
+ kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"),
63
+ param_dtype=self.dtype,
64
+ )
65
+
66
+ def __call__(self, x):
67
+ for layer in self.kernel_net:
68
+ x = layer(x)
69
+
70
+ out = jnp.real(self.output_linear(x))
71
+
72
+ return out
73
+
74
+
75
+ class ComplexGaborLayer(nn.Module):
76
+ output_dim: int
77
+ omega_0: float
78
+ s_0: float
79
+ is_first_layer: bool = False
80
+ dtype: jnp.dtype = jnp.float32
81
+
82
+ def setup(self):
83
+ c = 1 if self.is_first_layer else 6 / self.omega_0**2
84
+ distrib = "uniform_squared" if self.is_first_layer else "uniform"
85
+
86
+ if self.is_first_layer:
87
+ dtype = self.dtype
88
+ else:
89
+ dtype = jnp.complex64
90
+
91
+ self.linear = nn.Dense(
92
+ features=self.output_dim,
93
+ use_bias=True,
94
+ kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib),
95
+ param_dtype=dtype
96
+ )
97
+
98
+ def __call__(self, x):
99
+ omega = self.omega_0 * self.linear(x)
100
+ scale = self.s_0 * self.linear(x)
101
+
102
+ return jnp.exp(1j * omega - (jnp.abs(scale)**2))
103
+
104
+
105
+ class RealGaborLayer(nn.Module):
106
+ output_dim: int
107
+ omega_0: float
108
+ s_0: float
109
+ is_first_layer: bool = False
110
+ dtype: jnp.dtype = jnp.float32
111
+
112
+ def setup(self):
113
+
114
+ c = 1 if self.is_first_layer else 6 / self.omega_0**2
115
+ distrib = "uniform_squared" if self.is_first_layer else "uniform"
116
+
117
+ self.freqs = nn.Dense(
118
+ features=self.output_dim,
119
+ kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
120
+ use_bias=True,
121
+ param_dtype=self.dtype
122
+ )
123
+
124
+ self.scales = nn.Dense(
125
+ features = self.output_dim,
126
+ kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype),
127
+ use_bias=True,
128
+ param_dtype=self.dtype
129
+ )
130
+
131
+ def __call__(self, x):
132
+ omega = self.omega_0 * self.freqs(x)
133
+ scale = self.s_0 * self.scales(x)
134
+
135
+ return jnp.cos(omega) * jnp.exp(-(scale**2))