| import jax |
| from typing import Any, Callable, Sequence, Optional |
| from jax import lax, random, numpy as jnp |
| import flax |
| from flax.training import train_state |
| from flax.core import freeze, unfreeze |
| from flax import linen as nn |
| from flax import serialization |
| import optax |
|
|
|
|
| class ExplicitMLP(nn.Module): |
| features: Sequence[int] |
|
|
| def setup(self): |
| self.layers = [nn.Dense(feat) for feat in self.features] |
|
|
| def __call__(self, inputs): |
| x = inputs |
| for i, lyr in enumerate(self.layers): |
| x = lyr(x) |
| if i != len(self.layers) - 1: |
| x = nn.relu(x) |
| return x |
|
|
|
|
| class SimpleMLP(nn.Module): |
| features: Sequence[int] |
|
|
| @nn.compact |
| def __call__(self, inputs): |
| x = inputs |
| for i, feat in enumerate(self.features): |
| x = nn.Dense(feat)(x) |
| if i != len(self.features - 1): |
| x = nn.relu(x) |
| return x |
|
|
|
|
| if __name__ == '__main__': |
| key1, key2 = random.split(random.PRNGKey(0), 2) |
|
|
| |
| nsamples = 20 |
| xdim = 10 |
| ydim = 5 |
|
|
| |
| W = random.normal(key1, (xdim, ydim)) |
| b = random.normal(key2, (ydim,)) |
| true_params = freeze({'params': {'bias': b, 'kernel': W}}) |
|
|
| |
| ksample, knoise = random.split(key1) |
| x_samples = random.normal(ksample, (nsamples, xdim)) |
| y_samples = jnp.dot(x_samples, W) + b |
| y_samples += 0.1 * random.normal(knoise, (nsamples, ydim)) |
| print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) |
|
|
| key_init, subkey = random.split(ksample, 2) |
| model = ExplicitMLP(features=[5]) |
| params = model.init(subkey, x_samples) |
|
|
| def make_mse_func(x_batched, y_batched): |
| def mse(params): |
| |
| def squared_error(x, y): |
| pred = model.apply(params, x) |
| return jnp.inner(y - pred, y - pred) / 2.0 |
|
|
| |
| return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0) |
|
|
| return jax.jit(mse) |
|
|
| |
| loss = make_mse_func(x_samples, y_samples) |
|
|
| lr = 0.3 |
| tx = optax.sgd(learning_rate=lr) |
| opt_state = tx.init(params) |
| loss_grad_fn = jax.value_and_grad(loss) |
|
|
| for i in range(101): |
| loss_val, grads = loss_grad_fn(params) |
| updates, opt_state = tx.update(grads, opt_state) |
| params = optax.apply_updates(params, updates) |
|
|
| if i % 10 == 0: |
| print('Loss step {}: '.format(i), loss_val) |
|
|
| |
| bytes_output = serialization.to_bytes(params) |
| dict_output = serialization.to_state_dict(params) |
| print('Dict output') |
| print(dict_output) |
| print('Bytes output') |
| print(bytes_output) |
|
|
| |
| saved_params = serialization.from_bytes(params, bytes_output) |
| print(loss(saved_params)) |
| print(loss(params)) |
|
|