import gradio as gr import numpy as np import matplotlib.pyplot as plt # Correct the import name clash from scipy.stats import beta as beta_distribution # <--- RENAMED IMPORT import random import io from PIL import Image # Requires Pillow: pip install Pillow # --- Configuration --- priors = { "Uniform Prior (Beta(1,1))": (1, 1), "Prior Biased Toward Heads (Beta(5,1))": (5, 1), "Prior Biased Toward Fair (Beta(2,2))": (2, 2) } prior_names = list(priors.keys()) x_pdf = np.linspace(0, 1, 300) # --- Helper Function to Generate Plot Image (NumPy Array) --- def generate_plot_image(alpha, beta, title_prefix=""): # Function arguments 'alpha', 'beta' are the numbers """Generates a Matplotlib plot and returns it as a NumPy array.""" fig, ax = plt.subplots(figsize=(6, 4)) plot_title = f"{title_prefix}Beta({alpha:.1f}, {beta:.1f})" # Use formatting # Prevent plotting issues with non-positive alpha/beta if alpha <= 0 or beta <= 0: ax.text(0.5, 0.5, 'Invalid parameters\nCannot plot Beta(≤0, ≤0)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) ax.set_title(plot_title, fontsize=10) ax.set_xlabel("Probability of Heads (r)", fontsize=9) ax.set_ylabel("Density", fontsize=9) else: # Calculate PDF using the RENAMED distribution object try: # Use the renamed import here: pdf = beta_distribution.pdf(x_pdf, alpha, beta) # <-- CORRECTED # Handle potential inf values if alpha/beta are very small pdf = np.nan_to_num(pdf, nan=0.0, posinf=np.nanmax(pdf[np.isfinite(pdf)]) if np.any(np.isfinite(pdf)) else 1.0, neginf=0.0) # Replace inf with max finite value or 1 except ValueError: pdf = np.zeros_like(x_pdf) # Fallback ax.plot(x_pdf, pdf, color='dodgerblue', linewidth=2) ax.fill_between(x_pdf, pdf, color='dodgerblue', alpha=0.3) # Add mean marker safely if alpha > 0 and beta > 0: mean_val = alpha / (alpha + beta) try: # Use the renamed import here: peak_y = beta_distribution.pdf(mean_val, alpha, beta) # <-- CORRECTED if np.isfinite(peak_y) and peak_y >= 0: # Ensure vlines max isn't inf/nan ymax_vline = peak_y if np.isfinite(peak_y) else ax.get_ylim()[1] ax.vlines(mean_val, 0, ymax_vline , color='red', linestyle='--', label=f'Mean ≈ {mean_val:.2f}') ax.legend(fontsize=8) except ValueError: pass # Ignore if pdf calculation fails for mean ax.set_title(plot_title, fontsize=10) ax.set_xlabel("Probability of Heads (r)", fontsize=9) ax.set_ylabel("Density", fontsize=9) ax.set_ylim(bottom=0) ax.grid(True, linestyle=':', alpha=0.7) plt.tight_layout() # --- Convert plot to NumPy array --- buf = io.BytesIO() fig.savefig(buf, format='png') plt.close(fig) # IMPORTANT: Close the figure to release memory buf.seek(0) img = Image.open(buf) img_array = np.array(img) buf.close() # --- End Conversion --- return img_array # --- Gradio Action Functions --- def reset_simulation(selected_prior_name): """Resets the state based on the selected prior.""" # print(f"Resetting with prior: {selected_prior_name}") # Debug print if selected_prior_name not in priors: selected_prior_name = prior_names[0] # Default fallback prior_a, prior_b = priors[selected_prior_name] initial_state = { "prior_a": prior_a, "prior_b": prior_b, "current_a": prior_a, "current_b": prior_b, "history": [] } history_str = "No tosses yet." params_str = f"Current Posterior: Beta({prior_a:.1f}, {prior_b:.1f})" plot_image = generate_plot_image(prior_a, prior_b, title_prefix="Prior: ") # Pass numeric params # print("Reset complete.") # Debug print return initial_state, history_str, params_str, plot_image def perform_next_toss(current_state): """Performs one toss, updates state, history, params, and plot image.""" # print(f"Performing next toss. Current state: {current_state}") # Debug print if not isinstance(current_state, dict) or "current_a" not in current_state: print("Warning: Invalid state detected in perform_next_toss. Resetting.") # Debug print # Need to ensure reset doesn't fail here now try: current_state, _, _, _ = reset_simulation(prior_names[0]) # Use first prior as default except Exception as e: print(f"ERROR during reset within perform_next_toss: {e}") # Cannot proceed - maybe return error messages/default plot? error_plot = generate_plot_image(1, 1, title_prefix="ERROR State Invalid - ") # Default plot on error return current_state, "Error: Invalid state", "Error", error_plot # Simulate a coin toss toss_result = random.randint(0, 1) # 0=Tails, 1=Heads # Update history list in state new_history = current_state.get("history", []) + [toss_result] current_state["history"] = new_history # Update parameters in state current_a = current_state.get("current_a", 1.0) # Default if missing current_b = current_state.get("current_b", 1.0) # Default if missing if toss_result == 1: # Heads current_a += 1 else: # Tails current_b += 1 current_state["current_a"] = current_a current_state["current_b"] = current_b # Generate outputs history_str = ", ".join(['H' if t == 1 else 'T' for t in new_history]) if not history_str: history_str = "No tosses yet." params_str = f"Current Posterior: Beta({current_a:.1f}, {current_b:.1f})" plot_image = generate_plot_image(current_a, current_b, title_prefix="Posterior: ") # Pass numeric params # print("Toss complete.") # Debug print return current_state, history_str, params_str, plot_image # --- Build Gradio Interface --- # (The gr.Blocks interface code remains the same as the previous version) with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown( """ # Bayesian Coin Toss Simulation 🪙 Visualize how a Beta distribution (representing belief about a coin's bias) updates after each simulated coin toss (Heads or Tails). Select a prior belief, then press 'Next Toss' or 'Reset'. """ ) # Hidden state to store current alpha, beta, and history list state = gr.State(value=None) # Initial value is None, will be set by load/change with gr.Row(): prior_selector = gr.Dropdown( label="Select Prior Belief", choices=prior_names, value=prior_names[0] # Default selection ) with gr.Row(): next_button = gr.Button("Next Toss", variant="primary") reset_button = gr.Button("Reset") with gr.Row(): history_output = gr.Textbox(label="Toss History", value="Initializing...", interactive=False, scale=2) params_output = gr.Textbox(label="Posterior Parameters", value="Initializing...", interactive=False, scale=1) # Use gr.Image for the plot output plot_output = gr.Image(label="Posterior Distribution", type="numpy", value=None) # Expect NumPy array # --- Event Listeners --- # 1. Initialize or reset when the prior selection changes prior_selector.change( fn=reset_simulation, inputs=[prior_selector], outputs=[state, history_output, params_output, plot_output], queue=False # Prevent queuing if rapid changes happen ) # 2. Initialize when the app loads - runs reset_simulation with the default dropdown value app.load( fn=reset_simulation, inputs=[prior_selector], outputs=[state, history_output, params_output, plot_output], ) # 3. Update when "Next Toss" is clicked next_button.click( fn=perform_next_toss, inputs=[state], # Pass the current state outputs=[state, history_output, params_output, plot_output] # Get updated state back ) # 4. Reset when "Reset" is clicked reset_button.click( fn=reset_simulation, inputs=[prior_selector], # Reset based on the *currently selected* prior outputs=[state, history_output, params_output, plot_output] ) # --- Launch the App --- app.launch(debug=True) # debug=True helps see errors in console