import gradio as gr import numpy as np import matplotlib.pyplot as plt import seaborn as sns import matplotlib.patches as patches from matplotlib.gridspec import GridSpec import matplotlib.colors as mcolors from matplotlib.ticker import PercentFormatter # Set up the styling for better readability plt.rcParams.update({ 'font.family': 'sans-serif', 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'], 'font.size': 12, 'axes.titlesize': 16, 'axes.labelsize': 14, 'xtick.labelsize': 12, 'ytick.labelsize': 12, 'legend.fontsize': 12, 'figure.titlesize': 20 }) def create_bayes_visualization(prior_prob, sensitivity, specificity, population=1000): """Create a clear, readable visualization of Bayes' theorem.""" # Calculate values based on Bayes theorem true_positive = prior_prob * sensitivity * population false_positive = (1 - prior_prob) * (1 - specificity) * population true_negative = (1 - prior_prob) * specificity * population false_negative = prior_prob * (1 - sensitivity) * population # Calculate posterior probability (positive predictive value) posterior_prob = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0 # Create a large figure with a clean white background fig = plt.figure(figsize=(16, 12), facecolor='white') gs = GridSpec(3, 2, height_ratios=[1, 2, 1], width_ratios=[1, 1]) # Title for the entire visualization fig.suptitle("Bayes' Theorem: Medical Test Visualization", fontsize=22, fontweight='bold', y=0.98) # Add parameter information at the top param_ax = fig.add_subplot(gs[0, :]) param_ax.axis('off') param_text = ( f"Disease Prevalence: {prior_prob:.1%} | " f"Test Sensitivity: {sensitivity:.1%} | " f"Test Specificity: {specificity:.1%}" ) param_ax.text(0.5, 0.5, param_text, ha='center', va='center', fontsize=16, bbox=dict(facecolor='#e6f2ff', edgecolor='#3399ff', boxstyle='round,pad=0.5')) # 1. Population Distribution (Left) pop_ax = fig.add_subplot(gs[1, 0]) pop_ax.set_title("Population Distribution", fontsize=18, pad=15) pop_ax.axis('equal') pop_ax.set_xlim(0, 100) pop_ax.set_ylim(0, 100) # Create a clean, modern look with a light grid pop_ax.grid(False) pop_ax.set_xticks([]) pop_ax.set_yticks([]) pop_ax.spines['top'].set_visible(False) pop_ax.spines['right'].set_visible(False) pop_ax.spines['bottom'].set_visible(False) pop_ax.spines['left'].set_visible(False) # Draw the population rectangle with a light border pop_rect = patches.Rectangle((0, 0), 100, 100, linewidth=2, edgecolor='#666666', facecolor='#f0f0f0', alpha=0.3) pop_ax.add_patch(pop_rect) # Draw the disease prevalence with a distinct color disease_width = prior_prob * 100 disease_rect = patches.Rectangle((0, 0), disease_width, 100, linewidth=1, edgecolor='#cc0000', facecolor='#ff9999', alpha=0.7) pop_ax.add_patch(disease_rect) # Add clear labels with contrasting backgrounds pop_ax.text(disease_width/2, 50, f"Have disease\n{int(prior_prob*population)} people\n({prior_prob:.1%})", ha='center', va='center', fontsize=14, fontweight='bold', bbox=dict(facecolor='white', alpha=0.8, edgecolor='#cc0000', boxstyle='round,pad=0.3')) pop_ax.text(disease_width + (100-disease_width)/2, 50, f"Don't have disease\n{int((1-prior_prob)*population)} people\n({1-prior_prob:.1%})", ha='center', va='center', fontsize=14, fontweight='bold', bbox=dict(facecolor='white', alpha=0.8, edgecolor='#666666', boxstyle='round,pad=0.3')) # Add a dividing line pop_ax.axvline(x=disease_width, color='#666666', linestyle='--', linewidth=2, alpha=0.7) # 2. Test Results (Right) test_ax = fig.add_subplot(gs[1, 1]) test_ax.set_title("Test Results", fontsize=18, pad=15) test_ax.axis('equal') test_ax.set_xlim(0, 100) test_ax.set_ylim(0, 100) # Clean styling test_ax.grid(False) test_ax.set_xticks([]) test_ax.set_yticks([]) test_ax.spines['top'].set_visible(False) test_ax.spines['right'].set_visible(False) test_ax.spines['bottom'].set_visible(False) test_ax.spines['left'].set_visible(False) # Draw the test results rectangle test_rect = patches.Rectangle((0, 0), 100, 100, linewidth=2, edgecolor='#666666', facecolor='#f0f0f0', alpha=0.3) test_ax.add_patch(test_rect) # Calculate proportions for visualization tp_width = (prior_prob * sensitivity) * 100 fp_width = ((1 - prior_prob) * (1 - specificity)) * 100 fn_width = (prior_prob * (1 - sensitivity)) * 100 tn_width = ((1 - prior_prob) * specificity) * 100 # Use a clear, distinct color palette tp_rect = patches.Rectangle((0, 0), tp_width, 100, linewidth=1, edgecolor='#990000', facecolor='#ff5555', alpha=0.8) test_ax.add_patch(tp_rect) fp_rect = patches.Rectangle((tp_width, 0), fp_width, 100, linewidth=1, edgecolor='#994400', facecolor='#ffaa77', alpha=0.8) test_ax.add_patch(fp_rect) fn_rect = patches.Rectangle((tp_width + fp_width, 0), fn_width, 100, linewidth=1, edgecolor='#004499', facecolor='#77aaff', alpha=0.8) test_ax.add_patch(fn_rect) tn_rect = patches.Rectangle((tp_width + fp_width + fn_width, 0), tn_width, 100, linewidth=1, edgecolor='#000066', facecolor='#5588ff', alpha=0.8) test_ax.add_patch(tn_rect) # Add dividing lines test_ax.axvline(x=tp_width + fp_width, color='#666666', linestyle='--', linewidth=2, alpha=0.7) test_ax.axvline(x=tp_width, color='#666666', linestyle=':', linewidth=1.5, alpha=0.7) test_ax.axvline(x=tp_width + fp_width + fn_width, color='#666666', linestyle=':', linewidth=1.5, alpha=0.7) # Improved label placement to avoid overlap # Use vertical positioning to separate labels # True Positives - top position test_ax.text(tp_width/2, 75, f"True Positives\n{int(true_positive)} people", ha='center', va='center', fontsize=14, fontweight='bold', bbox=dict(facecolor='white', alpha=0.9, edgecolor='#990000', boxstyle='round,pad=0.3')) # False Positives - bottom position if narrow, otherwise center if fp_width < 10: # If the section is narrow fp_y_pos = 25 if tp_width > 10 else 50 # Adjust based on TP width test_ax.text(tp_width + fp_width/2, fp_y_pos, f"False\nPositives\n{int(false_positive)}", ha='center', va='center', fontsize=12, fontweight='bold', bbox=dict(facecolor='white', alpha=0.9, edgecolor='#994400', boxstyle='round,pad=0.3')) else: test_ax.text(tp_width + fp_width/2, 50, f"False Positives\n{int(false_positive)} people", ha='center', va='center', fontsize=14, fontweight='bold', bbox=dict(facecolor='white', alpha=0.9, edgecolor='#994400', boxstyle='round,pad=0.3')) # False Negatives - top position if narrow, otherwise center if fn_width < 10: # If the section is narrow fn_y_pos = 75 if tn_width > 10 else 50 # Adjust based on TN width test_ax.text(tp_width + fp_width + fn_width/2, fn_y_pos, f"False\nNegatives\n{int(false_negative)}", ha='center', va='center', fontsize=12, fontweight='bold', bbox=dict(facecolor='white', alpha=0.9, edgecolor='#004499', boxstyle='round,pad=0.3')) else: test_ax.text(tp_width + fp_width + fn_width/2, 50, f"False Negatives\n{int(false_negative)} people", ha='center', va='center', fontsize=14, fontweight='bold', bbox=dict(facecolor='white', alpha=0.9, edgecolor='#004499', boxstyle='round,pad=0.3')) # True Negatives - bottom position test_ax.text(tp_width + fp_width + fn_width + tn_width/2, 25, f"True Negatives\n{int(true_negative)} people", ha='center', va='center', fontsize=14, fontweight='bold', bbox=dict(facecolor='white', alpha=0.9, edgecolor='#000066', boxstyle='round,pad=0.3')) # Add test positive/negative regions with clear separation and improved positioning test_positive = true_positive + false_positive test_negative = false_negative + true_negative # Create a separate area below the main visualization for test result summaries test_ax.add_patch(patches.Rectangle((0, -20), tp_width + fp_width, 15, facecolor='#ffeeee', alpha=0.7, edgecolor='#cc0000')) test_ax.add_patch(patches.Rectangle((tp_width + fp_width, -20), fn_width + tn_width, 15, facecolor='#eeeeff', alpha=0.7, edgecolor='#0044cc')) # Add clear labels for test positive/negative test_ax.text(tp_width/2 + fp_width/2, -12.5, f"Test Positive: {int(test_positive)} people ({test_positive/population:.1%})", ha='center', va='center', fontsize=14, fontweight='bold', color='#990000') test_ax.text(tp_width + fp_width + fn_width/2 + tn_width/2, -12.5, f"Test Negative: {int(test_negative)} people ({test_negative/population:.1%})", ha='center', va='center', fontsize=14, fontweight='bold', color='#000066') # 3. Formula and Conclusion (Bottom) formula_ax = fig.add_subplot(gs[2, :]) formula_ax.axis('off') # Create a box for the formula formula_box = patches.FancyBboxPatch((0.1, 0.4), 0.8, 0.5, boxstyle=patches.BoxStyle("Round", pad=0.6), facecolor='#f5f5f5', edgecolor='#3399ff', linewidth=2, alpha=0.7, transform=formula_ax.transAxes) formula_ax.add_patch(formula_box) # Add Bayes formula with clear formatting formula_title = "Bayes' Theorem Applied to Medical Testing:" formula_ax.text(0.5, 0.8, formula_title, ha='center', va='center', fontsize=16, fontweight='bold', transform=formula_ax.transAxes) formula = r"$P(Disease|Positive) = \frac{P(Positive|Disease) \times P(Disease)}{P(Positive)}$" formula_ax.text(0.5, 0.65, formula, ha='center', va='center', fontsize=16, transform=formula_ax.transAxes) formula_explained = "Posterior Probability = Sensitivity × Prior Probability / Probability of Positive Test" formula_ax.text(0.5, 0.5, formula_explained, ha='center', va='center', fontsize=14, color='#555555', transform=formula_ax.transAxes) # Add the calculation with the actual values test_positive_prob = test_positive/population calculation = f"= {sensitivity:.1%} × {prior_prob:.1%} / {test_positive_prob:.1%} = {posterior_prob:.1%}" formula_ax.text(0.5, 0.35, calculation, ha='center', va='center', fontsize=16, transform=formula_ax.transAxes) # Create a highlighted conclusion box conclusion_box = patches.FancyBboxPatch((0.15, 0.05), 0.7, 0.2, boxstyle=patches.BoxStyle("Round", pad=0.6), facecolor='#ffffcc', edgecolor='#ffcc00', linewidth=2, transform=formula_ax.transAxes) formula_ax.add_patch(conclusion_box) # Add the conclusion with emphasis conclusion = f"If someone tests positive, they have a {posterior_prob:.1%} chance of having the disease" formula_ax.text(0.5, 0.15, conclusion, ha='center', va='center', fontsize=18, fontweight='bold', transform=formula_ax.transAxes) plt.tight_layout() plt.subplots_adjust(top=0.92, hspace=0.1, wspace=0.1) return fig def explain_bayes(prior_prob, sensitivity, specificity): """Generate the Bayes' theorem explanation and visualization.""" population = 1000 # Calculate values based on Bayes theorem true_positive = prior_prob * sensitivity * population false_positive = (1 - prior_prob) * (1 - specificity) * population true_negative = (1 - prior_prob) * specificity * population false_negative = prior_prob * (1 - sensitivity) * population # Calculate posterior probability (positive predictive value) test_positive = true_positive + false_positive test_positive_prob = test_positive / population posterior_prob = true_positive / test_positive if test_positive > 0 else 0 # Create the visualization fig = create_bayes_visualization(prior_prob, sensitivity, specificity, population) # Generate explanation text with clearer explanations of the percentages explanation = f""" ### Medical Test Example Explained Step-by-Step Imagine a medical test for a disease that affects {prior_prob:.1%} of the population (prior probability). **What the percentages mean:** 1. **Disease Prevalence ({prior_prob:.1%})**: - This means that out of every 100 people, about {int(prior_prob*100)} people have the disease - In our population of 1,000 people, {int(prior_prob*population)} people have the disease and {int((1-prior_prob)*population)} people don't 2. **Test Sensitivity ({sensitivity:.1%})**: - This means the test correctly identifies {sensitivity:.1%} of people who actually have the disease - Out of the {int(prior_prob*population)} people with the disease: * {int(true_positive)} people test positive (true positives) = {int(prior_prob*population)} × {sensitivity:.1%} * {int(false_negative)} people test negative (false negatives) = {int(prior_prob*population)} × {(1-sensitivity):.1%} 3. **Test Specificity ({specificity:.1%})**: - This means the test correctly identifies {specificity:.1%} of people who don't have the disease - Out of the {int((1-prior_prob)*population)} people without the disease: * {int(true_negative)} people test negative (true negatives) = {int((1-prior_prob)*population)} × {specificity:.1%} * {int(false_positive)} people test positive (false positives) = {int((1-prior_prob)*population)} × {(1-specificity):.1%} 4. **Probability of Positive Test ({test_positive_prob:.1%})**: - This is the total percentage of people who test positive, regardless of whether they have the disease - It's calculated by adding: * True positives: {int(true_positive)} people = {prior_prob:.1%} × {sensitivity:.1%} × 1,000 * False positives: {int(false_positive)} people = {(1-prior_prob):.1%} × {(1-specificity):.1%} × 1,000 - Total positive tests: {int(test_positive)} people out of 1,000 = {test_positive_prob:.1%} of the population **How the formula works:** Bayes' theorem calculates the probability that someone actually has the disease if they test positive: P(Disease|Positive) = P(Positive|Disease) × P(Disease) / P(Positive) Breaking this down with our numbers: - P(Positive|Disease) = Sensitivity = {sensitivity:.1%} - P(Disease) = Prior Probability = {prior_prob:.1%} - P(Positive) = Probability of a positive test = {test_positive_prob:.1%} Putting these into the formula: - Posterior Probability = {sensitivity:.1%} × {prior_prob:.1%} ÷ {test_positive_prob:.1%} - = {sensitivity * prior_prob:.1%} ÷ {test_positive_prob:.1%} - = {posterior_prob:.1%} **The key insight:** If someone tests positive, they have a {posterior_prob:.1%} chance of having the disease, not {sensitivity:.1%} as many people might think! This is often surprising because: 1. Even a good test ({sensitivity:.1%} accurate) can give misleading results when a disease is rare 2. Most positive results might actually be false alarms when testing for rare conditions 3. The more common a disease is, the more likely a positive test is to be correct """ return fig, explanation # Create the Gradio interface with gr.Blocks(title="Bayes' Theorem Visualizer") as demo: gr.Markdown("# Bayes' Theorem Visualizer") gr.Markdown(""" Bayes' theorem helps us update our beliefs based on new evidence. This interactive tool visualizes how prior probability, sensitivity, and specificity affect the posterior probability in a medical testing scenario. Adjust the sliders below and see how the results change in real-time! """) with gr.Column(): with gr.Group(): gr.Markdown("### Adjust Parameters") prior_prob = gr.Slider( minimum=0.01, maximum=0.5, value=0.1, step=0.01, label="Disease Prevalence (Prior Probability)", info="What percentage of the population has the disease?" ) sensitivity = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, step=0.01, label="Test Sensitivity (True Positive Rate)", info="How good is the test at detecting people who have the disease?" ) specificity = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, step=0.01, label="Test Specificity (True Negative Rate)", info="How good is the test at correctly identifying people who don't have the disease?" ) output_plot = gr.Plot(label="Visualization") output_text = gr.Markdown(label="Explanation") with gr.Accordion("Key Terms", open=False): gr.Markdown(""" - **Prior Probability (Prevalence)**: The initial probability of having a disease before testing - **Sensitivity**: The ability to correctly identify those with the disease (true positive rate) - **Specificity**: The ability to correctly identify those without the disease (true negative rate) - **Posterior Probability**: The updated probability of having the disease after a positive test - **True Positive**: Correctly identified as having the disease - **False Positive**: Incorrectly identified as having the disease (also called a "Type I error") - **True Negative**: Correctly identified as not having the disease - **False Negative**: Incorrectly identified as not having the disease (also called a "Type II error") """) # Update when any parameter changes for param in [prior_prob, sensitivity, specificity]: param.change( explain_bayes, inputs=[prior_prob, sensitivity, specificity], outputs=[output_plot, output_text] ) # Add examples gr.Examples( examples=[ [0.01, 0.99, 0.99], # Rare disease, excellent test [0.1, 0.9, 0.9], # Common scenario [0.3, 0.8, 0.7], # More common disease, less accurate test [0.5, 0.7, 0.95] # Very common disease, asymmetric test accuracy ], inputs=[prior_prob, sensitivity, specificity], outputs=[output_plot, output_text], fn=explain_bayes, label="Try These Examples" ) # Initialize the visualization demo.load( explain_bayes, inputs=[prior_prob, sensitivity, specificity], outputs=[output_plot, output_text] ) # Launch the app if __name__ == "__main__": demo.launch()