import jax.numpy as jnp from flax import linen as nn from typing import Any import jax from .utils import custom_uniform from jax.nn.initializers import Initializer def complex_kernel_uniform_init(numerator : float = 6, mode : str = "fan_in", dtype : jnp.dtype = jnp.float32, distribution: str = "uniform") -> Initializer: def init(key: jax.random.key, shape: tuple, dtype: Any = dtype) -> Any: real_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype) imag_kernel = custom_uniform(numerator=numerator, mode=mode, distribution=distribution)(key, shape, dtype) return real_kernel + 1j * imag_kernel return init class WIRE(nn.Module): output_dim: int hidden_dim: int num_layers: int hidden_omega_0: float first_omega_0: float scale: float complexgabor: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): if self.complexgabor: WIRElayer = ComplexGaborLayer dtype = jnp.complex64 else: WIRElayer = RealGaborLayer dtype = self.dtype self.kernel_net = [ WIRElayer( output_dim=self.hidden_dim, omega_0=self.first_omega_0, s_0=self.scale, is_first_layer=True, dtype=dtype ) ] + [ WIRElayer( output_dim=self.hidden_dim, omega_0=self.hidden_omega_0, s_0=self.scale, is_first_layer=False, dtype=dtype ) for _ in range(self.num_layers) ] self.output_linear = nn.Dense( features=self.output_dim, use_bias=True, kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"), param_dtype=self.dtype, ) def __call__(self, x): for layer in self.kernel_net: x = layer(x) out = jnp.real(self.output_linear(x)) return out class ComplexGaborLayer(nn.Module): output_dim: int omega_0: float s_0: float is_first_layer: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): c = 1 if self.is_first_layer else 6 / self.omega_0**2 distrib = "uniform_squared" if self.is_first_layer else "uniform" if self.is_first_layer: dtype = self.dtype else: dtype = jnp.complex64 self.linear = nn.Dense( features=self.output_dim, use_bias=True, kernel_init=complex_kernel_uniform_init(numerator=c, mode="fan_in", distribution=distrib), param_dtype=dtype ) def __call__(self, x): omega = self.omega_0 * self.linear(x) scale = self.s_0 * self.linear(x) return jnp.exp(1j * omega - (jnp.abs(scale)**2)) class RealGaborLayer(nn.Module): output_dim: int omega_0: float s_0: float is_first_layer: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): c = 1 if self.is_first_layer else 6 / self.omega_0**2 distrib = "uniform_squared" if self.is_first_layer else "uniform" self.freqs = nn.Dense( features=self.output_dim, kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype), use_bias=True, param_dtype=self.dtype ) self.scales = nn.Dense( features = self.output_dim, kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib, dtype=self.dtype), use_bias=True, param_dtype=self.dtype ) def __call__(self, x): omega = self.omega_0 * self.freqs(x) scale = self.s_0 * self.scales(x) return jnp.cos(omega) * jnp.exp(-(scale**2))