Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import config # noqa: F401 | |
| def generate_scatter_plot( | |
| activation, sigma, sigma_bias, include_bias, num_samples=5000 | |
| ): | |
| """Generate scatter plot of function space samples with theoretical curve.""" | |
| rng = np.random.default_rng(42) | |
| theta = rng.normal(0, sigma, num_samples) | |
| if include_bias: | |
| bias = rng.normal(0, sigma_bias, num_samples) | |
| else: | |
| bias = 0.0 | |
| x1, x2 = 1.0, 2.0 | |
| z1 = theta * x1 + bias | |
| z2 = theta * x2 + bias | |
| if activation == "sin": | |
| f1, f2 = np.sin(z1), np.sin(z2) | |
| label1, label2 = ( | |
| r"$f_1 = \sin(\theta \cdot x_1 + b)$", | |
| r"$f_2 = \sin(\theta \cdot x_2 + b)$", | |
| ) | |
| elif activation == "tanh": | |
| f1, f2 = np.tanh(z1), np.tanh(z2) | |
| label1, label2 = ( | |
| r"$f_1 = \tanh(\theta \cdot x_1 + b)$", | |
| r"$f_2 = \tanh(\theta \cdot x_2 + b)$", | |
| ) | |
| else: # relu | |
| f1, f2 = np.maximum(0, z1), np.maximum(0, z2) | |
| label1, label2 = ( | |
| r"$f_1 = \mathrm{ReLU}(\theta \cdot x_1 + b)$", | |
| r"$f_2 = \mathrm{ReLU}(\theta \cdot x_2 + b)$", | |
| ) | |
| fig, ax = plt.subplots(figsize=(6, 2)) | |
| ax.scatter(f1, f2, alpha=0.3, s=5, c="tab:blue") | |
| if activation == "relu": | |
| ax.set_xlim(-0.05, 4.05) | |
| ax.set_ylim(-0.05, 8.05) | |
| else: | |
| ax.set_xlim(-1.02, 1.02) | |
| ax.set_ylim(-1.05, 1.05) | |
| ax.set_xlabel(label1) | |
| ax.set_ylabel(label2) | |
| return fig | |
| # --- Page UI --- | |
| st.title("Function Space Priors: Low-Dimensional Manifolds") | |
| col1, col2 = st.columns([0.5, 2]) | |
| with col1: | |
| activation = st.selectbox( | |
| "Activation", | |
| options=["sin", "tanh", "relu"], | |
| index=0, | |
| ) | |
| sigma = st.slider( | |
| "Weight std ($\\sigma$)", | |
| min_value=0.1, | |
| max_value=3.0, | |
| value=1.0, | |
| step=0.1, | |
| ) | |
| include_bias = st.toggle("Include bias") | |
| if include_bias: | |
| sigma_bias = st.slider( | |
| "Bias std ($\\sigma_b$)", | |
| min_value=0.1, | |
| max_value=3.0, | |
| value=1.0, | |
| step=0.1, | |
| ) | |
| else: | |
| sigma_bias = 0.0 | |
| with col2: | |
| scatter_fig = generate_scatter_plot(activation, sigma, sigma_bias, include_bias) | |
| st.pyplot(scatter_fig) | |