Spaces:
Runtime error
Runtime error
Change numpy code by jax.numpy
Browse files
app.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import numpy as
|
|
|
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
st.title('Fitting simple models with JAX')
|
| 6 |
st.header('A quadratric regression example')
|
| 7 |
|
|
@@ -15,19 +20,18 @@ number_of_observations = st.sidebar.slider('Number of observations', min_value=5
|
|
| 15 |
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
|
| 16 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
X = np.column_stack((np.ones(number_of_observations),
|
| 21 |
np.random.random(number_of_observations)))
|
| 22 |
|
| 23 |
w = np.array([3.0, -20.0, 32.0]) # coefficients
|
| 24 |
-
|
| 25 |
X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
| 26 |
-
additional_noise = 8 *
|
| 27 |
-
y =
|
| 28 |
+ additional_noise
|
| 29 |
|
| 30 |
|
|
|
|
| 31 |
fig, ax = plt.subplots(dpi=320)
|
| 32 |
ax.set_xlim((0,1))
|
| 33 |
ax.set_ylim((-5,26))
|
|
@@ -46,6 +50,10 @@ st.latex(r'''\bf{w}\leftarrow \bf{w}-\eta \frac{\partial\ell(\bf{X},\bf{y}, \bf{
|
|
| 46 |
|
| 47 |
# Fitting by the respective cost_function
|
| 48 |
if cost_function == 'RMSE-Loss':
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
st.write('You selected the RMSE loss function.')
|
| 50 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
|
| 51 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import jax
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
|
| 6 |
+
# Set random key
|
| 7 |
+
seed=321
|
| 8 |
+
key = jax.random.PRNGKey(seed)
|
| 9 |
+
|
| 10 |
st.title('Fitting simple models with JAX')
|
| 11 |
st.header('A quadratric regression example')
|
| 12 |
|
|
|
|
| 20 |
noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise', min_value = 0.0, max_value=2.0, value=1.0)
|
| 21 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
| 22 |
|
| 23 |
+
# Generate random data
|
|
|
|
| 24 |
X = np.column_stack((np.ones(number_of_observations),
|
| 25 |
np.random.random(number_of_observations)))
|
| 26 |
|
| 27 |
w = np.array([3.0, -20.0, 32.0]) # coefficients
|
|
|
|
| 28 |
X = np.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
| 29 |
+
additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
|
| 30 |
+
y = jnp.dot(X, w) + noise_standard_deviation * jax.random.normal(key, shape=[number_of_observations,]) \
|
| 31 |
+ additional_noise
|
| 32 |
|
| 33 |
|
| 34 |
+
# Plot the data
|
| 35 |
fig, ax = plt.subplots(dpi=320)
|
| 36 |
ax.set_xlim((0,1))
|
| 37 |
ax.set_ylim((-5,26))
|
|
|
|
| 50 |
|
| 51 |
# Fitting by the respective cost_function
|
| 52 |
if cost_function == 'RMSE-Loss':
|
| 53 |
+
|
| 54 |
+
def loss(w):
|
| 55 |
+
return 1/X.shape[0] * jax.numpy.linalg.norm(jnp.dot(X, w) - y)**2
|
| 56 |
+
|
| 57 |
st.write('You selected the RMSE loss function.')
|
| 58 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}||Xw - y||_{2}^2''')
|
| 59 |
st.latex(r'''\ell(X, y, w)=\frac{1}{m}\big(\sqrt{(Xw - y)\cdot(Xw - y)}\big)^2''')
|