bayes-theorem / app.py
vikramlingam's picture
Rename bayes.py to app.py
1272e90 verified
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors
from matplotlib.ticker import PercentFormatter
# Set up the styling for better readability
plt.rcParams.update({
'font.family': 'sans-serif',
'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
'font.size': 12,
'axes.titlesize': 16,
'axes.labelsize': 14,
'xtick.labelsize': 12,
'ytick.labelsize': 12,
'legend.fontsize': 12,
'figure.titlesize': 20
})
def create_bayes_visualization(prior_prob, sensitivity, specificity, population=1000):
"""Create a clear, readable visualization of Bayes' theorem."""
# Calculate values based on Bayes theorem
true_positive = prior_prob * sensitivity * population
false_positive = (1 - prior_prob) * (1 - specificity) * population
true_negative = (1 - prior_prob) * specificity * population
false_negative = prior_prob * (1 - sensitivity) * population
# Calculate posterior probability (positive predictive value)
posterior_prob = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
# Create a large figure with a clean white background
fig = plt.figure(figsize=(16, 12), facecolor='white')
gs = GridSpec(3, 2, height_ratios=[1, 2, 1], width_ratios=[1, 1])
# Title for the entire visualization
fig.suptitle("Bayes' Theorem: Medical Test Visualization", fontsize=22, fontweight='bold', y=0.98)
# Add parameter information at the top
param_ax = fig.add_subplot(gs[0, :])
param_ax.axis('off')
param_text = (
f"Disease Prevalence: {prior_prob:.1%} | "
f"Test Sensitivity: {sensitivity:.1%} | "
f"Test Specificity: {specificity:.1%}"
)
param_ax.text(0.5, 0.5, param_text, ha='center', va='center', fontsize=16,
bbox=dict(facecolor='#e6f2ff', edgecolor='#3399ff', boxstyle='round,pad=0.5'))
# 1. Population Distribution (Left)
pop_ax = fig.add_subplot(gs[1, 0])
pop_ax.set_title("Population Distribution", fontsize=18, pad=15)
pop_ax.axis('equal')
pop_ax.set_xlim(0, 100)
pop_ax.set_ylim(0, 100)
# Create a clean, modern look with a light grid
pop_ax.grid(False)
pop_ax.set_xticks([])
pop_ax.set_yticks([])
pop_ax.spines['top'].set_visible(False)
pop_ax.spines['right'].set_visible(False)
pop_ax.spines['bottom'].set_visible(False)
pop_ax.spines['left'].set_visible(False)
# Draw the population rectangle with a light border
pop_rect = patches.Rectangle((0, 0), 100, 100, linewidth=2, edgecolor='#666666', facecolor='#f0f0f0', alpha=0.3)
pop_ax.add_patch(pop_rect)
# Draw the disease prevalence with a distinct color
disease_width = prior_prob * 100
disease_rect = patches.Rectangle((0, 0), disease_width, 100, linewidth=1,
edgecolor='#cc0000', facecolor='#ff9999', alpha=0.7)
pop_ax.add_patch(disease_rect)
# Add clear labels with contrasting backgrounds
pop_ax.text(disease_width/2, 50, f"Have disease\n{int(prior_prob*population)} people\n({prior_prob:.1%})",
ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='#cc0000', boxstyle='round,pad=0.3'))
pop_ax.text(disease_width + (100-disease_width)/2, 50,
f"Don't have disease\n{int((1-prior_prob)*population)} people\n({1-prior_prob:.1%})",
ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='#666666', boxstyle='round,pad=0.3'))
# Add a dividing line
pop_ax.axvline(x=disease_width, color='#666666', linestyle='--', linewidth=2, alpha=0.7)
# 2. Test Results (Right)
test_ax = fig.add_subplot(gs[1, 1])
test_ax.set_title("Test Results", fontsize=18, pad=15)
test_ax.axis('equal')
test_ax.set_xlim(0, 100)
test_ax.set_ylim(0, 100)
# Clean styling
test_ax.grid(False)
test_ax.set_xticks([])
test_ax.set_yticks([])
test_ax.spines['top'].set_visible(False)
test_ax.spines['right'].set_visible(False)
test_ax.spines['bottom'].set_visible(False)
test_ax.spines['left'].set_visible(False)
# Draw the test results rectangle
test_rect = patches.Rectangle((0, 0), 100, 100, linewidth=2, edgecolor='#666666', facecolor='#f0f0f0', alpha=0.3)
test_ax.add_patch(test_rect)
# Calculate proportions for visualization
tp_width = (prior_prob * sensitivity) * 100
fp_width = ((1 - prior_prob) * (1 - specificity)) * 100
fn_width = (prior_prob * (1 - sensitivity)) * 100
tn_width = ((1 - prior_prob) * specificity) * 100
# Use a clear, distinct color palette
tp_rect = patches.Rectangle((0, 0), tp_width, 100, linewidth=1,
edgecolor='#990000', facecolor='#ff5555', alpha=0.8)
test_ax.add_patch(tp_rect)
fp_rect = patches.Rectangle((tp_width, 0), fp_width, 100, linewidth=1,
edgecolor='#994400', facecolor='#ffaa77', alpha=0.8)
test_ax.add_patch(fp_rect)
fn_rect = patches.Rectangle((tp_width + fp_width, 0), fn_width, 100, linewidth=1,
edgecolor='#004499', facecolor='#77aaff', alpha=0.8)
test_ax.add_patch(fn_rect)
tn_rect = patches.Rectangle((tp_width + fp_width + fn_width, 0), tn_width, 100, linewidth=1,
edgecolor='#000066', facecolor='#5588ff', alpha=0.8)
test_ax.add_patch(tn_rect)
# Add dividing lines
test_ax.axvline(x=tp_width + fp_width, color='#666666', linestyle='--', linewidth=2, alpha=0.7)
test_ax.axvline(x=tp_width, color='#666666', linestyle=':', linewidth=1.5, alpha=0.7)
test_ax.axvline(x=tp_width + fp_width + fn_width, color='#666666', linestyle=':', linewidth=1.5, alpha=0.7)
# Improved label placement to avoid overlap
# Use vertical positioning to separate labels
# True Positives - top position
test_ax.text(tp_width/2, 75, f"True Positives\n{int(true_positive)} people",
ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.9, edgecolor='#990000', boxstyle='round,pad=0.3'))
# False Positives - bottom position if narrow, otherwise center
if fp_width < 10: # If the section is narrow
fp_y_pos = 25 if tp_width > 10 else 50 # Adjust based on TP width
test_ax.text(tp_width + fp_width/2, fp_y_pos, f"False\nPositives\n{int(false_positive)}",
ha='center', va='center', fontsize=12, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.9, edgecolor='#994400', boxstyle='round,pad=0.3'))
else:
test_ax.text(tp_width + fp_width/2, 50, f"False Positives\n{int(false_positive)} people",
ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.9, edgecolor='#994400', boxstyle='round,pad=0.3'))
# False Negatives - top position if narrow, otherwise center
if fn_width < 10: # If the section is narrow
fn_y_pos = 75 if tn_width > 10 else 50 # Adjust based on TN width
test_ax.text(tp_width + fp_width + fn_width/2, fn_y_pos, f"False\nNegatives\n{int(false_negative)}",
ha='center', va='center', fontsize=12, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.9, edgecolor='#004499', boxstyle='round,pad=0.3'))
else:
test_ax.text(tp_width + fp_width + fn_width/2, 50, f"False Negatives\n{int(false_negative)} people",
ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.9, edgecolor='#004499', boxstyle='round,pad=0.3'))
# True Negatives - bottom position
test_ax.text(tp_width + fp_width + fn_width + tn_width/2, 25, f"True Negatives\n{int(true_negative)} people",
ha='center', va='center', fontsize=14, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.9, edgecolor='#000066', boxstyle='round,pad=0.3'))
# Add test positive/negative regions with clear separation and improved positioning
test_positive = true_positive + false_positive
test_negative = false_negative + true_negative
# Create a separate area below the main visualization for test result summaries
test_ax.add_patch(patches.Rectangle((0, -20), tp_width + fp_width, 15,
facecolor='#ffeeee', alpha=0.7, edgecolor='#cc0000'))
test_ax.add_patch(patches.Rectangle((tp_width + fp_width, -20), fn_width + tn_width, 15,
facecolor='#eeeeff', alpha=0.7, edgecolor='#0044cc'))
# Add clear labels for test positive/negative
test_ax.text(tp_width/2 + fp_width/2, -12.5,
f"Test Positive: {int(test_positive)} people ({test_positive/population:.1%})",
ha='center', va='center', fontsize=14, fontweight='bold', color='#990000')
test_ax.text(tp_width + fp_width + fn_width/2 + tn_width/2, -12.5,
f"Test Negative: {int(test_negative)} people ({test_negative/population:.1%})",
ha='center', va='center', fontsize=14, fontweight='bold', color='#000066')
# 3. Formula and Conclusion (Bottom)
formula_ax = fig.add_subplot(gs[2, :])
formula_ax.axis('off')
# Create a box for the formula
formula_box = patches.FancyBboxPatch((0.1, 0.4), 0.8, 0.5, boxstyle=patches.BoxStyle("Round", pad=0.6),
facecolor='#f5f5f5', edgecolor='#3399ff', linewidth=2, alpha=0.7,
transform=formula_ax.transAxes)
formula_ax.add_patch(formula_box)
# Add Bayes formula with clear formatting
formula_title = "Bayes' Theorem Applied to Medical Testing:"
formula_ax.text(0.5, 0.8, formula_title, ha='center', va='center', fontsize=16, fontweight='bold',
transform=formula_ax.transAxes)
formula = r"$P(Disease|Positive) = \frac{P(Positive|Disease) \times P(Disease)}{P(Positive)}$"
formula_ax.text(0.5, 0.65, formula, ha='center', va='center', fontsize=16,
transform=formula_ax.transAxes)
formula_explained = "Posterior Probability = Sensitivity × Prior Probability / Probability of Positive Test"
formula_ax.text(0.5, 0.5, formula_explained, ha='center', va='center', fontsize=14, color='#555555',
transform=formula_ax.transAxes)
# Add the calculation with the actual values
test_positive_prob = test_positive/population
calculation = f"= {sensitivity:.1%} × {prior_prob:.1%} / {test_positive_prob:.1%} = {posterior_prob:.1%}"
formula_ax.text(0.5, 0.35, calculation, ha='center', va='center', fontsize=16,
transform=formula_ax.transAxes)
# Create a highlighted conclusion box
conclusion_box = patches.FancyBboxPatch((0.15, 0.05), 0.7, 0.2, boxstyle=patches.BoxStyle("Round", pad=0.6),
facecolor='#ffffcc', edgecolor='#ffcc00', linewidth=2,
transform=formula_ax.transAxes)
formula_ax.add_patch(conclusion_box)
# Add the conclusion with emphasis
conclusion = f"If someone tests positive, they have a {posterior_prob:.1%} chance of having the disease"
formula_ax.text(0.5, 0.15, conclusion, ha='center', va='center', fontsize=18, fontweight='bold',
transform=formula_ax.transAxes)
plt.tight_layout()
plt.subplots_adjust(top=0.92, hspace=0.1, wspace=0.1)
return fig
def explain_bayes(prior_prob, sensitivity, specificity):
"""Generate the Bayes' theorem explanation and visualization."""
population = 1000
# Calculate values based on Bayes theorem
true_positive = prior_prob * sensitivity * population
false_positive = (1 - prior_prob) * (1 - specificity) * population
true_negative = (1 - prior_prob) * specificity * population
false_negative = prior_prob * (1 - sensitivity) * population
# Calculate posterior probability (positive predictive value)
test_positive = true_positive + false_positive
test_positive_prob = test_positive / population
posterior_prob = true_positive / test_positive if test_positive > 0 else 0
# Create the visualization
fig = create_bayes_visualization(prior_prob, sensitivity, specificity, population)
# Generate explanation text with clearer explanations of the percentages
explanation = f"""
### Medical Test Example Explained Step-by-Step
Imagine a medical test for a disease that affects {prior_prob:.1%} of the population (prior probability).
**What the percentages mean:**
1. **Disease Prevalence ({prior_prob:.1%})**:
- This means that out of every 100 people, about {int(prior_prob*100)} people have the disease
- In our population of 1,000 people, {int(prior_prob*population)} people have the disease and {int((1-prior_prob)*population)} people don't
2. **Test Sensitivity ({sensitivity:.1%})**:
- This means the test correctly identifies {sensitivity:.1%} of people who actually have the disease
- Out of the {int(prior_prob*population)} people with the disease:
* {int(true_positive)} people test positive (true positives) = {int(prior_prob*population)} × {sensitivity:.1%}
* {int(false_negative)} people test negative (false negatives) = {int(prior_prob*population)} × {(1-sensitivity):.1%}
3. **Test Specificity ({specificity:.1%})**:
- This means the test correctly identifies {specificity:.1%} of people who don't have the disease
- Out of the {int((1-prior_prob)*population)} people without the disease:
* {int(true_negative)} people test negative (true negatives) = {int((1-prior_prob)*population)} × {specificity:.1%}
* {int(false_positive)} people test positive (false positives) = {int((1-prior_prob)*population)} × {(1-specificity):.1%}
4. **Probability of Positive Test ({test_positive_prob:.1%})**:
- This is the total percentage of people who test positive, regardless of whether they have the disease
- It's calculated by adding:
* True positives: {int(true_positive)} people = {prior_prob:.1%} × {sensitivity:.1%} × 1,000
* False positives: {int(false_positive)} people = {(1-prior_prob):.1%} × {(1-specificity):.1%} × 1,000
- Total positive tests: {int(test_positive)} people out of 1,000 = {test_positive_prob:.1%} of the population
**How the formula works:**
Bayes' theorem calculates the probability that someone actually has the disease if they test positive:
P(Disease|Positive) = P(Positive|Disease) × P(Disease) / P(Positive)
Breaking this down with our numbers:
- P(Positive|Disease) = Sensitivity = {sensitivity:.1%}
- P(Disease) = Prior Probability = {prior_prob:.1%}
- P(Positive) = Probability of a positive test = {test_positive_prob:.1%}
Putting these into the formula:
- Posterior Probability = {sensitivity:.1%} × {prior_prob:.1%} ÷ {test_positive_prob:.1%}
- = {sensitivity * prior_prob:.1%} ÷ {test_positive_prob:.1%}
- = {posterior_prob:.1%}
**The key insight:** If someone tests positive, they have a {posterior_prob:.1%} chance of having the disease, not {sensitivity:.1%} as many people might think!
This is often surprising because:
1. Even a good test ({sensitivity:.1%} accurate) can give misleading results when a disease is rare
2. Most positive results might actually be false alarms when testing for rare conditions
3. The more common a disease is, the more likely a positive test is to be correct
"""
return fig, explanation
# Create the Gradio interface
with gr.Blocks(title="Bayes' Theorem Visualizer") as demo:
gr.Markdown("# Bayes' Theorem Visualizer")
gr.Markdown("""
Bayes' theorem helps us update our beliefs based on new evidence. This interactive tool visualizes how prior probability,
sensitivity, and specificity affect the posterior probability in a medical testing scenario.
Adjust the sliders below and see how the results change in real-time!
""")
with gr.Column():
with gr.Group():
gr.Markdown("### Adjust Parameters")
prior_prob = gr.Slider(
minimum=0.01, maximum=0.5, value=0.1, step=0.01,
label="Disease Prevalence (Prior Probability)",
info="What percentage of the population has the disease?"
)
sensitivity = gr.Slider(
minimum=0.5, maximum=1.0, value=0.9, step=0.01,
label="Test Sensitivity (True Positive Rate)",
info="How good is the test at detecting people who have the disease?"
)
specificity = gr.Slider(
minimum=0.5, maximum=1.0, value=0.9, step=0.01,
label="Test Specificity (True Negative Rate)",
info="How good is the test at correctly identifying people who don't have the disease?"
)
output_plot = gr.Plot(label="Visualization")
output_text = gr.Markdown(label="Explanation")
with gr.Accordion("Key Terms", open=False):
gr.Markdown("""
- **Prior Probability (Prevalence)**: The initial probability of having a disease before testing
- **Sensitivity**: The ability to correctly identify those with the disease (true positive rate)
- **Specificity**: The ability to correctly identify those without the disease (true negative rate)
- **Posterior Probability**: The updated probability of having the disease after a positive test
- **True Positive**: Correctly identified as having the disease
- **False Positive**: Incorrectly identified as having the disease (also called a "Type I error")
- **True Negative**: Correctly identified as not having the disease
- **False Negative**: Incorrectly identified as not having the disease (also called a "Type II error")
""")
# Update when any parameter changes
for param in [prior_prob, sensitivity, specificity]:
param.change(
explain_bayes,
inputs=[prior_prob, sensitivity, specificity],
outputs=[output_plot, output_text]
)
# Add examples
gr.Examples(
examples=[
[0.01, 0.99, 0.99], # Rare disease, excellent test
[0.1, 0.9, 0.9], # Common scenario
[0.3, 0.8, 0.7], # More common disease, less accurate test
[0.5, 0.7, 0.95] # Very common disease, asymmetric test accuracy
],
inputs=[prior_prob, sensitivity, specificity],
outputs=[output_plot, output_text],
fn=explain_bayes,
label="Try These Examples"
)
# Initialize the visualization
demo.load(
explain_bayes,
inputs=[prior_prob, sensitivity, specificity],
outputs=[output_plot, output_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()