aryamanpathak's picture
Update app.py
f8ea705 verified
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Correct the import name clash
from scipy.stats import beta as beta_distribution # <--- RENAMED IMPORT
import random
import io
from PIL import Image # Requires Pillow: pip install Pillow
# --- Configuration ---
priors = {
"Uniform Prior (Beta(1,1))": (1, 1),
"Prior Biased Toward Heads (Beta(5,1))": (5, 1),
"Prior Biased Toward Fair (Beta(2,2))": (2, 2)
}
prior_names = list(priors.keys())
x_pdf = np.linspace(0, 1, 300)
# --- Helper Function to Generate Plot Image (NumPy Array) ---
def generate_plot_image(alpha, beta, title_prefix=""): # Function arguments 'alpha', 'beta' are the numbers
"""Generates a Matplotlib plot and returns it as a NumPy array."""
fig, ax = plt.subplots(figsize=(6, 4))
plot_title = f"{title_prefix}Beta({alpha:.1f}, {beta:.1f})" # Use formatting
# Prevent plotting issues with non-positive alpha/beta
if alpha <= 0 or beta <= 0:
ax.text(0.5, 0.5, 'Invalid parameters\nCannot plot Beta(≤0, ≤0)',
horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
ax.set_title(plot_title, fontsize=10)
ax.set_xlabel("Probability of Heads (r)", fontsize=9)
ax.set_ylabel("Density", fontsize=9)
else:
# Calculate PDF using the RENAMED distribution object
try:
# Use the renamed import here:
pdf = beta_distribution.pdf(x_pdf, alpha, beta) # <-- CORRECTED
# Handle potential inf values if alpha/beta are very small
pdf = np.nan_to_num(pdf, nan=0.0, posinf=np.nanmax(pdf[np.isfinite(pdf)]) if np.any(np.isfinite(pdf)) else 1.0, neginf=0.0) # Replace inf with max finite value or 1
except ValueError:
pdf = np.zeros_like(x_pdf) # Fallback
ax.plot(x_pdf, pdf, color='dodgerblue', linewidth=2)
ax.fill_between(x_pdf, pdf, color='dodgerblue', alpha=0.3)
# Add mean marker safely
if alpha > 0 and beta > 0:
mean_val = alpha / (alpha + beta)
try:
# Use the renamed import here:
peak_y = beta_distribution.pdf(mean_val, alpha, beta) # <-- CORRECTED
if np.isfinite(peak_y) and peak_y >= 0:
# Ensure vlines max isn't inf/nan
ymax_vline = peak_y if np.isfinite(peak_y) else ax.get_ylim()[1]
ax.vlines(mean_val, 0, ymax_vline , color='red', linestyle='--', label=f'Mean ≈ {mean_val:.2f}')
ax.legend(fontsize=8)
except ValueError:
pass # Ignore if pdf calculation fails for mean
ax.set_title(plot_title, fontsize=10)
ax.set_xlabel("Probability of Heads (r)", fontsize=9)
ax.set_ylabel("Density", fontsize=9)
ax.set_ylim(bottom=0)
ax.grid(True, linestyle=':', alpha=0.7)
plt.tight_layout()
# --- Convert plot to NumPy array ---
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # IMPORTANT: Close the figure to release memory
buf.seek(0)
img = Image.open(buf)
img_array = np.array(img)
buf.close()
# --- End Conversion ---
return img_array
# --- Gradio Action Functions ---
def reset_simulation(selected_prior_name):
"""Resets the state based on the selected prior."""
# print(f"Resetting with prior: {selected_prior_name}") # Debug print
if selected_prior_name not in priors:
selected_prior_name = prior_names[0] # Default fallback
prior_a, prior_b = priors[selected_prior_name]
initial_state = {
"prior_a": prior_a,
"prior_b": prior_b,
"current_a": prior_a,
"current_b": prior_b,
"history": []
}
history_str = "No tosses yet."
params_str = f"Current Posterior: Beta({prior_a:.1f}, {prior_b:.1f})"
plot_image = generate_plot_image(prior_a, prior_b, title_prefix="Prior: ") # Pass numeric params
# print("Reset complete.") # Debug print
return initial_state, history_str, params_str, plot_image
def perform_next_toss(current_state):
"""Performs one toss, updates state, history, params, and plot image."""
# print(f"Performing next toss. Current state: {current_state}") # Debug print
if not isinstance(current_state, dict) or "current_a" not in current_state:
print("Warning: Invalid state detected in perform_next_toss. Resetting.") # Debug print
# Need to ensure reset doesn't fail here now
try:
current_state, _, _, _ = reset_simulation(prior_names[0]) # Use first prior as default
except Exception as e:
print(f"ERROR during reset within perform_next_toss: {e}")
# Cannot proceed - maybe return error messages/default plot?
error_plot = generate_plot_image(1, 1, title_prefix="ERROR State Invalid - ") # Default plot on error
return current_state, "Error: Invalid state", "Error", error_plot
# Simulate a coin toss
toss_result = random.randint(0, 1) # 0=Tails, 1=Heads
# Update history list in state
new_history = current_state.get("history", []) + [toss_result]
current_state["history"] = new_history
# Update parameters in state
current_a = current_state.get("current_a", 1.0) # Default if missing
current_b = current_state.get("current_b", 1.0) # Default if missing
if toss_result == 1: # Heads
current_a += 1
else: # Tails
current_b += 1
current_state["current_a"] = current_a
current_state["current_b"] = current_b
# Generate outputs
history_str = ", ".join(['H' if t == 1 else 'T' for t in new_history])
if not history_str: history_str = "No tosses yet."
params_str = f"Current Posterior: Beta({current_a:.1f}, {current_b:.1f})"
plot_image = generate_plot_image(current_a, current_b, title_prefix="Posterior: ") # Pass numeric params
# print("Toss complete.") # Debug print
return current_state, history_str, params_str, plot_image
# --- Build Gradio Interface ---
# (The gr.Blocks interface code remains the same as the previous version)
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# Bayesian Coin Toss Simulation 🪙
Visualize how a Beta distribution (representing belief about a coin's bias)
updates after each simulated coin toss (Heads or Tails).
Select a prior belief, then press 'Next Toss' or 'Reset'.
"""
)
# Hidden state to store current alpha, beta, and history list
state = gr.State(value=None) # Initial value is None, will be set by load/change
with gr.Row():
prior_selector = gr.Dropdown(
label="Select Prior Belief",
choices=prior_names,
value=prior_names[0] # Default selection
)
with gr.Row():
next_button = gr.Button("Next Toss", variant="primary")
reset_button = gr.Button("Reset")
with gr.Row():
history_output = gr.Textbox(label="Toss History", value="Initializing...", interactive=False, scale=2)
params_output = gr.Textbox(label="Posterior Parameters", value="Initializing...", interactive=False, scale=1)
# Use gr.Image for the plot output
plot_output = gr.Image(label="Posterior Distribution", type="numpy", value=None) # Expect NumPy array
# --- Event Listeners ---
# 1. Initialize or reset when the prior selection changes
prior_selector.change(
fn=reset_simulation,
inputs=[prior_selector],
outputs=[state, history_output, params_output, plot_output],
queue=False # Prevent queuing if rapid changes happen
)
# 2. Initialize when the app loads - runs reset_simulation with the default dropdown value
app.load(
fn=reset_simulation,
inputs=[prior_selector],
outputs=[state, history_output, params_output, plot_output],
)
# 3. Update when "Next Toss" is clicked
next_button.click(
fn=perform_next_toss,
inputs=[state], # Pass the current state
outputs=[state, history_output, params_output, plot_output] # Get updated state back
)
# 4. Reset when "Reset" is clicked
reset_button.click(
fn=reset_simulation,
inputs=[prior_selector], # Reset based on the *currently selected* prior
outputs=[state, history_output, params_output, plot_output]
)
# --- Launch the App ---
app.launch(debug=True) # debug=True helps see errors in console