import streamlit as st import numpy as np import matplotlib.pyplot as plt import config # noqa: F401 @st.cache_data def generate_scatter_plot( activation, sigma, sigma_bias, include_bias, num_samples=5000 ): """Generate scatter plot of function space samples with theoretical curve.""" rng = np.random.default_rng(42) theta = rng.normal(0, sigma, num_samples) if include_bias: bias = rng.normal(0, sigma_bias, num_samples) else: bias = 0.0 x1, x2 = 1.0, 2.0 z1 = theta * x1 + bias z2 = theta * x2 + bias if activation == "sin": f1, f2 = np.sin(z1), np.sin(z2) label1, label2 = ( r"$f_1 = \sin(\theta \cdot x_1 + b)$", r"$f_2 = \sin(\theta \cdot x_2 + b)$", ) elif activation == "tanh": f1, f2 = np.tanh(z1), np.tanh(z2) label1, label2 = ( r"$f_1 = \tanh(\theta \cdot x_1 + b)$", r"$f_2 = \tanh(\theta \cdot x_2 + b)$", ) else: # relu f1, f2 = np.maximum(0, z1), np.maximum(0, z2) label1, label2 = ( r"$f_1 = \mathrm{ReLU}(\theta \cdot x_1 + b)$", r"$f_2 = \mathrm{ReLU}(\theta \cdot x_2 + b)$", ) fig, ax = plt.subplots(figsize=(6, 2)) ax.scatter(f1, f2, alpha=0.3, s=5, c="tab:blue") if activation == "relu": ax.set_xlim(-0.05, 4.05) ax.set_ylim(-0.05, 8.05) else: ax.set_xlim(-1.02, 1.02) ax.set_ylim(-1.05, 1.05) ax.set_xlabel(label1) ax.set_ylabel(label2) return fig # --- Page UI --- st.title("Function Space Priors: Low-Dimensional Manifolds") col1, col2 = st.columns([0.5, 2]) with col1: activation = st.selectbox( "Activation", options=["sin", "tanh", "relu"], index=0, ) sigma = st.slider( "Weight std ($\\sigma$)", min_value=0.1, max_value=3.0, value=1.0, step=0.1, ) include_bias = st.toggle("Include bias") if include_bias: sigma_bias = st.slider( "Bias std ($\\sigma_b$)", min_value=0.1, max_value=3.0, value=1.0, step=0.1, ) else: sigma_bias = 0.0 with col2: scatter_fig = generate_scatter_plot(activation, sigma, sigma_bias, include_bias) st.pyplot(scatter_fig)