| |
| import jax.numpy as jnp |
| from flax import linen as nn |
| import jax |
| from typing import Callable |
|
|
| class MLP(nn.Module): |
| hidden_dim: int |
| output_dim: int |
| num_layers: int |
| act: Callable = nn.silu |
| dtype: jnp.dtype = jnp.float32 |
|
|
| @nn.compact |
| def __call__(self, x): |
| x = nn.Dense( |
| features=self.hidden_dim, |
| use_bias=True, |
| kernel_init=nn.initializers.glorot_normal(dtype=self.dtype), |
| param_dtype=self.dtype |
| )(x) |
| x = self.act(x) |
| for _ in range(self.num_layers): |
| x = nn.Dense( |
| features=self.hidden_dim, |
| use_bias=True, |
| kernel_init=nn.initializers.glorot_normal(dtype=self.dtype), |
| param_dtype=self.dtype |
| )(x) |
| x = self.act(x) |
| x = nn.Dense( |
| features=self.output_dim, |
| use_bias=True, |
| kernel_init=nn.initializers.glorot_normal(dtype=self.dtype), |
| param_dtype=self.dtype |
| )(x) |
| return x |
|
|
| if __name__ == "__main__": |
| |
| x = jax.random.uniform(jax.random.PRNGKey(0), (1, 3), minval=-3, maxval=3) |
| model = MLP(hidden_dim=32, output_dim=16, num_layers=3) |
| params = model.init(jax.random.PRNGKey(0), x) |
| model_fn = lambda params, x : model.apply(params, x) |
| print(model_fn(params, x).shape) |