| import numpy as np |
| import matplotlib.pyplot as plt |
| import os |
| import glob |
| from PIL import Image |
| from functools import partial |
| import jax |
| from typing import Any, Callable, Sequence, Optional, NewType |
| from jax import lax, random, vmap, scipy, numpy as jnp |
| |
| from models.ode import odeint |
| import flax |
| from flax.training import train_state |
| from flax import traverse_util |
| from flax.core import freeze, unfreeze |
| from flax import linen as nn |
| from flax import serialization |
| import optax |
| from sklearn.datasets import make_circles, make_moons, make_s_curve |
| from tqdm import tqdm |
|
|
|
|
| |
| |
| |
|
|
|
|
| class HyperNetwork(nn.Module): |
| """Hyper-network allowing f(z(t), t) to change with time. |
| |
| Adapted from the Pytorch implementation at: |
| https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py |
| """ |
| in_out_dim: Any = 2 |
| hidden_dim: Any = 32 |
| width: Any = 64 |
|
|
| @nn.compact |
| def __call__(self, t): |
| |
| blocksize = self.width * self.in_out_dim |
| params = lax.expand_dims(t, (0, 1)) |
| params = nn.Dense(self.hidden_dim)(params) |
| params = nn.tanh(params) |
| params = nn.Dense(self.hidden_dim)(params) |
| params = nn.tanh(params) |
| params = nn.Dense(3 * blocksize + self.width)(params) |
|
|
| |
| params = lax.reshape(params, (3 * blocksize + self.width,)) |
| W = lax.reshape(params[:blocksize], (self.width, self.in_out_dim, 1)) |
|
|
| U = lax.reshape(params[blocksize:2 * blocksize], (self.width, 1, self.in_out_dim)) |
|
|
| G = lax.reshape(params[2 * blocksize:3 * blocksize], (self.width, 1, self.in_out_dim)) |
| U = U * nn.sigmoid(G) |
|
|
| B = lax.expand_dims(params[3 * blocksize:], (1, 2)) |
| return W, B, U |
|
|
|
|
| class CNF(nn.Module): |
| """Adapted from the Pytorch implementation at: |
| https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py |
| """ |
| in_out_dim: Any = 2 |
| hidden_dim: Any = 32 |
| width: Any = 64 |
|
|
| @nn.compact |
| def __call__(self, t, states): |
| z, logp_z = states[:, :2], states[:, 2:] |
| W, B, U = HyperNetwork(self.in_out_dim, self.hidden_dim, self.width)(t) |
|
|
| def dzdt(z): |
| h = nn.tanh(vmap(jnp.matmul, (None, 0))(z, W) + B) |
| return jnp.matmul(h, U).mean(0) |
|
|
| dz_dt = dzdt(z) |
| sum_dzdt = lambda z: dzdt(z).sum(0) |
| df_dz = jax.jacrev(sum_dzdt)(z) |
| dlogp_z_dt = -1.0 * jnp.trace(df_dz, 0, 0, 2) |
|
|
| return lax.concatenate((dz_dt, lax.expand_dims(dlogp_z_dt, (1,))), 1) |
|
|
|
|
| class Neg_CNF(nn.Module): |
| """Negative CNF for jax's odeint.""" |
| in_out_dim: Any = 2 |
| hidden_dim: Any = 32 |
| width: Any = 64 |
|
|
| @nn.compact |
| def __call__(self, t, states): |
| outputs = CNF(self.in_out_dim, self.hidden_dim, self.width)(-1.0 * t, states) |
|
|
| return -1.0 * outputs |
|
|
|
|
| def get_batch_circles(num_samples): |
| """Adapted from the Pytorch implementation at: |
| https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py |
| """ |
| points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5) |
| x = jnp.array(points, dtype=jnp.float32) |
| logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32) |
|
|
| return lax.concatenate((x, logp_diff_t1), 1) |
|
|
|
|
| def get_batch_moons(num_samples): |
| points, _ = make_moons(n_samples=num_samples, noise=0.05) |
| x = jnp.array(points, dtype=jnp.float32) |
| logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32) |
|
|
| return lax.concatenate((x, logp_diff_t1), 1) |
|
|
|
|
| def get_batch_scurve(num_samples): |
| points, _ = make_s_curve(n_samples=num_samples, noise=0.05, random_state=0) |
| x1 = jnp.array(points, dtype=jnp.float32)[:, :1] |
| x2 = jnp.array(points, dtype=jnp.float32)[:, 2:] |
| x = lax.concatenate((x1, x2), 1) |
| logp_diff_t1 = jnp.zeros((num_samples, 1), dtype=jnp.float32) |
|
|
| return lax.concatenate((x, logp_diff_t1), 1) |
|
|
|
|
| def multivariate_normal(z): |
| """ |
| Log probability of multivariate_normal. |
| """ |
| mean = jnp.array([0., 0.]) |
| z_m = z - mean |
| cov = jnp.array([[0.1, 0.], [0., 0.1]]) |
| logz = -jnp.log((2 * jnp.pi)) + -0.5 * jnp.log(jnp.linalg.det(cov)) + -0.5 * jnp.matmul(jnp.matmul(z_m.T, jnp.linalg.inv(cov)), z_m) |
| return logz |
|
|
|
|
| def create_train_state(rng, learning_rate, in_out_dim, hidden_dim, width): |
| """Creates initial 'TrainState'.""" |
| inputs = jnp.ones((1, 2)) |
| neg_cnf = Neg_CNF(in_out_dim, hidden_dim, width) |
| params = neg_cnf.init(rng, jnp.array(10.), inputs)['params'] |
| set_params(params) |
| tx = optax.adam(learning_rate) |
| return train_state.TrainState.create( |
| apply_fn=neg_cnf.apply, params=params, tx=tx |
| ) |
|
|
|
|
| def set_params(params): |
| |
| params = unfreeze(params) |
| |
| flat_params = {'/'.join(k): v for k, v in traverse_util.flatten_dict(params).items()} |
| unflat_params = traverse_util.unflatten_dict({tuple(k.split('/')): 0.1 * jnp.ones_like(v) for k, v in flat_params.items()}) |
| new_params = freeze(unflat_params) |
| test_x = jnp.array([[0., 1.], [2., 3.], [4., 5.]]) |
| test_log_p = jnp.zeros((3, 1)) |
| test_inputs = lax.concatenate((test_x, test_log_p), 1) |
| Neg_CNF().apply({'params': new_params}, jnp.array(0.), test_inputs) |
|
|
|
|
| @partial(jax.jit, static_argnums=(2, 3, 4, 5, 6)) |
| def train_step(state, batch, in_out_dim, hidden_dim, width, t0, t1): |
| p_z0 = lambda x: scipy.stats.multivariate_normal.logpdf(x, |
| mean=jnp.array([0., 0.]), |
| cov=jnp.array([[0.1, 0.], [0., 0.1]])) |
| vmap_multi = jax.vmap(multivariate_normal, 0, 0) |
| def loss_fn(params): |
| func = lambda states, t: Neg_CNF(in_out_dim, hidden_dim, width).apply({'params': params}, t, states) |
| outputs = odeint( |
| func, |
| batch, |
| -1.0 * jnp.array([t1, t0]), |
| atol=1e-5, |
| rtol=1e-5 |
| ) |
| z_t, logp_diff_t = outputs[:, :, :2], outputs[:, :, 2:] |
| z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1] |
| logp_x = p_z0(z_t0) - lax.squeeze(logp_diff_t0, dimensions=(1,)) |
| loss = -logp_x.mean(0) |
| return loss |
| grad_fn = jax.value_and_grad(loss_fn) |
| loss, grads = grad_fn(state.params) |
| state = state.apply_gradients(grads=grads) |
|
|
| return state, loss |
|
|
|
|
| def train(learning_rate, n_iters, batch_size, in_out_dim, hidden_dim, width, t0, t1, visual, dataset): |
| """Train the model.""" |
| rng = jax.random.PRNGKey(0) |
| state = create_train_state(rng, learning_rate, in_out_dim, hidden_dim, width) |
| if dataset == "circles": |
| get_batch = lambda num_samples: get_batch_circles(num_samples) |
| elif dataset == "moons": |
| get_batch = lambda num_samples: get_batch_moons(num_samples) |
| elif dataset == "scurve": |
| get_batch = lambda num_samples: get_batch_scurve(num_samples) |
|
|
| for itr in range(1, n_iters+1): |
| batch = get_batch(batch_size) |
| state, loss = train_step(state, batch, in_out_dim, hidden_dim, width, t0, t1) |
| print("iter: %d, loss: %.2f" % (itr, loss)) |
|
|
| if visual is True: |
| |
| neg_params = state.params |
| neg_params = unfreeze(neg_params) |
| |
| neg_flat_params = {'/'.join(k): v for k, v in traverse_util.flatten_dict(neg_params).items()} |
| pos_flat_params = {key[6:]: jnp.array(np.array(neg_flat_params[key])) for key in list(neg_flat_params.keys())} |
| pos_unflat_params = traverse_util.unflatten_dict({tuple(k.split('/')): v for k, v in pos_flat_params.items()}) |
| pos_params = freeze(pos_unflat_params) |
| jax.profiler.save_device_memory_profile("memory.prof") |
| output = viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1, dataset) |
| z_t_samples, z_t_density, logp_diff_t, viz_timesteps, target_sample, z_t1 = output |
| create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, target_sample, z_t1, dataset) |
|
|
|
|
| def solve_dynamics(dynamics_fn, initial_state, t): |
| def f(initial_state, t): |
| return odeint(dynamics_fn, initial_state, t, atol=1e-5, rtol=1e-5) |
| return f(initial_state, t) |
|
|
|
|
| def viz(neg_params, pos_params, in_out_dim, hidden_dim, width, t0, t1, dataset): |
| """Adapted from PyTorch """ |
| viz_samples = 30000 |
| viz_timesteps = 41 |
| if dataset == "circles": |
| get_batch = lambda num_samples: get_batch_circles(num_samples) |
| elif dataset == "moons": |
| get_batch = lambda num_samples: get_batch_moons(num_samples) |
| elif dataset == "scurve": |
| get_batch = lambda num_samples: get_batch_scurve(num_samples) |
| target_sample = get_batch(viz_samples)[:, :2] |
|
|
| if not os.path.exists('results_%s/' % dataset): |
| os.makedirs('results_%s/' % dataset) |
|
|
| z_t0 = jnp.array(np.random.multivariate_normal(mean=np.array([0., 0.]), |
| cov=np.array([[0.1, 0.], [0., 0.1]]), |
| size=viz_samples)) |
| logp_diff_t0 = jnp.zeros((viz_samples, 1), dtype=jnp.float32) |
|
|
| func_pos = lambda states, t: CNF(in_out_dim, hidden_dim, width).apply({'params': pos_params}, t, states) |
| output = solve_dynamics(func_pos, lax.concatenate((z_t0, logp_diff_t0), 1), jnp.linspace(t0, t1, viz_timesteps)) |
| z_t_samples, _ = output[..., :2], output[..., 2:] |
|
|
| |
| x = jnp.linspace(-1.5, 1.5, 100) |
| y = jnp.linspace(-1.5, 1.5, 100) |
| points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T |
|
|
| z_t1 = jnp.array(points, dtype=jnp.float32) |
| logp_diff_t1 = jnp.zeros((z_t1.shape[0], 1), dtype=jnp.float32) |
| func_neg = lambda states, t: Neg_CNF(in_out_dim, hidden_dim, width).apply({'params': neg_params}, t, states) |
| output = solve_dynamics(func_neg, lax.concatenate((z_t1, logp_diff_t1), 1), -jnp.linspace(t1, t0, viz_timesteps)) |
| z_t_density, logp_diff_t = output[..., :2], output[..., 2:] |
|
|
| return z_t_samples, z_t_density, logp_diff_t, viz_timesteps, target_sample, z_t1 |
|
|
|
|
| def create_plots(z_t_samples, z_t_density, logp_diff_t, t0, t1, viz_timesteps, target_sample, z_t1, dataset): |
| |
| for (t, z_sample, z_density, logp_diff) in zip( |
| tqdm(np.linspace(t0, t1, viz_timesteps)), |
| z_t_samples, z_t_density, logp_diff_t |
| ): |
| fig = plt.figure(figsize=(12, 4), dpi=200) |
| plt.tight_layout() |
| plt.axis('off') |
| plt.margins(0, 0) |
| fig.suptitle(f'{t:.2f}s') |
|
|
| ax1 = fig.add_subplot(1, 3, 1) |
| ax1.set_title('Target') |
| ax1.get_xaxis().set_ticks([]) |
| ax1.get_yaxis().set_ticks([]) |
| ax2 = fig.add_subplot(1, 3, 2) |
| ax2.set_title('Samples') |
| ax2.get_xaxis().set_ticks([]) |
| ax2.get_yaxis().set_ticks([]) |
| ax3 = fig.add_subplot(1, 3, 3) |
| ax3.set_title('Log Probability') |
| ax3.get_xaxis().set_ticks([]) |
| ax3.get_yaxis().set_ticks([]) |
|
|
| ax1.hist2d(*jnp.transpose(target_sample), bins=300, density=True, |
| range=[[-1.5, 1.5], [-1.5, 1.5]]) |
|
|
| ax2.hist2d(*jnp.transpose(z_sample), bins=300, density=True, |
| range=[[-1.5, 1.5], [-1.5, 1.5]]) |
| p_z0 = lambda x: scipy.stats.multivariate_normal.logpdf(x, |
| mean=jnp.array([0., 0.]), |
| cov=jnp.array([[0.1, 0.], [0., 0.1]])) |
| logp = p_z0(z_density) - lax.squeeze(logp_diff, dimensions=(1,)) |
| ax3.tricontourf(*jnp.transpose(z_t1), |
| jnp.exp(logp), 200) |
|
|
| plt.savefig(os.path.join('results_%s/' % dataset, f"cnf-viz-{int(t * 1000):05d}.jpg"), |
| pad_inches=0.2, bbox_inches='tight') |
| plt.close() |
|
|
| img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join('results_%s/' % dataset, f"cnf-viz-*.jpg")))] |
| img.save(fp=os.path.join('results_%s/' % dataset, "cnf-viz.gif"), format='GIF', append_images=imgs, |
| save_all=True, duration=250, loop=0) |
|
|
| print('Saved visualization animation at {}'.format(os.path.join('results_%s/' % dataset, "cnf-viz.gif"))) |
|
|
|
|
| if __name__ == '__main__': |
| train(0.001, 1000, 512, 2, 32, 64, 0., 10., True, 'scurve') |
|
|