Nipun's picture
updates
311f782
import streamlit as st
import torch
import numpy as np
import plotly.graph_objects as go
st.set_page_config(layout="wide")
st.title("2D Probability Distributions")
# Sidebar for controls
with st.sidebar:
st.header("Controls")
distribution_type = st.selectbox("Select Distribution", ["Bivariate Normal", "2D Uniform"])
if distribution_type == "Bivariate Normal":
mu_x = st.slider("Mean of X (μx)", -2.0, 2.0, 0.0)
mu_y = st.slider("Mean of Y (μy)", -2.0, 2.0, 0.0)
sigma_x = st.slider("Std Dev of X (σx)", 0.1, 2.0, 1.0)
sigma_y = st.slider("Std Dev of Y (σy)", 0.1, 2.0, 1.0)
rho = st.slider("Correlation (ρ)", -0.9, 0.9, 0.0)
# Covariance matrix
cov_matrix = torch.tensor([[sigma_x**2, rho * sigma_x * sigma_y],
[rho * sigma_x * sigma_y, sigma_y**2]])
mean_vector = torch.tensor([mu_x, mu_y])
# Create distribution
distribution = torch.distributions.MultivariateNormal(mean_vector, cov_matrix)
elif distribution_type == "2D Uniform":
low_x = st.slider("Lower Bound X", -4.0, 0.0, -2.0)
high_x = st.slider("Upper Bound X", 0.0, 4.0, 2.0)
low_y = st.slider("Lower Bound Y", -4.0, 0.0, -2.0)
high_y = st.slider("Upper Bound Y", 0.0, 4.0, 2.0)
# Create distribution
distribution = torch.distributions.Uniform(torch.tensor([low_x, low_y]), torch.tensor([high_x, high_y]))
# Generate grid
x = torch.linspace(-4, 4, 100)
y = torch.linspace(-4, 4, 100)
X, Y = torch.meshgrid(x, y, indexing='xy')
pos = torch.stack((X, Y), dim=-1)
if distribution_type == "Bivariate Normal":
Z = torch.exp(distribution.log_prob(pos))
# Compute marginal distributions
marginal_x = torch.distributions.Normal(mean_vector[0], torch.sqrt(cov_matrix[0, 0]))
marginal_y = torch.distributions.Normal(mean_vector[1], torch.sqrt(cov_matrix[1, 1]))
pdf_x = torch.exp(marginal_x.log_prob(x))
pdf_y = torch.exp(marginal_y.log_prob(y))
elif distribution_type == "2D Uniform":
Z = torch.zeros_like(X)
mask_x = (X >= low_x) & (X <= high_x)
mask_y = (Y >= low_y) & (Y <= high_y)
Z[mask_x & mask_y] = 1.0 / ((high_x - low_x) * (high_y - low_y))
pdf_x = torch.where((x >= low_x) & (x <= high_x), 1 / (high_x - low_x), 0)
pdf_y = torch.where((y >= low_y) & (y <= high_y), 1 / (high_y - low_y), 0)
# Convert to numpy for plotting
X, Y, Z = X.numpy(), Y.numpy(), Z.numpy()
# Create 3D surface plot
fig = go.Figure()
fig.add_trace(go.Surface(z=Z, x=X, y=Y, colorscale='Viridis', opacity=0.9, name='Density'))
# Marginal distributions on the walls
fig.add_trace(go.Scatter3d(x=x.numpy(), y=np.full_like(x.numpy(), -4), z=pdf_x.numpy() / np.max(pdf_x.numpy()) * np.max(Z), mode='lines', line=dict(color='red', width=4), name='Marginal X'))
fig.add_trace(go.Scatter3d(x=np.full_like(y.numpy(), 4), y=y.numpy(), z=pdf_y.numpy() / np.max(pdf_y.numpy()) * np.max(Z), mode='lines', line=dict(color='blue', width=4), name='Marginal Y'))
fig.update_layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Density',
),
margin=dict(l=0, r=0, t=20, b=20),
legend=dict(x=0.8, y=0.9, font=dict(size=14)),
width=1100, height=800
)
# Main display
st.plotly_chart(fig, use_container_width=True)