|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from scipy.stats import beta as beta_distribution |
|
|
import random |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
priors = { |
|
|
"Uniform Prior (Beta(1,1))": (1, 1), |
|
|
"Prior Biased Toward Heads (Beta(5,1))": (5, 1), |
|
|
"Prior Biased Toward Fair (Beta(2,2))": (2, 2) |
|
|
} |
|
|
prior_names = list(priors.keys()) |
|
|
x_pdf = np.linspace(0, 1, 300) |
|
|
|
|
|
|
|
|
def generate_plot_image(alpha, beta, title_prefix=""): |
|
|
"""Generates a Matplotlib plot and returns it as a NumPy array.""" |
|
|
fig, ax = plt.subplots(figsize=(6, 4)) |
|
|
plot_title = f"{title_prefix}Beta({alpha:.1f}, {beta:.1f})" |
|
|
|
|
|
|
|
|
if alpha <= 0 or beta <= 0: |
|
|
ax.text(0.5, 0.5, 'Invalid parameters\nCannot plot Beta(≤0, ≤0)', |
|
|
horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) |
|
|
ax.set_title(plot_title, fontsize=10) |
|
|
ax.set_xlabel("Probability of Heads (r)", fontsize=9) |
|
|
ax.set_ylabel("Density", fontsize=9) |
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
pdf = beta_distribution.pdf(x_pdf, alpha, beta) |
|
|
|
|
|
pdf = np.nan_to_num(pdf, nan=0.0, posinf=np.nanmax(pdf[np.isfinite(pdf)]) if np.any(np.isfinite(pdf)) else 1.0, neginf=0.0) |
|
|
except ValueError: |
|
|
pdf = np.zeros_like(x_pdf) |
|
|
|
|
|
ax.plot(x_pdf, pdf, color='dodgerblue', linewidth=2) |
|
|
ax.fill_between(x_pdf, pdf, color='dodgerblue', alpha=0.3) |
|
|
|
|
|
|
|
|
if alpha > 0 and beta > 0: |
|
|
mean_val = alpha / (alpha + beta) |
|
|
try: |
|
|
|
|
|
peak_y = beta_distribution.pdf(mean_val, alpha, beta) |
|
|
if np.isfinite(peak_y) and peak_y >= 0: |
|
|
|
|
|
ymax_vline = peak_y if np.isfinite(peak_y) else ax.get_ylim()[1] |
|
|
ax.vlines(mean_val, 0, ymax_vline , color='red', linestyle='--', label=f'Mean ≈ {mean_val:.2f}') |
|
|
ax.legend(fontsize=8) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
ax.set_title(plot_title, fontsize=10) |
|
|
ax.set_xlabel("Probability of Heads (r)", fontsize=9) |
|
|
ax.set_ylabel("Density", fontsize=9) |
|
|
ax.set_ylim(bottom=0) |
|
|
ax.grid(True, linestyle=':', alpha=0.7) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format='png') |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
img = Image.open(buf) |
|
|
img_array = np.array(img) |
|
|
buf.close() |
|
|
|
|
|
|
|
|
return img_array |
|
|
|
|
|
|
|
|
|
|
|
def reset_simulation(selected_prior_name): |
|
|
"""Resets the state based on the selected prior.""" |
|
|
|
|
|
if selected_prior_name not in priors: |
|
|
selected_prior_name = prior_names[0] |
|
|
|
|
|
prior_a, prior_b = priors[selected_prior_name] |
|
|
initial_state = { |
|
|
"prior_a": prior_a, |
|
|
"prior_b": prior_b, |
|
|
"current_a": prior_a, |
|
|
"current_b": prior_b, |
|
|
"history": [] |
|
|
} |
|
|
history_str = "No tosses yet." |
|
|
params_str = f"Current Posterior: Beta({prior_a:.1f}, {prior_b:.1f})" |
|
|
plot_image = generate_plot_image(prior_a, prior_b, title_prefix="Prior: ") |
|
|
|
|
|
return initial_state, history_str, params_str, plot_image |
|
|
|
|
|
def perform_next_toss(current_state): |
|
|
"""Performs one toss, updates state, history, params, and plot image.""" |
|
|
|
|
|
if not isinstance(current_state, dict) or "current_a" not in current_state: |
|
|
print("Warning: Invalid state detected in perform_next_toss. Resetting.") |
|
|
|
|
|
try: |
|
|
current_state, _, _, _ = reset_simulation(prior_names[0]) |
|
|
except Exception as e: |
|
|
print(f"ERROR during reset within perform_next_toss: {e}") |
|
|
|
|
|
error_plot = generate_plot_image(1, 1, title_prefix="ERROR State Invalid - ") |
|
|
return current_state, "Error: Invalid state", "Error", error_plot |
|
|
|
|
|
|
|
|
|
|
|
toss_result = random.randint(0, 1) |
|
|
|
|
|
|
|
|
new_history = current_state.get("history", []) + [toss_result] |
|
|
current_state["history"] = new_history |
|
|
|
|
|
|
|
|
current_a = current_state.get("current_a", 1.0) |
|
|
current_b = current_state.get("current_b", 1.0) |
|
|
|
|
|
if toss_result == 1: |
|
|
current_a += 1 |
|
|
else: |
|
|
current_b += 1 |
|
|
current_state["current_a"] = current_a |
|
|
current_state["current_b"] = current_b |
|
|
|
|
|
|
|
|
history_str = ", ".join(['H' if t == 1 else 'T' for t in new_history]) |
|
|
if not history_str: history_str = "No tosses yet." |
|
|
|
|
|
params_str = f"Current Posterior: Beta({current_a:.1f}, {current_b:.1f})" |
|
|
plot_image = generate_plot_image(current_a, current_b, title_prefix="Posterior: ") |
|
|
|
|
|
return current_state, history_str, params_str, plot_image |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as app: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Bayesian Coin Toss Simulation 🪙 |
|
|
Visualize how a Beta distribution (representing belief about a coin's bias) |
|
|
updates after each simulated coin toss (Heads or Tails). |
|
|
Select a prior belief, then press 'Next Toss' or 'Reset'. |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
state = gr.State(value=None) |
|
|
|
|
|
with gr.Row(): |
|
|
prior_selector = gr.Dropdown( |
|
|
label="Select Prior Belief", |
|
|
choices=prior_names, |
|
|
value=prior_names[0] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
next_button = gr.Button("Next Toss", variant="primary") |
|
|
reset_button = gr.Button("Reset") |
|
|
|
|
|
with gr.Row(): |
|
|
history_output = gr.Textbox(label="Toss History", value="Initializing...", interactive=False, scale=2) |
|
|
params_output = gr.Textbox(label="Posterior Parameters", value="Initializing...", interactive=False, scale=1) |
|
|
|
|
|
|
|
|
plot_output = gr.Image(label="Posterior Distribution", type="numpy", value=None) |
|
|
|
|
|
|
|
|
|
|
|
prior_selector.change( |
|
|
fn=reset_simulation, |
|
|
inputs=[prior_selector], |
|
|
outputs=[state, history_output, params_output, plot_output], |
|
|
queue=False |
|
|
) |
|
|
|
|
|
|
|
|
app.load( |
|
|
fn=reset_simulation, |
|
|
inputs=[prior_selector], |
|
|
outputs=[state, history_output, params_output, plot_output], |
|
|
) |
|
|
|
|
|
|
|
|
next_button.click( |
|
|
fn=perform_next_toss, |
|
|
inputs=[state], |
|
|
outputs=[state, history_output, params_output, plot_output] |
|
|
) |
|
|
|
|
|
|
|
|
reset_button.click( |
|
|
fn=reset_simulation, |
|
|
inputs=[prior_selector], |
|
|
outputs=[state, history_output, params_output, plot_output] |
|
|
) |
|
|
|
|
|
|
|
|
app.launch(debug=True) |