File size: 8,541 Bytes
aabf0e7 f8ea705 aabf0e7 f8ea705 aabf0e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Correct the import name clash
from scipy.stats import beta as beta_distribution # <--- RENAMED IMPORT
import random
import io
from PIL import Image # Requires Pillow: pip install Pillow
# --- Configuration ---
priors = {
"Uniform Prior (Beta(1,1))": (1, 1),
"Prior Biased Toward Heads (Beta(5,1))": (5, 1),
"Prior Biased Toward Fair (Beta(2,2))": (2, 2)
}
prior_names = list(priors.keys())
x_pdf = np.linspace(0, 1, 300)
# --- Helper Function to Generate Plot Image (NumPy Array) ---
def generate_plot_image(alpha, beta, title_prefix=""): # Function arguments 'alpha', 'beta' are the numbers
"""Generates a Matplotlib plot and returns it as a NumPy array."""
fig, ax = plt.subplots(figsize=(6, 4))
plot_title = f"{title_prefix}Beta({alpha:.1f}, {beta:.1f})" # Use formatting
# Prevent plotting issues with non-positive alpha/beta
if alpha <= 0 or beta <= 0:
ax.text(0.5, 0.5, 'Invalid parameters\nCannot plot Beta(≤0, ≤0)',
horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
ax.set_title(plot_title, fontsize=10)
ax.set_xlabel("Probability of Heads (r)", fontsize=9)
ax.set_ylabel("Density", fontsize=9)
else:
# Calculate PDF using the RENAMED distribution object
try:
# Use the renamed import here:
pdf = beta_distribution.pdf(x_pdf, alpha, beta) # <-- CORRECTED
# Handle potential inf values if alpha/beta are very small
pdf = np.nan_to_num(pdf, nan=0.0, posinf=np.nanmax(pdf[np.isfinite(pdf)]) if np.any(np.isfinite(pdf)) else 1.0, neginf=0.0) # Replace inf with max finite value or 1
except ValueError:
pdf = np.zeros_like(x_pdf) # Fallback
ax.plot(x_pdf, pdf, color='dodgerblue', linewidth=2)
ax.fill_between(x_pdf, pdf, color='dodgerblue', alpha=0.3)
# Add mean marker safely
if alpha > 0 and beta > 0:
mean_val = alpha / (alpha + beta)
try:
# Use the renamed import here:
peak_y = beta_distribution.pdf(mean_val, alpha, beta) # <-- CORRECTED
if np.isfinite(peak_y) and peak_y >= 0:
# Ensure vlines max isn't inf/nan
ymax_vline = peak_y if np.isfinite(peak_y) else ax.get_ylim()[1]
ax.vlines(mean_val, 0, ymax_vline , color='red', linestyle='--', label=f'Mean ≈ {mean_val:.2f}')
ax.legend(fontsize=8)
except ValueError:
pass # Ignore if pdf calculation fails for mean
ax.set_title(plot_title, fontsize=10)
ax.set_xlabel("Probability of Heads (r)", fontsize=9)
ax.set_ylabel("Density", fontsize=9)
ax.set_ylim(bottom=0)
ax.grid(True, linestyle=':', alpha=0.7)
plt.tight_layout()
# --- Convert plot to NumPy array ---
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # IMPORTANT: Close the figure to release memory
buf.seek(0)
img = Image.open(buf)
img_array = np.array(img)
buf.close()
# --- End Conversion ---
return img_array
# --- Gradio Action Functions ---
def reset_simulation(selected_prior_name):
"""Resets the state based on the selected prior."""
# print(f"Resetting with prior: {selected_prior_name}") # Debug print
if selected_prior_name not in priors:
selected_prior_name = prior_names[0] # Default fallback
prior_a, prior_b = priors[selected_prior_name]
initial_state = {
"prior_a": prior_a,
"prior_b": prior_b,
"current_a": prior_a,
"current_b": prior_b,
"history": []
}
history_str = "No tosses yet."
params_str = f"Current Posterior: Beta({prior_a:.1f}, {prior_b:.1f})"
plot_image = generate_plot_image(prior_a, prior_b, title_prefix="Prior: ") # Pass numeric params
# print("Reset complete.") # Debug print
return initial_state, history_str, params_str, plot_image
def perform_next_toss(current_state):
"""Performs one toss, updates state, history, params, and plot image."""
# print(f"Performing next toss. Current state: {current_state}") # Debug print
if not isinstance(current_state, dict) or "current_a" not in current_state:
print("Warning: Invalid state detected in perform_next_toss. Resetting.") # Debug print
# Need to ensure reset doesn't fail here now
try:
current_state, _, _, _ = reset_simulation(prior_names[0]) # Use first prior as default
except Exception as e:
print(f"ERROR during reset within perform_next_toss: {e}")
# Cannot proceed - maybe return error messages/default plot?
error_plot = generate_plot_image(1, 1, title_prefix="ERROR State Invalid - ") # Default plot on error
return current_state, "Error: Invalid state", "Error", error_plot
# Simulate a coin toss
toss_result = random.randint(0, 1) # 0=Tails, 1=Heads
# Update history list in state
new_history = current_state.get("history", []) + [toss_result]
current_state["history"] = new_history
# Update parameters in state
current_a = current_state.get("current_a", 1.0) # Default if missing
current_b = current_state.get("current_b", 1.0) # Default if missing
if toss_result == 1: # Heads
current_a += 1
else: # Tails
current_b += 1
current_state["current_a"] = current_a
current_state["current_b"] = current_b
# Generate outputs
history_str = ", ".join(['H' if t == 1 else 'T' for t in new_history])
if not history_str: history_str = "No tosses yet."
params_str = f"Current Posterior: Beta({current_a:.1f}, {current_b:.1f})"
plot_image = generate_plot_image(current_a, current_b, title_prefix="Posterior: ") # Pass numeric params
# print("Toss complete.") # Debug print
return current_state, history_str, params_str, plot_image
# --- Build Gradio Interface ---
# (The gr.Blocks interface code remains the same as the previous version)
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# Bayesian Coin Toss Simulation 🪙
Visualize how a Beta distribution (representing belief about a coin's bias)
updates after each simulated coin toss (Heads or Tails).
Select a prior belief, then press 'Next Toss' or 'Reset'.
"""
)
# Hidden state to store current alpha, beta, and history list
state = gr.State(value=None) # Initial value is None, will be set by load/change
with gr.Row():
prior_selector = gr.Dropdown(
label="Select Prior Belief",
choices=prior_names,
value=prior_names[0] # Default selection
)
with gr.Row():
next_button = gr.Button("Next Toss", variant="primary")
reset_button = gr.Button("Reset")
with gr.Row():
history_output = gr.Textbox(label="Toss History", value="Initializing...", interactive=False, scale=2)
params_output = gr.Textbox(label="Posterior Parameters", value="Initializing...", interactive=False, scale=1)
# Use gr.Image for the plot output
plot_output = gr.Image(label="Posterior Distribution", type="numpy", value=None) # Expect NumPy array
# --- Event Listeners ---
# 1. Initialize or reset when the prior selection changes
prior_selector.change(
fn=reset_simulation,
inputs=[prior_selector],
outputs=[state, history_output, params_output, plot_output],
queue=False # Prevent queuing if rapid changes happen
)
# 2. Initialize when the app loads - runs reset_simulation with the default dropdown value
app.load(
fn=reset_simulation,
inputs=[prior_selector],
outputs=[state, history_output, params_output, plot_output],
)
# 3. Update when "Next Toss" is clicked
next_button.click(
fn=perform_next_toss,
inputs=[state], # Pass the current state
outputs=[state, history_output, params_output, plot_output] # Get updated state back
)
# 4. Reset when "Reset" is clicked
reset_button.click(
fn=reset_simulation,
inputs=[prior_selector], # Reset based on the *currently selected* prior
outputs=[state, history_output, params_output, plot_output]
)
# --- Launch the App ---
app.launch(debug=True) # debug=True helps see errors in console |