Spaces:
Sleeping
Sleeping
Commit ·
fdd3bfb
1
Parent(s): 8b69b54
Add option to use a custom function
Browse files- gp_visualizer.py +86 -36
gp_visualizer.py
CHANGED
|
@@ -9,6 +9,7 @@ from jinja2 import Template
|
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import matplotlib.lines as mlines
|
| 11 |
import numpy as np
|
|
|
|
| 12 |
import pandas as pd
|
| 13 |
from PIL import Image
|
| 14 |
import plotly.graph_objects as go
|
|
@@ -35,6 +36,14 @@ logging.basicConfig(
|
|
| 35 |
)
|
| 36 |
logger = logging.getLogger("ELVIS")
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def eval_kernel(kernel_str):
|
| 39 |
# List of allowed kernel constructors
|
| 40 |
allowed_names = {
|
|
@@ -64,6 +73,25 @@ def eval_kernel(kernel_str):
|
|
| 64 |
|
| 65 |
return result
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True):
|
| 68 |
np.random.seed(42)
|
| 69 |
if uniform:
|
|
@@ -79,15 +107,21 @@ def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True):
|
|
| 79 |
|
| 80 |
return X, y
|
| 81 |
|
|
|
|
| 82 |
class GPVisualizer:
|
| 83 |
DEFAULT_KERNEL = "RBF() + WhiteKernel()"
|
|
|
|
| 84 |
|
| 85 |
def __init__(self, width, height):
|
| 86 |
self.canvas_width = width
|
| 87 |
self.canvas_height = height
|
| 88 |
|
|
|
|
| 89 |
self.kernel = eval_kernel(self.DEFAULT_KERNEL)
|
| 90 |
|
|
|
|
|
|
|
|
|
|
| 91 |
self.plot_options = {
|
| 92 |
"show_training_data": True,
|
| 93 |
"show_confidence_interval": True,
|
|
@@ -102,6 +136,15 @@ class GPVisualizer:
|
|
| 102 |
display: none;
|
| 103 |
}"""
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def plot(self):
|
| 106 |
'''
|
| 107 |
'''
|
|
@@ -113,17 +156,8 @@ class GPVisualizer:
|
|
| 113 |
ax = fig.add_axes([0., 0., 1., 1.]) #
|
| 114 |
ax.margins(x=0, y=0) # no padding in both directions
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
X_tr, y_tr = make_sine(xlim=(0, 1), nsample=20)
|
| 119 |
-
print(X_tr.shape, y_tr.shape)
|
| 120 |
-
X_ts, y_ts = make_sine(xlim=(-1, 2), uniform=True, sort=True, nsample=100)
|
| 121 |
-
|
| 122 |
-
# fit GP
|
| 123 |
-
gpr = GaussianProcessRegressor(kernel=self.kernel, random_state=0)
|
| 124 |
-
logger.info('fitting ' + str(gpr))
|
| 125 |
-
gpr.fit(X_tr, y_tr)
|
| 126 |
-
y_pred, y_std = gpr.predict(X_ts, return_std=True)
|
| 127 |
|
| 128 |
# plot
|
| 129 |
fig, ax = plt.subplots(figsize=(8, 8))
|
|
@@ -131,28 +165,25 @@ class GPVisualizer:
|
|
| 131 |
ax.set_xlabel("x")
|
| 132 |
ax.set_ylabel("y")
|
| 133 |
|
| 134 |
-
|
| 135 |
-
#plt.errorbar(X_ts.flatten(), y_pred, yerr=1.95*y_std, fmt='+', color='orange')
|
| 136 |
-
|
| 137 |
-
R2 = gpr.score(X_tr, y_tr)
|
| 138 |
|
| 139 |
if self.plot_options["show_training_data"]:
|
| 140 |
-
if len(
|
| 141 |
-
plt.scatter(
|
| 142 |
else:
|
| 143 |
-
plt.scatter(
|
| 144 |
|
| 145 |
if self.plot_options["show_true_function"]:
|
| 146 |
-
plt.plot(
|
| 147 |
|
| 148 |
if self.plot_options["show_predictions"]:
|
| 149 |
-
plt.
|
| 150 |
|
| 151 |
if self.plot_options["show_confidence_interval"]:
|
| 152 |
plt.fill_between(
|
| 153 |
-
|
| 154 |
-
y_pred - 1.96*y_std,
|
| 155 |
-
y_pred + 1.96*y_std,
|
| 156 |
alpha=0.2,
|
| 157 |
label='95% confidence interval',
|
| 158 |
color=self.plot_cmap(3)
|
|
@@ -201,10 +232,21 @@ class GPVisualizer:
|
|
| 201 |
)
|
| 202 |
return fig
|
| 203 |
|
| 204 |
-
|
| 205 |
-
#
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
def update_kernel_spec(self, kernel_spec):
|
| 210 |
self.kernel = eval_kernel(kernel_spec)
|
|
@@ -231,8 +273,18 @@ class GPVisualizer:
|
|
| 231 |
|
| 232 |
with gr.Column(scale=1):
|
| 233 |
with gr.Tab("Settings"):
|
| 234 |
-
dataset_radio = gr.Radio(
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
# upload data
|
| 238 |
file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
|
|
@@ -272,14 +324,12 @@ class GPVisualizer:
|
|
| 272 |
with gr.Tab("Usage"):
|
| 273 |
gr.Markdown(''.join(open('usage.md', 'r').readlines()))
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
-
# event handlers for GUI elements
|
| 277 |
-
#kernel_type.change(
|
| 278 |
-
#fn=self.update_kernel_type,
|
| 279 |
-
#inputs=kernel_type,
|
| 280 |
-
#outputs=(self.canvas)
|
| 281 |
-
#)
|
| 282 |
-
|
| 283 |
kernel_spec.submit(
|
| 284 |
fn=self.update_kernel_spec,
|
| 285 |
inputs=kernel_spec,
|
|
|
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import matplotlib.lines as mlines
|
| 11 |
import numpy as np
|
| 12 |
+
import numexpr
|
| 13 |
import pandas as pd
|
| 14 |
from PIL import Image
|
| 15 |
import plotly.graph_objects as go
|
|
|
|
| 36 |
)
|
| 37 |
logger = logging.getLogger("ELVIS")
|
| 38 |
|
| 39 |
+
|
| 40 |
+
NUMEXPR_CONSTANTS = {
|
| 41 |
+
'pi': np.pi,
|
| 42 |
+
'PI': np.pi,
|
| 43 |
+
'e': np.e,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
def eval_kernel(kernel_str):
|
| 48 |
# List of allowed kernel constructors
|
| 49 |
allowed_names = {
|
|
|
|
| 73 |
|
| 74 |
return result
|
| 75 |
|
| 76 |
+
|
| 77 |
+
def get_function(function, xlim=(-1, 1), nsample=100):
|
| 78 |
+
pi = np.pi
|
| 79 |
+
PI = np.pi
|
| 80 |
+
|
| 81 |
+
x = np.linspace(xlim[0], xlim[1], nsample)
|
| 82 |
+
y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
|
| 83 |
+
x = x.reshape(-1, 1)
|
| 84 |
+
return x, y
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0):
|
| 88 |
+
x = xlim[0] + (xlim[1] - xlim[0]) * np.random.rand(nsample)
|
| 89 |
+
x = np.sort(x)
|
| 90 |
+
y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + sigma * np.random.randn(nsample)
|
| 91 |
+
x = x.reshape(-1, 1)
|
| 92 |
+
return x, y
|
| 93 |
+
|
| 94 |
+
|
| 95 |
def make_sine(xlim=(0,1), nsample=20, sigma=0.1, uniform=False, sort=True):
|
| 96 |
np.random.seed(42)
|
| 97 |
if uniform:
|
|
|
|
| 107 |
|
| 108 |
return X, y
|
| 109 |
|
| 110 |
+
|
| 111 |
class GPVisualizer:
|
| 112 |
DEFAULT_KERNEL = "RBF() + WhiteKernel()"
|
| 113 |
+
DEFAULT_FUNCTION = "sin(2 * pi * x)"
|
| 114 |
|
| 115 |
def __init__(self, width, height):
|
| 116 |
self.canvas_width = width
|
| 117 |
self.canvas_height = height
|
| 118 |
|
| 119 |
+
self.function = self.DEFAULT_FUNCTION
|
| 120 |
self.kernel = eval_kernel(self.DEFAULT_KERNEL)
|
| 121 |
|
| 122 |
+
self.x_train, self.y_train = self.generate_data(self.function)
|
| 123 |
+
self.model = self.train_model(self.kernel, self.x_train, self.y_train)
|
| 124 |
+
|
| 125 |
self.plot_options = {
|
| 126 |
"show_training_data": True,
|
| 127 |
"show_confidence_interval": True,
|
|
|
|
| 136 |
display: none;
|
| 137 |
}"""
|
| 138 |
|
| 139 |
+
def generate_data(self, function):
|
| 140 |
+
return get_data_points(function, xlim=(-1, 1), nsample=30, sigma=0.1)
|
| 141 |
+
|
| 142 |
+
def train_model(self, kernel, x_train, y_train):
|
| 143 |
+
gpr = GaussianProcessRegressor(kernel=kernel, random_state=0)
|
| 144 |
+
logger.info('fitting ' + str(gpr))
|
| 145 |
+
gpr.fit(x_train, y_train)
|
| 146 |
+
return gpr
|
| 147 |
+
|
| 148 |
def plot(self):
|
| 149 |
'''
|
| 150 |
'''
|
|
|
|
| 156 |
ax = fig.add_axes([0., 0., 1., 1.]) #
|
| 157 |
ax.margins(x=0, y=0) # no padding in both directions
|
| 158 |
|
| 159 |
+
x_test, y_test = get_function(self.function, xlim=(-2, 2), nsample=100)
|
| 160 |
+
y_pred, y_std = self.model.predict(x_test, return_std=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# plot
|
| 163 |
fig, ax = plt.subplots(figsize=(8, 8))
|
|
|
|
| 165 |
ax.set_xlabel("x")
|
| 166 |
ax.set_ylabel("y")
|
| 167 |
|
| 168 |
+
R2 = self.model.score(self.x_train, self.y_train)
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
if self.plot_options["show_training_data"]:
|
| 171 |
+
if len(self.x_train) > 1:
|
| 172 |
+
plt.scatter(self.x_train.flatten(), self.y_train, label='training data (R2=%.2f)' % (R2), color=self.plot_cmap(0))
|
| 173 |
else:
|
| 174 |
+
plt.scatter(self.x_train.flatten(), self.y_train, label='training data', color=self.plot_cmap(0))
|
| 175 |
|
| 176 |
if self.plot_options["show_true_function"]:
|
| 177 |
+
plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
|
| 178 |
|
| 179 |
if self.plot_options["show_predictions"]:
|
| 180 |
+
plt.plot(x_test.flatten(), y_pred, linestyle="--", label='mean prediction', color=self.plot_cmap(2))
|
| 181 |
|
| 182 |
if self.plot_options["show_confidence_interval"]:
|
| 183 |
plt.fill_between(
|
| 184 |
+
x_test.flatten(),
|
| 185 |
+
y_pred - 1.96 * y_std,
|
| 186 |
+
y_pred + 1.96 * y_std,
|
| 187 |
alpha=0.2,
|
| 188 |
label='95% confidence interval',
|
| 189 |
color=self.plot_cmap(3)
|
|
|
|
| 232 |
)
|
| 233 |
return fig
|
| 234 |
|
| 235 |
+
def update_function(self, function):
|
| 236 |
+
# test if function is valid
|
| 237 |
+
try:
|
| 238 |
+
x = np.linspace(-1, 1, 10)
|
| 239 |
+
y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
|
| 240 |
+
except Exception as e:
|
| 241 |
+
raise ValueError(f"Invalid function: {e}")
|
| 242 |
+
|
| 243 |
+
self.function = function
|
| 244 |
+
|
| 245 |
+
# reset data and model
|
| 246 |
+
self.x_train, self.y_train = self.generate_data(self.function)
|
| 247 |
+
self.model = self.train_model(self.kernel, self.x_train, self.y_train)
|
| 248 |
+
|
| 249 |
+
return self.plot()
|
| 250 |
|
| 251 |
def update_kernel_spec(self, kernel_spec):
|
| 252 |
self.kernel = eval_kernel(kernel_spec)
|
|
|
|
| 273 |
|
| 274 |
with gr.Column(scale=1):
|
| 275 |
with gr.Tab("Settings"):
|
| 276 |
+
dataset_radio = gr.Radio(
|
| 277 |
+
["Generate", "Upload"],
|
| 278 |
+
value="Generate",
|
| 279 |
+
label="Dataset",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
function_box = gr.Textbox(
|
| 283 |
+
label="Function",
|
| 284 |
+
placeholder="function of x",
|
| 285 |
+
value=self.DEFAULT_FUNCTION,
|
| 286 |
+
interactive=True,
|
| 287 |
+
)
|
| 288 |
|
| 289 |
# upload data
|
| 290 |
file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
|
|
|
|
| 324 |
with gr.Tab("Usage"):
|
| 325 |
gr.Markdown(''.join(open('usage.md', 'r').readlines()))
|
| 326 |
|
| 327 |
+
function_box.submit(
|
| 328 |
+
fn=self.update_function,
|
| 329 |
+
inputs=function_box,
|
| 330 |
+
outputs=[self.canvas],
|
| 331 |
+
)
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
kernel_spec.submit(
|
| 334 |
fn=self.update_kernel_spec,
|
| 335 |
inputs=kernel_spec,
|