joel-woodfield commited on
Commit
fdd3bfb
·
1 Parent(s): 8b69b54

Add option to use a custom function

Browse files
Files changed (1) hide show
  1. gp_visualizer.py +86 -36
gp_visualizer.py CHANGED
@@ -9,6 +9,7 @@ from jinja2 import Template
9
  import matplotlib.pyplot as plt
10
  import matplotlib.lines as mlines
11
  import numpy as np
 
12
  import pandas as pd
13
  from PIL import Image
14
  import plotly.graph_objects as go
@@ -35,6 +36,14 @@ logging.basicConfig(
35
  )
36
  logger = logging.getLogger("ELVIS")
37
 
 
 
 
 
 
 
 
 
38
  def eval_kernel(kernel_str):
39
  # List of allowed kernel constructors
40
  allowed_names = {
@@ -64,6 +73,25 @@ def eval_kernel(kernel_str):
64
 
65
  return result
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True):
68
  np.random.seed(42)
69
  if uniform:
@@ -79,15 +107,21 @@ def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True):
79
 
80
  return X, y
81
 
 
82
  class GPVisualizer:
83
  DEFAULT_KERNEL = "RBF() + WhiteKernel()"
 
84
 
85
  def __init__(self, width, height):
86
  self.canvas_width = width
87
  self.canvas_height = height
88
 
 
89
  self.kernel = eval_kernel(self.DEFAULT_KERNEL)
90
 
 
 
 
91
  self.plot_options = {
92
  "show_training_data": True,
93
  "show_confidence_interval": True,
@@ -102,6 +136,15 @@ class GPVisualizer:
102
  display: none;
103
  }"""
104
 
 
 
 
 
 
 
 
 
 
105
  def plot(self):
106
  '''
107
  '''
@@ -113,17 +156,8 @@ class GPVisualizer:
113
  ax = fig.add_axes([0., 0., 1., 1.]) #
114
  ax.margins(x=0, y=0) # no padding in both directions
115
 
116
- # make a synthetic dataset
117
- logger.info("Generating data")
118
- X_tr, y_tr = make_sine(xlim=(0, 1), nsample=20)
119
- print(X_tr.shape, y_tr.shape)
120
- X_ts, y_ts = make_sine(xlim=(-1, 2), uniform=True, sort=True, nsample=100)
121
-
122
- # fit GP
123
- gpr = GaussianProcessRegressor(kernel=self.kernel, random_state=0)
124
- logger.info('fitting ' + str(gpr))
125
- gpr.fit(X_tr, y_tr)
126
- y_pred, y_std = gpr.predict(X_ts, return_std=True)
127
 
128
  # plot
129
  fig, ax = plt.subplots(figsize=(8, 8))
@@ -131,28 +165,25 @@ class GPVisualizer:
131
  ax.set_xlabel("x")
132
  ax.set_ylabel("y")
133
 
134
- #plt.scatter(X_tr.flatten(), y_tr)
135
- #plt.errorbar(X_ts.flatten(), y_pred, yerr=1.95*y_std, fmt='+', color='orange')
136
-
137
- R2 = gpr.score(X_tr, y_tr)
138
 
139
  if self.plot_options["show_training_data"]:
140
- if len(X_tr) > 1:
141
- plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2), color=self.plot_cmap(0))
142
  else:
143
- plt.scatter(X_tr.flatten(), y_tr, label='training data', color=self.plot_cmap(0))
144
 
145
  if self.plot_options["show_true_function"]:
146
- plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), label='true function', color=self.plot_cmap(1))
147
 
148
  if self.plot_options["show_predictions"]:
149
- plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions', color=self.plot_cmap(2))
150
 
151
  if self.plot_options["show_confidence_interval"]:
152
  plt.fill_between(
153
- X_ts.flatten(),
154
- y_pred - 1.96*y_std,
155
- y_pred + 1.96*y_std,
156
  alpha=0.2,
157
  label='95% confidence interval',
158
  color=self.plot_cmap(3)
@@ -201,10 +232,21 @@ class GPVisualizer:
201
  )
202
  return fig
203
 
204
- #def update_kernel_type(self, kernel_type):
205
- #self.kernel_type = kernel_type
206
- #self.kernel = self.kernels[kernel_type]
207
- #return self.plot()
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def update_kernel_spec(self, kernel_spec):
210
  self.kernel = eval_kernel(kernel_spec)
@@ -231,8 +273,18 @@ class GPVisualizer:
231
 
232
  with gr.Column(scale=1):
233
  with gr.Tab("Settings"):
234
- dataset_radio = gr.Radio(["sine", "Upload"],
235
- value="sine", label="Dataset", elem_id="rowheight")
 
 
 
 
 
 
 
 
 
 
236
 
237
  # upload data
238
  file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
@@ -272,14 +324,12 @@ class GPVisualizer:
272
  with gr.Tab("Usage"):
273
  gr.Markdown(''.join(open('usage.md', 'r').readlines()))
274
 
 
 
 
 
 
275
 
276
- # event handlers for GUI elements
277
- #kernel_type.change(
278
- #fn=self.update_kernel_type,
279
- #inputs=kernel_type,
280
- #outputs=(self.canvas)
281
- #)
282
-
283
  kernel_spec.submit(
284
  fn=self.update_kernel_spec,
285
  inputs=kernel_spec,
 
9
  import matplotlib.pyplot as plt
10
  import matplotlib.lines as mlines
11
  import numpy as np
12
+ import numexpr
13
  import pandas as pd
14
  from PIL import Image
15
  import plotly.graph_objects as go
 
36
  )
37
  logger = logging.getLogger("ELVIS")
38
 
39
+
40
+ NUMEXPR_CONSTANTS = {
41
+ 'pi': np.pi,
42
+ 'PI': np.pi,
43
+ 'e': np.e,
44
+ }
45
+
46
+
47
  def eval_kernel(kernel_str):
48
  # List of allowed kernel constructors
49
  allowed_names = {
 
73
 
74
  return result
75
 
76
+
77
+ def get_function(function, xlim=(-1, 1), nsample=100):
78
+ pi = np.pi
79
+ PI = np.pi
80
+
81
+ x = np.linspace(xlim[0], xlim[1], nsample)
82
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
83
+ x = x.reshape(-1, 1)
84
+ return x, y
85
+
86
+
87
+ def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0):
88
+ x = xlim[0] + (xlim[1] - xlim[0]) * np.random.rand(nsample)
89
+ x = np.sort(x)
90
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + sigma * np.random.randn(nsample)
91
+ x = x.reshape(-1, 1)
92
+ return x, y
93
+
94
+
95
  def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True):
96
  np.random.seed(42)
97
  if uniform:
 
107
 
108
  return X, y
109
 
110
+
111
  class GPVisualizer:
112
  DEFAULT_KERNEL = "RBF() + WhiteKernel()"
113
+ DEFAULT_FUNCTION = "sin(2 * pi * x)"
114
 
115
  def __init__(self, width, height):
116
  self.canvas_width = width
117
  self.canvas_height = height
118
 
119
+ self.function = self.DEFAULT_FUNCTION
120
  self.kernel = eval_kernel(self.DEFAULT_KERNEL)
121
 
122
+ self.x_train, self.y_train = self.generate_data(self.function)
123
+ self.model = self.train_model(self.kernel, self.x_train, self.y_train)
124
+
125
  self.plot_options = {
126
  "show_training_data": True,
127
  "show_confidence_interval": True,
 
136
  display: none;
137
  }"""
138
 
139
+ def generate_data(self, function):
140
+ return get_data_points(function, xlim=(-1, 1), nsample=30, sigma=0.1)
141
+
142
+ def train_model(self, kernel, x_train, y_train):
143
+ gpr = GaussianProcessRegressor(kernel=kernel, random_state=0)
144
+ logger.info('fitting ' + str(gpr))
145
+ gpr.fit(x_train, y_train)
146
+ return gpr
147
+
148
  def plot(self):
149
  '''
150
  '''
 
156
  ax = fig.add_axes([0., 0., 1., 1.]) #
157
  ax.margins(x=0, y=0) # no padding in both directions
158
 
159
+ x_test, y_test = get_function(self.function, xlim=(-2, 2), nsample=100)
160
+ y_pred, y_std = self.model.predict(x_test, return_std=True)
 
 
 
 
 
 
 
 
 
161
 
162
  # plot
163
  fig, ax = plt.subplots(figsize=(8, 8))
 
165
  ax.set_xlabel("x")
166
  ax.set_ylabel("y")
167
 
168
+ R2 = self.model.score(self.x_train, self.y_train)
 
 
 
169
 
170
  if self.plot_options["show_training_data"]:
171
+ if len(self.x_train) > 1:
172
+ plt.scatter(self.x_train.flatten(), self.y_train, label='training data (R2=%.2f)' % (R2), color=self.plot_cmap(0))
173
  else:
174
+ plt.scatter(self.x_train.flatten(), self.y_train, label='training data', color=self.plot_cmap(0))
175
 
176
  if self.plot_options["show_true_function"]:
177
+ plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
178
 
179
  if self.plot_options["show_predictions"]:
180
+ plt.plot(x_test.flatten(), y_pred, linestyle="--", label='mean prediction', color=self.plot_cmap(2))
181
 
182
  if self.plot_options["show_confidence_interval"]:
183
  plt.fill_between(
184
+ x_test.flatten(),
185
+ y_pred - 1.96 * y_std,
186
+ y_pred + 1.96 * y_std,
187
  alpha=0.2,
188
  label='95% confidence interval',
189
  color=self.plot_cmap(3)
 
232
  )
233
  return fig
234
 
235
+ def update_function(self, function):
236
+ # test if function is valid
237
+ try:
238
+ x = np.linspace(-1, 1, 10)
239
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
240
+ except Exception as e:
241
+ raise ValueError(f"Invalid function: {e}")
242
+
243
+ self.function = function
244
+
245
+ # reset data and model
246
+ self.x_train, self.y_train = self.generate_data(self.function)
247
+ self.model = self.train_model(self.kernel, self.x_train, self.y_train)
248
+
249
+ return self.plot()
250
 
251
  def update_kernel_spec(self, kernel_spec):
252
  self.kernel = eval_kernel(kernel_spec)
 
273
 
274
  with gr.Column(scale=1):
275
  with gr.Tab("Settings"):
276
+ dataset_radio = gr.Radio(
277
+ ["Generate", "Upload"],
278
+ value="Generate",
279
+ label="Dataset",
280
+ )
281
+
282
+ function_box = gr.Textbox(
283
+ label="Function",
284
+ placeholder="function of x",
285
+ value=self.DEFAULT_FUNCTION,
286
+ interactive=True,
287
+ )
288
 
289
  # upload data
290
  file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
 
324
  with gr.Tab("Usage"):
325
  gr.Markdown(''.join(open('usage.md', 'r').readlines()))
326
 
327
+ function_box.submit(
328
+ fn=self.update_function,
329
+ inputs=function_box,
330
+ outputs=[self.canvas],
331
+ )
332
 
 
 
 
 
 
 
 
333
  kernel_spec.submit(
334
  fn=self.update_kernel_spec,
335
  inputs=kernel_spec,