import streamlit as st import jax import optax import jax.numpy as jnp import flax.linen as nn from flax.linen.initializers import zeros from tinygp import kernels, transforms, GaussianProcess import numpy as np import matplotlib.pyplot as plt mcycle_x = np.array([ 35.2, 27.6, 35.6, 28.2, 57.6, 26.4, 46.6, 55. , 16.6, 8.2, 32.8, 19.2, 14.6, 42.8, 9.6, 50.6, 2.4, 34.8, 33.4, 6.2, 34.4, 29.4, 25.6, 16. , 13.8, 15.6, 20.2, 44.4, 3.6, 21.2, 8.8, 24.2, 45. , 33.8, 7.8, 38. , 2.6, 41.6, 20.4, 23.2, 26. , 40. , 28.4, 55.4, 15.4, 53.2, 11.4, 48.8, 25.4, 13.6, 39.2, 40.4, 16.8, 4. , 43. , 15.8, 24.6, 16.4, 28.6, 52. , 30.2, 18.6, 10.2, 32. , 6.6, 17.6, 19.6, 24. , 16.2, 42.4, 22. , 23.4, 44. , 6.8, 11. , 10.6, 26.2, 39.4, 31.2, 19.4, 21.4, 27.2, 47.8, 35.4, 13.2, 31. , 36.2, 14.8, 3.2, 25. , 21.8, 17.8, 27. , 10. ]) mcycle_y = np.array([ -16. , 4. , 34.8, 12. , 10.7, -65.6, 10.7, -2.7, -59. , -2.7, 46.9, -123.1, -13.3, 0. , -2.7, 0. , 0. , 75. , 16. , -2.7, 1.3, -17.4, -26.8, -42.9, 0. , -40.2, -123.1, 0. , 0. , -134. , -1.3, -95.1, 10.7, 45.6, -2.7, 46.9, -1.3, 30.8, -117.9, -123.1, -5.4, -21.5, -21.5, -2.7, -22.8, -14.7, 0. , -13.3, -72.3, -2.7, 5.4, -13.3, -71. , -2.7, 14.7, -21.5, -53.5, -5.4, 46.9, 10.7, 36.2, -112.5, -5.4, 54.9, -2.7, -37.5, -127.2, -112.5, -21.5, 29.4, -123.1, -128.5, -1.3, -1.3, -5.4, -2.7, -107.1, -1.3, 8.1, -85.6, -101.9, -45.6, -26.8, 69.6, -2.7, 75. , -37.5, -2.7, -2.7, -64.4, -108.4, -99.1, -16. , -2.7]) oly_x = np.array([1896. , 1900. , 1904. , 1908. , 1912. , 1920. , 1924. , 1928. , 1932. , 1936. , 1948. , 1952. , 1956. , 1960. , 1964. , 1968. , 1972. , 1976. , 1980. , 1984. , 1988. , 1992. , 1996. , 2000. , 2004. , 2008. , 2012. ]) oly_y = np.array([ 4.47083333, 4.46472926, 5.22208333, 4.15467867, 3.90331675, 3.56951267, 3.82454477, 3.62483707, 3.59284275, 3.53880792, 3.67010309, 3.39029111, 3.43642612, 3.20583007, 3.13275665, 3.32819844, 3.13583758, 3.0789588 , 3.10581822, 3.06552909, 3.09357349, 3.16111704, 3.14255244, 3.08527867, 3.10265829, 2.99877553, 3.03392977]) st.title("Heteroscedastic Gaussian Processes") st.markdown(r""" Gaussian processes generally assume Homoskedastic noise such as: $$ y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} \stackrel{\text { i.i.d. }}{\sim} \mathcal{N}\left(0, \sigma_{\epsilon}^{2}\right), \quad 1 \leq i \leq n $$ We can also assume separate distribution of noise over each data point: $$ y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} {\sim} \mathcal{N}\left(0, \sigma_{\epsilon_i}^{2}\right), \quad 1 \leq i \leq n $$ However, this may not be straightforward to extend for inference or conditioning. A simple idea can be to learn a non-linear neural network function to model the relationship between inputs and noise: $$ \sigma_{\epsilon_i} = f(\mathbf{x}_i) $$ This demo is an attempt to experiment with this idea on several synthetic and real datasets with Heteroskedastic noise. """) data = st.selectbox("Data", ["Motorcycle", "Olympic", "Linear", 'GPflow']) if data == "Motorcycle": data_x, data_y = mcycle_x, mcycle_y elif data == "Olympic": data_x, data_y = oly_x, oly_y elif data == "Linear": data_x = np.linspace(0,1,99) data_y = 3 * data_x + 2 + (np.random.randn(data_x.shape[0]) * data_x) elif data == 'GPflow': N = 1001 # Build inputs X data_x = np.linspace(0, 4 * np.pi, N) # Deterministic functions in place of latent ones f1 = np.sin f2 = np.cos # Use transform = exp to ensure positive-only scale values transform = np.exp # Compute loc and scale as functions of input X loc = f1(data_x) scale = transform(f2(data_x)) # Sample outputs Y from Gaussian Likelihood data_y = np.random.normal(loc, scale) x = (data_x - data_x.mean()) / data_x.std() y = (data_y - data_y.mean()) / data_y.std() n_tests = st.number_input( "Number of test points", min_value=50, max_value=1000, value=100 ) t = np.linspace(x.min(), x.max(), n_tests) noise = 0.01 # x = np.sort(random.uniform(-1, 1, 100)) # y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x)) # t = np.linspace(-1.5, 1.5, 500) # Define a small neural network used to non-linearly transform the input data in our model fet1 = st.slider("Number of neurons in Layer1", min_value=2, max_value=30, value=15) fet2 = st.slider("Number of neurons in Layer2", min_value=2, max_value=30, value=10) class Transformer(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=fet1)(x) x = nn.relu(x) x = nn.Dense(features=fet2)(x) x = nn.relu(x) x = nn.Dense(features=1)(x) return x class BaseGPLoss(nn.Module): @nn.compact def __call__(self, x, y, t): # Set up a typical Matern-3/2 kernel log_sigma = self.param("log_sigma", zeros, ()) log_rho = self.param("log_rho", zeros, ()) log_jitter = self.param("log_jitter", zeros, ()) kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho)) # Evaluate and return the GP negative log likelihood as usual gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter)) log_prob, gp_cond = gp.condition(y, t[:, None]) return ( -log_prob, (gp_cond.loc, gp_cond.variance), jnp.exp(log_jitter), ) class GPLoss(nn.Module): @nn.compact def __call__(self, x, y, t): # Set up a typical Matern-3/2 kernel log_sigma = self.param("log_sigma", zeros, ()) log_rho = self.param("log_rho", zeros, ()) # log_jitter = self.param("log_jitter", zeros, ()) base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho)) # Define a custom transform to pass the input coordinates through our `Transformer` # network from above transform = Transformer() log_jitter = transform(x.reshape(-1, 1)).ravel() kernel = base_kernel # Evaluate and return the GP negative log likelihood as usual gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter)) log_prob, gp_cond = gp.condition(y, t[:, None]) return ( -log_prob, (gp_cond.loc, gp_cond.variance), jnp.exp(transform(t[:, None])), ) # Define and train the model def loss(params): return m.apply(params, x, y, t)[0] base_model = BaseGPLoss() model = GPLoss() seed = np.random.randint(0,100) base_params = base_model.init(jax.random.PRNGKey(seed), x, y, t) params = model.init(jax.random.PRNGKey(np.random.randint(seed)), x, y, t) n_iters = st.number_input("Number of iterations", min_value=1, max_value=200, value=100) lr = st.selectbox("Learning rate", [0.1, 0.01, 0.001, 0.0001], 1) tx = optax.sgd(learning_rate=lr) base_opt_state = tx.init(base_params) opt_state = tx.init(params) loss_grad_fn = jax.jit(jax.value_and_grad(loss)) base_losses = [] losses = [] my_bar = st.progress(0) for i in range(n_iters): m = base_model base_loss_val, base_grads = loss_grad_fn(base_params) m = model loss_val, grads = loss_grad_fn(params) base_updates, base_opt_state = tx.update(base_grads, base_opt_state) updates, opt_state = tx.update(grads, opt_state) base_params = optax.apply_updates(base_params, base_updates) params = optax.apply_updates(params, updates) losses.append(loss_val) base_losses.append(base_loss_val) my_bar.progress((i+1) / n_iters) # Plot the results and compare to the true model fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4)) m_list = [base_model, model] p_list = [base_params, params] t_list = ['Homoskedastic GP', 'Heteroskedastic GP'] j_list = [] for i in range(2): _, (mu, var), jitter = m_list[i].apply(p_list[i], x, y, t) var += jitter.ravel() j_list.append(jitter) # plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth") ax[i].plot(x, y, ".k", label="data") ax[i].plot(t, mu) ax[i].set_title(t_list[i]) ax[i].fill_between( t, mu + 2 * np.sqrt(var), mu - 2 * np.sqrt(var), alpha=0.5, label="95% conf" ) ax[1].legend() col = st.columns(1)[0] with col: st.pyplot(fig) col2 = st.columns(1)[0] fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4)) idx = np.argsort(t) ax[1].plot(t[idx], j_list[1][idx], label="learned noise") ax[0].hlines(j_list[0], t[idx].min(), t[idx].max()) ax[0].set_xlabel('x') ax[1].set_xlabel('x') ax[0].set_ylabel('learned noise') ax[1].legend() with col2: st.pyplot(fig) col3 = st.columns(1)[0] fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4)) idx = np.argsort(t) ax[0].plot(base_losses) ax[1].plot(losses, label="loss") ax[0].set_xlabel('iterations') ax[0].set_ylabel('loss') ax[1].set_xlabel('iterations') ax[1].legend() with col3: st.pyplot(fig)