| | from functools import partial |
| | import jax |
| | from typing import Any, Callable, Sequence, Optional, NewType |
| | from jax import lax, random, vmap, numpy as jnp |
| | from jax.experimental.ode import odeint |
| | import flax |
| | from flax.training import train_state |
| | from flax import traverse_util |
| | from flax.core import freeze, unfreeze |
| | from flax import linen as nn |
| | from flax import serialization |
| | import optax |
| | import tensorflow_datasets as tfds |
| | import numpy as np |
| | from tqdm import tqdm |
| | import os |
| |
|
| |
|
| | |
| | @jax.jit |
| | def cross_entropy_loss(logits, labels): |
| | one_hot_labels = jax.nn.one_hot(labels, num_classes=10) |
| | return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1)) |
| |
|
| |
|
| | |
| | @jax.jit |
| | def compute_metrics(logits, labels): |
| | loss = cross_entropy_loss(logits=logits, labels=labels) |
| | accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) |
| | metrics = { |
| | 'loss': loss, |
| | 'accuracy': accuracy, |
| | } |
| | return metrics |
| |
|
| |
|
| | def get_datasets(): |
| | """Load MNIST train and test datasets into memory.""" |
| | ds_builder = tfds.builder('mnist') |
| | ds_builder.download_and_prepare() |
| | train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) |
| | test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) |
| | train_ds['image'] = jnp.float32(train_ds['image']) / 255. |
| | test_ds['image'] = jnp.float32(test_ds['image']) / 255. |
| | return train_ds, test_ds |
| |
|
| |
|
| | def create_train_state(model, rng, learning_rate): |
| | """Creates initial 'TrainState'.""" |
| | params = model.init(rng, jnp.ones([1, 28, 28, 1]))['params'] |
| | tx = optax.adam(learning_rate) |
| | return train_state.TrainState.create( |
| | apply_fn=model.apply, params=params, tx=tx |
| | ) |
| |
|
| |
|
| | |
| | @jax.jit |
| | def train_step(state, batch): |
| | """Train for a single step.""" |
| | def loss_fn(params): |
| | logits = apply({'params': params}, batch['image']) |
| | loss = cross_entropy_loss(logits=logits, labels=batch['label']) |
| | return loss, logits |
| | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
| | (_, logits), grads = grad_fn(state.params) |
| | state = state.apply_gradients(grads=grads) |
| | metrics = compute_metrics(logits=logits, labels=batch['label']) |
| | return state, metrics |
| |
|
| |
|
| | |
| | @jax.jit |
| | def eval_step(params, batch): |
| | logits = apply({'params': params}, batch['image']) |
| | return compute_metrics(logits=logits, labels=batch['label']) |
| |
|
| |
|
| | |
| | def train_epoch(model, state, train_ds, batch_size, epoch, rng): |
| | """Train for a single epoch""" |
| | train_ds_size = len(train_ds['image']) |
| | steps_per_epoch = train_ds_size // batch_size |
| |
|
| | perms = jax.random.permutation(rng, len(train_ds['image'])) |
| | perms = perms[:steps_per_epoch * batch_size] |
| | perms = perms.reshape((steps_per_epoch, batch_size)) |
| | batch_metrics = [] |
| | for perm in tqdm(perms): |
| | batch = {k: v[perm, ...] for k, v in train_ds.items()} |
| | state, metrics = train_step(model, state, batch) |
| | batch_metrics.append(metrics) |
| |
|
| | |
| | batch_metrics_np = jax.device_get(batch_metrics) |
| | epoch_metrics_np = { |
| | k: np.mean([metrics[k] for metrics in batch_metrics_np]) |
| | for k in batch_metrics_np[0] |
| | } |
| | print('train epoch: %d, loss: %.4f, accuracy: %.2f' % ( |
| | epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100 |
| | )) |
| |
|
| | return state |
| |
|
| |
|
| | |
| | def eval_model(params, test_ds): |
| | metrics = eval_step(params, test_ds) |
| | metrics = jax.device_get(metrics) |
| | summary = jax.tree_map(lambda x: x.item(), metrics) |
| | return summary['loss'], summary['accuracy'] |
| |
|
| |
|
| | def train_and_evaluate(learning_rate, n_epoch, batch_size): |
| | train_ds, test_ds = get_datasets() |
| | rng = jax.random.PRNGKey(0) |
| | rng, init_rng = jax.random.split(rng) |
| |
|
| | state = create_train_state(init_rng, learning_rate) |
| | del init_rng |
| |
|
| | for epoch in tqdm(range(1, n_epoch + 1)): |
| | rng, input_rng = jax.random.split(rng) |
| | state = train_epoch(state, train_ds, batch_size, epoch, input_rng) |
| | test_loss, test_accuracy = eval_model(state.params, test_ds) |
| | print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % ( |
| | epoch, test_loss, test_accuracy * 100 |
| | )) |
| |
|