Zeel commited on
Commit
9d9276d
·
1 Parent(s): 06326ea

add gpflow

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -55,21 +55,38 @@ st.title("Heteroscedastic Gaussian Processes")
55
 
56
  st.markdown(r"We are learning the noise v/s inputs relationship with a neural net.")
57
 
58
- data = st.selectbox("Data", ["Motorcycle", "Olympic"])
59
  if data == "Motorcycle":
60
  data_x, data_y = mcycle_x, mcycle_y
61
  elif data == "Olympic":
62
  data_x, data_y = oly_x, oly_y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  x = (data_x - data_x.mean()) / data_x.std()
65
  y = (data_y - data_y.mean()) / data_y.std()
66
 
67
  n_tests = st.number_input(
68
- "Number of test points", min_value=10, max_value=200, value=100
69
  )
70
 
71
- n_iters = st.number_input("Number of iterations", min_value=1, max_value=100, value=10)
72
-
73
  t = np.linspace(x.min(), x.max(), n_tests)
74
  noise = 0.01
75
 
@@ -78,12 +95,16 @@ noise = 0.01
78
  # t = np.linspace(-1.5, 1.5, 500)
79
 
80
  # Define a small neural network used to non-linearly transform the input data in our model
 
 
 
 
81
  class Transformer(nn.Module):
82
  @nn.compact
83
  def __call__(self, x):
84
- x = nn.Dense(features=15)(x)
85
  x = nn.relu(x)
86
- x = nn.Dense(features=10)(x)
87
  x = nn.relu(x)
88
  x = nn.Dense(features=1)(x)
89
  return x
@@ -138,15 +159,19 @@ def loss(params):
138
 
139
  base_model = BaseGPLoss()
140
  model = GPLoss()
141
- base_params = base_model.init(jax.random.PRNGKey(1234), x, y, t)
142
- params = model.init(jax.random.PRNGKey(1234), x, y, t)
143
- tx = optax.sgd(learning_rate=1e-3)
 
 
 
144
  base_opt_state = tx.init(base_params)
145
  opt_state = tx.init(params)
146
  loss_grad_fn = jax.jit(jax.value_and_grad(loss))
147
  base_losses = []
148
  losses = []
149
- for i in range(200):
 
150
  m = base_model
151
  base_loss_val, base_grads = loss_grad_fn(base_params)
152
  m = model
@@ -157,6 +182,7 @@ for i in range(200):
157
  params = optax.apply_updates(params, updates)
158
  losses.append(loss_val)
159
  base_losses.append(base_loss_val)
 
160
 
161
  # Plot the results and compare to the true model
162
  fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
 
55
 
56
  st.markdown(r"We are learning the noise v/s inputs relationship with a neural net.")
57
 
58
+ data = st.selectbox("Data", ["Motorcycle", "Olympic", 'GPflow'])
59
  if data == "Motorcycle":
60
  data_x, data_y = mcycle_x, mcycle_y
61
  elif data == "Olympic":
62
  data_x, data_y = oly_x, oly_y
63
+ elif data == 'GPflow':
64
+ N = 1001
65
+
66
+ # Build inputs X
67
+ data_x = np.linspace(0, 4 * np.pi, N)
68
+
69
+ # Deterministic functions in place of latent ones
70
+ f1 = np.sin
71
+ f2 = np.cos
72
+
73
+ # Use transform = exp to ensure positive-only scale values
74
+ transform = np.exp
75
+
76
+ # Compute loc and scale as functions of input X
77
+ loc = f1(data_x)
78
+ scale = transform(f2(data_x))
79
+
80
+ # Sample outputs Y from Gaussian Likelihood
81
+ data_y = np.random.normal(loc, scale)
82
 
83
  x = (data_x - data_x.mean()) / data_x.std()
84
  y = (data_y - data_y.mean()) / data_y.std()
85
 
86
  n_tests = st.number_input(
87
+ "Number of test points", min_value=50, max_value=1000, value=100
88
  )
89
 
 
 
90
  t = np.linspace(x.min(), x.max(), n_tests)
91
  noise = 0.01
92
 
 
95
  # t = np.linspace(-1.5, 1.5, 500)
96
 
97
  # Define a small neural network used to non-linearly transform the input data in our model
98
+
99
+ fet1 = st.slider("Number of neurons in Layer1", min_value=2, max_value=30, value=15)
100
+ fet2 = st.slider("Number of neurons in Layer1", min_value=2, max_value=30, value=10)
101
+
102
  class Transformer(nn.Module):
103
  @nn.compact
104
  def __call__(self, x):
105
+ x = nn.Dense(features=fet1)(x)
106
  x = nn.relu(x)
107
+ x = nn.Dense(features=fet2)(x)
108
  x = nn.relu(x)
109
  x = nn.Dense(features=1)(x)
110
  return x
 
159
 
160
  base_model = BaseGPLoss()
161
  model = GPLoss()
162
+ seed = np.random.randint(0,100)
163
+ base_params = base_model.init(jax.random.PRNGKey(seed), x, y, t)
164
+ params = model.init(jax.random.PRNGKey(np.random.randint(seed)), x, y, t)
165
+ n_iters = st.number_input("Number of iterations", min_value=1, max_value=200, value=100)
166
+ lr = st.selectbox("Learning rate", [0.1, 0.01, 0.001, 0.0001], 1)
167
+ tx = optax.sgd(learning_rate=lr)
168
  base_opt_state = tx.init(base_params)
169
  opt_state = tx.init(params)
170
  loss_grad_fn = jax.jit(jax.value_and_grad(loss))
171
  base_losses = []
172
  losses = []
173
+ my_bar = st.progress(0)
174
+ for i in range(n_iters):
175
  m = base_model
176
  base_loss_val, base_grads = loss_grad_fn(base_params)
177
  m = model
 
182
  params = optax.apply_updates(params, updates)
183
  losses.append(loss_val)
184
  base_losses.append(base_loss_val)
185
+ my_bar.progress((i+1) / n_iters)
186
 
187
  # Plot the results and compare to the true model
188
  fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))