Spaces:
Runtime error
Runtime error
first push
Browse files- .vscode/settings.json +4 -0
- app.py +105 -0
- requirements.txt +6 -0
.vscode/settings.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.linting.mypyEnabled": true,
|
| 3 |
+
"python.linting.enabled": true
|
| 4 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import optax
|
| 5 |
+
import jax.numpy as jnp
|
| 6 |
+
import flax.linen as nn
|
| 7 |
+
from flax.linen.initializers import zeros
|
| 8 |
+
from tinygp import kernels, transforms, GaussianProcess
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import regdata as rd
|
| 12 |
+
|
| 13 |
+
data_dict = {'motorcycle': rd.MotorcycleHelmet()}
|
| 14 |
+
|
| 15 |
+
data = st.selectbox('Data', [key for key in data_dict.keys()])
|
| 16 |
+
|
| 17 |
+
# st.markdown(f"{data}")
|
| 18 |
+
|
| 19 |
+
x, y, t = data_dict[data].get_data()
|
| 20 |
+
x = x.ravel()
|
| 21 |
+
t = t.ravel()
|
| 22 |
+
y = y.ravel()
|
| 23 |
+
|
| 24 |
+
random = np.random.default_rng(567)
|
| 25 |
+
|
| 26 |
+
noise = 0.01
|
| 27 |
+
|
| 28 |
+
# x = np.sort(random.uniform(-1, 1, 100))
|
| 29 |
+
# y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
|
| 30 |
+
# t = np.linspace(-1.5, 1.5, 500)
|
| 31 |
+
|
| 32 |
+
# Define a small neural network used to non-linearly transform the input data in our model
|
| 33 |
+
class Transformer(nn.Module):
|
| 34 |
+
@nn.compact
|
| 35 |
+
def __call__(self, x):
|
| 36 |
+
x = nn.Dense(features=15)(x)
|
| 37 |
+
x = nn.relu(x)
|
| 38 |
+
x = nn.Dense(features=10)(x)
|
| 39 |
+
x = nn.relu(x)
|
| 40 |
+
x = nn.Dense(features=1)(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class GPLoss(nn.Module):
|
| 45 |
+
@nn.compact
|
| 46 |
+
def __call__(self, x, y, t):
|
| 47 |
+
# Set up a typical Matern-3/2 kernel
|
| 48 |
+
log_sigma = self.param("log_sigma", zeros, ())
|
| 49 |
+
log_rho = self.param("log_rho", zeros, ())
|
| 50 |
+
# log_jitter = self.param("log_jitter", zeros, ())
|
| 51 |
+
base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
|
| 52 |
+
jnp.exp(log_rho)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Define a custom transform to pass the input coordinates through our `Transformer`
|
| 56 |
+
# network from above
|
| 57 |
+
transform = Transformer()
|
| 58 |
+
log_jitter = transform(x.reshape(-1,1)).ravel()
|
| 59 |
+
kernel = base_kernel
|
| 60 |
+
|
| 61 |
+
# Evaluate and return the GP negative log likelihood as usual
|
| 62 |
+
gp = GaussianProcess(
|
| 63 |
+
kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter)
|
| 64 |
+
)
|
| 65 |
+
log_prob, gp_cond = gp.condition(y, t[:, None])
|
| 66 |
+
return -log_prob, (gp_cond.loc, gp_cond.variance), jnp.exp(transform(t[:, None]))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Define and train the model
|
| 70 |
+
def loss(params):
|
| 71 |
+
return model.apply(params, x, y, t)[0]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
model = GPLoss()
|
| 75 |
+
params = model.init(jax.random.PRNGKey(1234), x, y, t)
|
| 76 |
+
tx = optax.sgd(learning_rate=1e-3)
|
| 77 |
+
opt_state = tx.init(params)
|
| 78 |
+
loss_grad_fn = jax.jit(jax.value_and_grad(loss))
|
| 79 |
+
for i in range(100):
|
| 80 |
+
loss_val, grads = loss_grad_fn(params)
|
| 81 |
+
updates, opt_state = tx.update(grads, opt_state)
|
| 82 |
+
params = optax.apply_updates(params, updates)
|
| 83 |
+
|
| 84 |
+
# Plot the results and compare to the true model
|
| 85 |
+
plt.figure()
|
| 86 |
+
_, (mu, var), jitter = model.apply(params, x, y, t)
|
| 87 |
+
var += jitter.ravel()
|
| 88 |
+
# plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
|
| 89 |
+
plt.plot(x, y, ".k", label="data")
|
| 90 |
+
plt.plot(t, mu)
|
| 91 |
+
idx = np.argsort(t)
|
| 92 |
+
plt.plot(t[idx], jitter[idx], label='noise')
|
| 93 |
+
plt.fill_between(
|
| 94 |
+
t, mu + 2*np.sqrt(var), mu - 2*np.sqrt(var), alpha=0.5, label="95% conf"
|
| 95 |
+
)
|
| 96 |
+
# plt.xlim(-1.5, 1.5)
|
| 97 |
+
# plt.ylim(-1.3, 1.3)
|
| 98 |
+
plt.xlabel("x")
|
| 99 |
+
plt.ylabel("y")
|
| 100 |
+
_ = plt.legend()
|
| 101 |
+
|
| 102 |
+
col = st.columns(1)[0]
|
| 103 |
+
with col:
|
| 104 |
+
st.pyplot(plt)
|
| 105 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
jaxlib
|
| 2 |
+
jax
|
| 3 |
+
optax
|
| 4 |
+
flax
|
| 5 |
+
tinygp
|
| 6 |
+
matplotlib
|