Spaces:
Runtime error
Runtime error
File size: 9,810 Bytes
814ce70 06326ea 7001066 043861d 7001066 06326ea 7001066 06326ea 7001066 35ac3ff 9d9276d 06326ea 9d9276d 06326ea 814ce70 06326ea 814ce70 9d9276d 6319656 9d9276d 814ce70 9d9276d 814ce70 9d9276d 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea 814ce70 9d9276d 06326ea 814ce70 06326ea 9d9276d 06326ea 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea 9d9276d 814ce70 06326ea 814ce70 06326ea 814ce70 06326ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import streamlit as st
import jax
import optax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros
from tinygp import kernels, transforms, GaussianProcess
import numpy as np
import matplotlib.pyplot as plt
mcycle_x = np.array([ 35.2, 27.6, 35.6, 28.2, 57.6, 26.4, 46.6, 55. ,
16.6, 8.2, 32.8, 19.2, 14.6, 42.8, 9.6, 50.6,
2.4, 34.8, 33.4, 6.2, 34.4, 29.4, 25.6, 16. ,
13.8, 15.6, 20.2, 44.4, 3.6, 21.2, 8.8, 24.2,
45. , 33.8, 7.8, 38. , 2.6, 41.6, 20.4, 23.2,
26. , 40. , 28.4, 55.4, 15.4, 53.2, 11.4, 48.8,
25.4, 13.6, 39.2, 40.4, 16.8, 4. , 43. , 15.8,
24.6, 16.4, 28.6, 52. , 30.2, 18.6, 10.2, 32. ,
6.6, 17.6, 19.6, 24. , 16.2, 42.4, 22. , 23.4,
44. , 6.8, 11. , 10.6, 26.2, 39.4, 31.2, 19.4,
21.4, 27.2, 47.8, 35.4, 13.2, 31. , 36.2, 14.8,
3.2, 25. , 21.8, 17.8, 27. , 10. ])
mcycle_y = np.array([ -16. , 4. , 34.8, 12. , 10.7, -65.6, 10.7, -2.7,
-59. , -2.7, 46.9, -123.1, -13.3, 0. , -2.7, 0. ,
0. , 75. , 16. , -2.7, 1.3, -17.4, -26.8, -42.9,
0. , -40.2, -123.1, 0. , 0. , -134. , -1.3, -95.1,
10.7, 45.6, -2.7, 46.9, -1.3, 30.8, -117.9, -123.1,
-5.4, -21.5, -21.5, -2.7, -22.8, -14.7, 0. , -13.3,
-72.3, -2.7, 5.4, -13.3, -71. , -2.7, 14.7, -21.5,
-53.5, -5.4, 46.9, 10.7, 36.2, -112.5, -5.4, 54.9,
-2.7, -37.5, -127.2, -112.5, -21.5, 29.4, -123.1, -128.5,
-1.3, -1.3, -5.4, -2.7, -107.1, -1.3, 8.1, -85.6,
-101.9, -45.6, -26.8, 69.6, -2.7, 75. , -37.5, -2.7,
-2.7, -64.4, -108.4, -99.1, -16. , -2.7])
oly_x = np.array([1896. , 1900. , 1904. , 1908. ,
1912. , 1920. , 1924. , 1928. ,
1932. , 1936. , 1948. , 1952. ,
1956. , 1960. , 1964. , 1968. ,
1972. , 1976. , 1980. , 1984. ,
1988. , 1992. , 1996. , 2000. ,
2004. , 2008. , 2012. ])
oly_y = np.array([ 4.47083333, 4.46472926, 5.22208333, 4.15467867,
3.90331675, 3.56951267, 3.82454477, 3.62483707,
3.59284275, 3.53880792, 3.67010309, 3.39029111,
3.43642612, 3.20583007, 3.13275665, 3.32819844,
3.13583758, 3.0789588 , 3.10581822, 3.06552909,
3.09357349, 3.16111704, 3.14255244, 3.08527867,
3.10265829, 2.99877553, 3.03392977])
st.title("Heteroscedastic Gaussian Processes")
st.markdown(r"""
Gaussian processes generally assume Homoskedastic noise such as:
$$
y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} \stackrel{\text { i.i.d. }}{\sim} \mathcal{N}\left(0, \sigma_{\epsilon}^{2}\right), \quad 1 \leq i \leq n
$$
We can also assume separate distribution of noise over each data point:
$$
y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} {\sim} \mathcal{N}\left(0, \sigma_{\epsilon_i}^{2}\right), \quad 1 \leq i \leq n
$$
However, this may not be straightforward to extend for inference or conditioning. A simple idea can be to learn a non-linear neural network function to model the relationship between inputs and noise:
$$
\sigma_{\epsilon_i} = f(\mathbf{x}_i)
$$
This demo is an attempt to experiment with this idea on several synthetic and real datasets with Heteroskedastic noise.
""")
data = st.selectbox("Data", ["Motorcycle", "Olympic", "Linear", 'GPflow'])
if data == "Motorcycle":
data_x, data_y = mcycle_x, mcycle_y
elif data == "Olympic":
data_x, data_y = oly_x, oly_y
elif data == "Linear":
data_x = np.linspace(0,1,99)
data_y = 3 * data_x + 2 + (np.random.randn(data_x.shape[0]) * data_x)
elif data == 'GPflow':
N = 1001
# Build inputs X
data_x = np.linspace(0, 4 * np.pi, N)
# Deterministic functions in place of latent ones
f1 = np.sin
f2 = np.cos
# Use transform = exp to ensure positive-only scale values
transform = np.exp
# Compute loc and scale as functions of input X
loc = f1(data_x)
scale = transform(f2(data_x))
# Sample outputs Y from Gaussian Likelihood
data_y = np.random.normal(loc, scale)
x = (data_x - data_x.mean()) / data_x.std()
y = (data_y - data_y.mean()) / data_y.std()
n_tests = st.number_input(
"Number of test points", min_value=50, max_value=1000, value=100
)
t = np.linspace(x.min(), x.max(), n_tests)
noise = 0.01
# x = np.sort(random.uniform(-1, 1, 100))
# y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
# t = np.linspace(-1.5, 1.5, 500)
# Define a small neural network used to non-linearly transform the input data in our model
fet1 = st.slider("Number of neurons in Layer1", min_value=2, max_value=30, value=15)
fet2 = st.slider("Number of neurons in Layer2", min_value=2, max_value=30, value=10)
class Transformer(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=fet1)(x)
x = nn.relu(x)
x = nn.Dense(features=fet2)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
class BaseGPLoss(nn.Module):
@nn.compact
def __call__(self, x, y, t):
# Set up a typical Matern-3/2 kernel
log_sigma = self.param("log_sigma", zeros, ())
log_rho = self.param("log_rho", zeros, ())
log_jitter = self.param("log_jitter", zeros, ())
kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho))
# Evaluate and return the GP negative log likelihood as usual
gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter))
log_prob, gp_cond = gp.condition(y, t[:, None])
return (
-log_prob,
(gp_cond.loc, gp_cond.variance),
jnp.exp(log_jitter),
)
class GPLoss(nn.Module):
@nn.compact
def __call__(self, x, y, t):
# Set up a typical Matern-3/2 kernel
log_sigma = self.param("log_sigma", zeros, ())
log_rho = self.param("log_rho", zeros, ())
# log_jitter = self.param("log_jitter", zeros, ())
base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho))
# Define a custom transform to pass the input coordinates through our `Transformer`
# network from above
transform = Transformer()
log_jitter = transform(x.reshape(-1, 1)).ravel()
kernel = base_kernel
# Evaluate and return the GP negative log likelihood as usual
gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter))
log_prob, gp_cond = gp.condition(y, t[:, None])
return (
-log_prob,
(gp_cond.loc, gp_cond.variance),
jnp.exp(transform(t[:, None])),
)
# Define and train the model
def loss(params):
return m.apply(params, x, y, t)[0]
base_model = BaseGPLoss()
model = GPLoss()
seed = np.random.randint(0,100)
base_params = base_model.init(jax.random.PRNGKey(seed), x, y, t)
params = model.init(jax.random.PRNGKey(np.random.randint(seed)), x, y, t)
n_iters = st.number_input("Number of iterations", min_value=1, max_value=200, value=100)
lr = st.selectbox("Learning rate", [0.1, 0.01, 0.001, 0.0001], 1)
tx = optax.sgd(learning_rate=lr)
base_opt_state = tx.init(base_params)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))
base_losses = []
losses = []
my_bar = st.progress(0)
for i in range(n_iters):
m = base_model
base_loss_val, base_grads = loss_grad_fn(base_params)
m = model
loss_val, grads = loss_grad_fn(params)
base_updates, base_opt_state = tx.update(base_grads, base_opt_state)
updates, opt_state = tx.update(grads, opt_state)
base_params = optax.apply_updates(base_params, base_updates)
params = optax.apply_updates(params, updates)
losses.append(loss_val)
base_losses.append(base_loss_val)
my_bar.progress((i+1) / n_iters)
# Plot the results and compare to the true model
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
m_list = [base_model, model]
p_list = [base_params, params]
t_list = ['Homoskedastic GP', 'Heteroskedastic GP']
j_list = []
for i in range(2):
_, (mu, var), jitter = m_list[i].apply(p_list[i], x, y, t)
var += jitter.ravel()
j_list.append(jitter)
# plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
ax[i].plot(x, y, ".k", label="data")
ax[i].plot(t, mu)
ax[i].set_title(t_list[i])
ax[i].fill_between(
t, mu + 2 * np.sqrt(var), mu - 2 * np.sqrt(var), alpha=0.5, label="95% conf"
)
ax[1].legend()
col = st.columns(1)[0]
with col:
st.pyplot(fig)
col2 = st.columns(1)[0]
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
idx = np.argsort(t)
ax[1].plot(t[idx], j_list[1][idx], label="learned noise")
ax[0].hlines(j_list[0], t[idx].min(), t[idx].max())
ax[0].set_xlabel('x')
ax[1].set_xlabel('x')
ax[0].set_ylabel('learned noise')
ax[1].legend()
with col2:
st.pyplot(fig)
col3 = st.columns(1)[0]
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
idx = np.argsort(t)
ax[0].plot(base_losses)
ax[1].plot(losses, label="loss")
ax[0].set_xlabel('iterations')
ax[0].set_ylabel('loss')
ax[1].set_xlabel('iterations')
ax[1].legend()
with col3:
st.pyplot(fig) |