Zeel commited on
Commit
dd9f338
·
1 Parent(s): 5dd083c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import streamlit as st
3
+ import matplotlib.pyplot as plt
4
+
5
+ st.subheader("Gaussian Process Regression")
6
+ st_col = st.columns(1)[0]
7
+
8
+ lengthscale = st.slider('Lengthscale', min_value=0.01, max_value=1.0, value=0.25, step=0.01)
9
+ variance = st.slider('Variance', min_value=0.001, max_value=0.1, value=0.025, step=0.001)
10
+ noise_variance = st.slider('Noise Variance', min_value=0.001, max_value=0.01, value=0.0, step=0.001)
11
+
12
+ def rbf_kernel(x1, x2, lengthscale, variance):
13
+ x1_ = x1.reshape(-1,1)/lengthscale
14
+ x2_ = x2.reshape(1,-1)/lengthscale
15
+ dist_sqr = (x1_ - x2_) ** 2
16
+ return variance * np.exp(-dist_sqr)
17
+
18
+ fig, ax = plt.subplots()
19
+
20
+ x_train = np.array([0.2, 0.5, 0.8]).reshape(-1,1)
21
+ y_train = np.array([0.8, 0.3, 0.6]).reshape(-1,1)
22
+
23
+ ax.scatter(x_train, y_train, label='train points')
24
+ ax.set_xlim(-0.2,1.2)
25
+ ax.set_ylim(-0.2,1.2)
26
+
27
+ N = 100
28
+ x_test = np.linspace(-0.2,1.2,N).reshape(-1,1)
29
+
30
+ k_train_train = rbf_kernel(x_train, x_train, lengthscale, variance)
31
+ k_test_train = rbf_kernel(x_test, x_train, lengthscale, variance)
32
+ k_test_test = rbf_kernel(x_test, x_test, lengthscale, variance)
33
+
34
+ c = np.linalg.inv(np.linalg.cholesky(k_train_train + noise_variance * np.eye(3)))
35
+ k_inv = np.dot(c.T,c)
36
+
37
+ pred_mean = k_test_train@k_inv@y_train
38
+ pred_var = k_test_test - k_test_train@k_inv@k_test_train.T
39
+ pred_std2 = 2 * (pred_var.diagonal() ** 0.5)
40
+
41
+ ax.plot(x_test, pred_mean, label='predictive mean')
42
+ ax.fill_between(x_test.ravel(), pred_mean.ravel()-pred_std2.ravel(),
43
+ pred_mean.ravel()+pred_std2.ravel(), alpha=0.5, label='95% confidence')
44
+ ax.set_xlabel('x')
45
+ ax.set_ylabel('y')
46
+ ax.legend()
47
+
48
+ with st_col:
49
+ st.pyplot(fig)
50
+
51
+ hide_streamlit_style = """
52
+ <style>
53
+ #MainMenu {visibility: hidden;}
54
+ footer {visibility: hidden;}
55
+ </style>
56
+ """
57
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)