Spaces:
Runtime error
Runtime error
Update data generation
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
import jax.numpy as jnp
|
| 3 |
import jax
|
| 4 |
import matplotlib.pyplot as plt
|
|
@@ -21,15 +22,14 @@ noise_standard_deviation = st.sidebar.slider('Standard deviation of the noise',
|
|
| 21 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
| 22 |
|
| 23 |
# Generate random data
|
| 24 |
-
X =
|
| 25 |
-
|
| 26 |
-
w = jnp.array([3.0, -20.0, 32.0]) # coefficients
|
| 27 |
X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
| 28 |
-
additional_noise = 8 * jax.random.bernoulli(key, p=0.08, shape=[number_of_observations,])
|
| 29 |
-
y = jnp.dot(X, w) + noise_standard_deviation * jax.random.normal(key, shape=[number_of_observations,]) \
|
| 30 |
-
+ additional_noise
|
| 31 |
-
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Plot the data
|
| 34 |
fig, ax = plt.subplots(dpi=320)
|
| 35 |
ax.set_xlim((0,1))
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
import jax.numpy as jnp
|
| 4 |
import jax
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 22 |
cost_function = st.sidebar.radio('What cost function you want to use for the fitting?', options=('RMSE-Loss', 'Huber-Loss'))
|
| 23 |
|
| 24 |
# Generate random data
|
| 25 |
+
X = np.column_stack((np.ones(number_of_observations),
|
| 26 |
+
np.random.random(number_of_observations)))
|
|
|
|
| 27 |
X = jnp.column_stack((X, X[:,1] ** 2)) # add x**2 column
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
additional_noise = 8 * np.random.binomial(1, 0.03, size = number_of_observations)
|
| 30 |
+
y = jnp.array(np.dot(X, w) + noise_standard_deviation * np.random.randn(number_of_observations) \
|
| 31 |
+
+ additional_noise)
|
| 32 |
+
|
| 33 |
# Plot the data
|
| 34 |
fig, ax = plt.subplots(dpi=320)
|
| 35 |
ax.set_xlim((0,1))
|