File size: 8,541 Bytes
aabf0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ea705
aabf0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ea705
aabf0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Correct the import name clash
from scipy.stats import beta as beta_distribution # <--- RENAMED IMPORT
import random
import io
from PIL import Image # Requires Pillow: pip install Pillow

# --- Configuration ---
priors = {
    "Uniform Prior (Beta(1,1))": (1, 1),
    "Prior Biased Toward Heads (Beta(5,1))": (5, 1),
    "Prior Biased Toward Fair (Beta(2,2))": (2, 2)
}
prior_names = list(priors.keys())
x_pdf = np.linspace(0, 1, 300)

# --- Helper Function to Generate Plot Image (NumPy Array) ---
def generate_plot_image(alpha, beta, title_prefix=""): # Function arguments 'alpha', 'beta' are the numbers
    """Generates a Matplotlib plot and returns it as a NumPy array."""
    fig, ax = plt.subplots(figsize=(6, 4))
    plot_title = f"{title_prefix}Beta({alpha:.1f}, {beta:.1f})" # Use formatting

    # Prevent plotting issues with non-positive alpha/beta
    if alpha <= 0 or beta <= 0:
         ax.text(0.5, 0.5, 'Invalid parameters\nCannot plot Beta(≤0, ≤0)',
                 horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
         ax.set_title(plot_title, fontsize=10)
         ax.set_xlabel("Probability of Heads (r)", fontsize=9)
         ax.set_ylabel("Density", fontsize=9)
    else:
        # Calculate PDF using the RENAMED distribution object
        try:
            # Use the renamed import here:
            pdf = beta_distribution.pdf(x_pdf, alpha, beta) # <-- CORRECTED
            # Handle potential inf values if alpha/beta are very small
            pdf = np.nan_to_num(pdf, nan=0.0, posinf=np.nanmax(pdf[np.isfinite(pdf)]) if np.any(np.isfinite(pdf)) else 1.0, neginf=0.0) # Replace inf with max finite value or 1
        except ValueError:
             pdf = np.zeros_like(x_pdf) # Fallback

        ax.plot(x_pdf, pdf, color='dodgerblue', linewidth=2)
        ax.fill_between(x_pdf, pdf, color='dodgerblue', alpha=0.3)

        # Add mean marker safely
        if alpha > 0 and beta > 0:
             mean_val = alpha / (alpha + beta)
             try:
                  # Use the renamed import here:
                  peak_y = beta_distribution.pdf(mean_val, alpha, beta) # <-- CORRECTED
                  if np.isfinite(peak_y) and peak_y >= 0:
                     # Ensure vlines max isn't inf/nan
                     ymax_vline = peak_y if np.isfinite(peak_y) else ax.get_ylim()[1]
                     ax.vlines(mean_val, 0, ymax_vline , color='red', linestyle='--', label=f'Mean ≈ {mean_val:.2f}')
                     ax.legend(fontsize=8)
             except ValueError:
                  pass # Ignore if pdf calculation fails for mean

        ax.set_title(plot_title, fontsize=10)
        ax.set_xlabel("Probability of Heads (r)", fontsize=9)
        ax.set_ylabel("Density", fontsize=9)
        ax.set_ylim(bottom=0)
        ax.grid(True, linestyle=':', alpha=0.7)

    plt.tight_layout()

    # --- Convert plot to NumPy array ---
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close(fig) # IMPORTANT: Close the figure to release memory
    buf.seek(0)
    img = Image.open(buf)
    img_array = np.array(img)
    buf.close()
    # --- End Conversion ---

    return img_array

# --- Gradio Action Functions ---

def reset_simulation(selected_prior_name):
    """Resets the state based on the selected prior."""
    # print(f"Resetting with prior: {selected_prior_name}") # Debug print
    if selected_prior_name not in priors:
         selected_prior_name = prior_names[0] # Default fallback

    prior_a, prior_b = priors[selected_prior_name]
    initial_state = {
        "prior_a": prior_a,
        "prior_b": prior_b,
        "current_a": prior_a,
        "current_b": prior_b,
        "history": []
    }
    history_str = "No tosses yet."
    params_str = f"Current Posterior: Beta({prior_a:.1f}, {prior_b:.1f})"
    plot_image = generate_plot_image(prior_a, prior_b, title_prefix="Prior: ") # Pass numeric params
    # print("Reset complete.") # Debug print
    return initial_state, history_str, params_str, plot_image

def perform_next_toss(current_state):
    """Performs one toss, updates state, history, params, and plot image."""
    # print(f"Performing next toss. Current state: {current_state}") # Debug print
    if not isinstance(current_state, dict) or "current_a" not in current_state:
        print("Warning: Invalid state detected in perform_next_toss. Resetting.") # Debug print
        # Need to ensure reset doesn't fail here now
        try:
            current_state, _, _, _ = reset_simulation(prior_names[0]) # Use first prior as default
        except Exception as e:
             print(f"ERROR during reset within perform_next_toss: {e}")
             # Cannot proceed - maybe return error messages/default plot?
             error_plot = generate_plot_image(1, 1, title_prefix="ERROR State Invalid - ") # Default plot on error
             return current_state, "Error: Invalid state", "Error", error_plot


    # Simulate a coin toss
    toss_result = random.randint(0, 1) # 0=Tails, 1=Heads

    # Update history list in state
    new_history = current_state.get("history", []) + [toss_result]
    current_state["history"] = new_history

    # Update parameters in state
    current_a = current_state.get("current_a", 1.0) # Default if missing
    current_b = current_state.get("current_b", 1.0) # Default if missing

    if toss_result == 1: # Heads
        current_a += 1
    else: # Tails
        current_b += 1
    current_state["current_a"] = current_a
    current_state["current_b"] = current_b

    # Generate outputs
    history_str = ", ".join(['H' if t == 1 else 'T' for t in new_history])
    if not history_str: history_str = "No tosses yet."

    params_str = f"Current Posterior: Beta({current_a:.1f}, {current_b:.1f})"
    plot_image = generate_plot_image(current_a, current_b, title_prefix="Posterior: ") # Pass numeric params
    # print("Toss complete.") # Debug print
    return current_state, history_str, params_str, plot_image

# --- Build Gradio Interface ---
# (The gr.Blocks interface code remains the same as the previous version)
with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown(
        """
        # Bayesian Coin Toss Simulation 🪙
        Visualize how a Beta distribution (representing belief about a coin's bias)
        updates after each simulated coin toss (Heads or Tails).
        Select a prior belief, then press 'Next Toss' or 'Reset'.
        """
    )

    # Hidden state to store current alpha, beta, and history list
    state = gr.State(value=None) # Initial value is None, will be set by load/change

    with gr.Row():
        prior_selector = gr.Dropdown(
            label="Select Prior Belief",
            choices=prior_names,
            value=prior_names[0] # Default selection
        )

    with gr.Row():
        next_button = gr.Button("Next Toss", variant="primary")
        reset_button = gr.Button("Reset")

    with gr.Row():
         history_output = gr.Textbox(label="Toss History", value="Initializing...", interactive=False, scale=2)
         params_output = gr.Textbox(label="Posterior Parameters", value="Initializing...", interactive=False, scale=1)

    # Use gr.Image for the plot output
    plot_output = gr.Image(label="Posterior Distribution", type="numpy", value=None) # Expect NumPy array

    # --- Event Listeners ---
    # 1. Initialize or reset when the prior selection changes
    prior_selector.change(
        fn=reset_simulation,
        inputs=[prior_selector],
        outputs=[state, history_output, params_output, plot_output],
        queue=False # Prevent queuing if rapid changes happen
    )

    # 2. Initialize when the app loads - runs reset_simulation with the default dropdown value
    app.load(
        fn=reset_simulation,
        inputs=[prior_selector],
        outputs=[state, history_output, params_output, plot_output],
    )

    # 3. Update when "Next Toss" is clicked
    next_button.click(
        fn=perform_next_toss,
        inputs=[state], # Pass the current state
        outputs=[state, history_output, params_output, plot_output] # Get updated state back
    )

    # 4. Reset when "Reset" is clicked
    reset_button.click(
        fn=reset_simulation,
        inputs=[prior_selector], # Reset based on the *currently selected* prior
        outputs=[state, history_output, params_output, plot_output]
    )

# --- Launch the App ---
app.launch(debug=True) # debug=True helps see errors in console