Zeel commited on
Commit
814ce70
·
1 Parent(s): 8be2f17

first push

Browse files
Files changed (3) hide show
  1. .vscode/settings.json +4 -0
  2. app.py +105 -0
  3. requirements.txt +6 -0
.vscode/settings.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "python.linting.mypyEnabled": true,
3
+ "python.linting.enabled": true
4
+ }
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import jax
4
+ import optax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+ from flax.linen.initializers import zeros
8
+ from tinygp import kernels, transforms, GaussianProcess
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import regdata as rd
12
+
13
+ data_dict = {'motorcycle': rd.MotorcycleHelmet()}
14
+
15
+ data = st.selectbox('Data', [key for key in data_dict.keys()])
16
+
17
+ # st.markdown(f"{data}")
18
+
19
+ x, y, t = data_dict[data].get_data()
20
+ x = x.ravel()
21
+ t = t.ravel()
22
+ y = y.ravel()
23
+
24
+ random = np.random.default_rng(567)
25
+
26
+ noise = 0.01
27
+
28
+ # x = np.sort(random.uniform(-1, 1, 100))
29
+ # y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
30
+ # t = np.linspace(-1.5, 1.5, 500)
31
+
32
+ # Define a small neural network used to non-linearly transform the input data in our model
33
+ class Transformer(nn.Module):
34
+ @nn.compact
35
+ def __call__(self, x):
36
+ x = nn.Dense(features=15)(x)
37
+ x = nn.relu(x)
38
+ x = nn.Dense(features=10)(x)
39
+ x = nn.relu(x)
40
+ x = nn.Dense(features=1)(x)
41
+ return x
42
+
43
+
44
+ class GPLoss(nn.Module):
45
+ @nn.compact
46
+ def __call__(self, x, y, t):
47
+ # Set up a typical Matern-3/2 kernel
48
+ log_sigma = self.param("log_sigma", zeros, ())
49
+ log_rho = self.param("log_rho", zeros, ())
50
+ # log_jitter = self.param("log_jitter", zeros, ())
51
+ base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
52
+ jnp.exp(log_rho)
53
+ )
54
+
55
+ # Define a custom transform to pass the input coordinates through our `Transformer`
56
+ # network from above
57
+ transform = Transformer()
58
+ log_jitter = transform(x.reshape(-1,1)).ravel()
59
+ kernel = base_kernel
60
+
61
+ # Evaluate and return the GP negative log likelihood as usual
62
+ gp = GaussianProcess(
63
+ kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter)
64
+ )
65
+ log_prob, gp_cond = gp.condition(y, t[:, None])
66
+ return -log_prob, (gp_cond.loc, gp_cond.variance), jnp.exp(transform(t[:, None]))
67
+
68
+
69
+ # Define and train the model
70
+ def loss(params):
71
+ return model.apply(params, x, y, t)[0]
72
+
73
+
74
+ model = GPLoss()
75
+ params = model.init(jax.random.PRNGKey(1234), x, y, t)
76
+ tx = optax.sgd(learning_rate=1e-3)
77
+ opt_state = tx.init(params)
78
+ loss_grad_fn = jax.jit(jax.value_and_grad(loss))
79
+ for i in range(100):
80
+ loss_val, grads = loss_grad_fn(params)
81
+ updates, opt_state = tx.update(grads, opt_state)
82
+ params = optax.apply_updates(params, updates)
83
+
84
+ # Plot the results and compare to the true model
85
+ plt.figure()
86
+ _, (mu, var), jitter = model.apply(params, x, y, t)
87
+ var += jitter.ravel()
88
+ # plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
89
+ plt.plot(x, y, ".k", label="data")
90
+ plt.plot(t, mu)
91
+ idx = np.argsort(t)
92
+ plt.plot(t[idx], jitter[idx], label='noise')
93
+ plt.fill_between(
94
+ t, mu + 2*np.sqrt(var), mu - 2*np.sqrt(var), alpha=0.5, label="95% conf"
95
+ )
96
+ # plt.xlim(-1.5, 1.5)
97
+ # plt.ylim(-1.3, 1.3)
98
+ plt.xlabel("x")
99
+ plt.ylabel("y")
100
+ _ = plt.legend()
101
+
102
+ col = st.columns(1)[0]
103
+ with col:
104
+ st.pyplot(plt)
105
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ jaxlib
2
+ jax
3
+ optax
4
+ flax
5
+ tinygp
6
+ matplotlib