vikramlingam commited on
Commit
f5bcd25
·
verified ·
1 Parent(s): aab98e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. bayes.py +392 -0
  2. requirements.txt +4 -0
bayes.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import matplotlib.patches as patches
6
+ from matplotlib.gridspec import GridSpec
7
+ import matplotlib.colors as mcolors
8
+ from matplotlib.ticker import PercentFormatter
9
+
10
+ # Set up the styling for better readability
11
+ plt.rcParams.update({
12
+ 'font.family': 'sans-serif',
13
+ 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
14
+ 'font.size': 12,
15
+ 'axes.titlesize': 16,
16
+ 'axes.labelsize': 14,
17
+ 'xtick.labelsize': 12,
18
+ 'ytick.labelsize': 12,
19
+ 'legend.fontsize': 12,
20
+ 'figure.titlesize': 20
21
+ })
22
+
23
+ def create_bayes_visualization(prior_prob, sensitivity, specificity, population=1000):
24
+ """Create a clear, readable visualization of Bayes' theorem."""
25
+
26
+ # Calculate values based on Bayes theorem
27
+ true_positive = prior_prob * sensitivity * population
28
+ false_positive = (1 - prior_prob) * (1 - specificity) * population
29
+ true_negative = (1 - prior_prob) * specificity * population
30
+ false_negative = prior_prob * (1 - sensitivity) * population
31
+
32
+ # Calculate posterior probability (positive predictive value)
33
+ posterior_prob = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
34
+
35
+ # Create a large figure with a clean white background
36
+ fig = plt.figure(figsize=(16, 12), facecolor='white')
37
+ gs = GridSpec(3, 2, height_ratios=[1, 2, 1], width_ratios=[1, 1])
38
+
39
+ # Title for the entire visualization
40
+ fig.suptitle("Bayes' Theorem: Medical Test Visualization", fontsize=22, fontweight='bold', y=0.98)
41
+
42
+ # Add parameter information at the top
43
+ param_ax = fig.add_subplot(gs[0, :])
44
+ param_ax.axis('off')
45
+ param_text = (
46
+ f"Disease Prevalence: {prior_prob:.1%} | "
47
+ f"Test Sensitivity: {sensitivity:.1%} | "
48
+ f"Test Specificity: {specificity:.1%}"
49
+ )
50
+ param_ax.text(0.5, 0.5, param_text, ha='center', va='center', fontsize=16,
51
+ bbox=dict(facecolor='#e6f2ff', edgecolor='#3399ff', boxstyle='round,pad=0.5'))
52
+
53
+ # 1. Population Distribution (Left)
54
+ pop_ax = fig.add_subplot(gs[1, 0])
55
+ pop_ax.set_title("Population Distribution", fontsize=18, pad=15)
56
+ pop_ax.axis('equal')
57
+ pop_ax.set_xlim(0, 100)
58
+ pop_ax.set_ylim(0, 100)
59
+
60
+ # Create a clean, modern look with a light grid
61
+ pop_ax.grid(False)
62
+ pop_ax.set_xticks([])
63
+ pop_ax.set_yticks([])
64
+ pop_ax.spines['top'].set_visible(False)
65
+ pop_ax.spines['right'].set_visible(False)
66
+ pop_ax.spines['bottom'].set_visible(False)
67
+ pop_ax.spines['left'].set_visible(False)
68
+
69
+ # Draw the population rectangle with a light border
70
+ pop_rect = patches.Rectangle((0, 0), 100, 100, linewidth=2, edgecolor='#666666', facecolor='#f0f0f0', alpha=0.3)
71
+ pop_ax.add_patch(pop_rect)
72
+
73
+ # Draw the disease prevalence with a distinct color
74
+ disease_width = prior_prob * 100
75
+ disease_rect = patches.Rectangle((0, 0), disease_width, 100, linewidth=1,
76
+ edgecolor='#cc0000', facecolor='#ff9999', alpha=0.7)
77
+ pop_ax.add_patch(disease_rect)
78
+
79
+ # Add clear labels with contrasting backgrounds
80
+ pop_ax.text(disease_width/2, 50, f"Have disease\n{int(prior_prob*population)} people\n({prior_prob:.1%})",
81
+ ha='center', va='center', fontsize=14, fontweight='bold',
82
+ bbox=dict(facecolor='white', alpha=0.8, edgecolor='#cc0000', boxstyle='round,pad=0.3'))
83
+
84
+ pop_ax.text(disease_width + (100-disease_width)/2, 50,
85
+ f"Don't have disease\n{int((1-prior_prob)*population)} people\n({1-prior_prob:.1%})",
86
+ ha='center', va='center', fontsize=14, fontweight='bold',
87
+ bbox=dict(facecolor='white', alpha=0.8, edgecolor='#666666', boxstyle='round,pad=0.3'))
88
+
89
+ # Add a dividing line
90
+ pop_ax.axvline(x=disease_width, color='#666666', linestyle='--', linewidth=2, alpha=0.7)
91
+
92
+ # 2. Test Results (Right)
93
+ test_ax = fig.add_subplot(gs[1, 1])
94
+ test_ax.set_title("Test Results", fontsize=18, pad=15)
95
+ test_ax.axis('equal')
96
+ test_ax.set_xlim(0, 100)
97
+ test_ax.set_ylim(0, 100)
98
+
99
+ # Clean styling
100
+ test_ax.grid(False)
101
+ test_ax.set_xticks([])
102
+ test_ax.set_yticks([])
103
+ test_ax.spines['top'].set_visible(False)
104
+ test_ax.spines['right'].set_visible(False)
105
+ test_ax.spines['bottom'].set_visible(False)
106
+ test_ax.spines['left'].set_visible(False)
107
+
108
+ # Draw the test results rectangle
109
+ test_rect = patches.Rectangle((0, 0), 100, 100, linewidth=2, edgecolor='#666666', facecolor='#f0f0f0', alpha=0.3)
110
+ test_ax.add_patch(test_rect)
111
+
112
+ # Calculate proportions for visualization
113
+ tp_width = (prior_prob * sensitivity) * 100
114
+ fp_width = ((1 - prior_prob) * (1 - specificity)) * 100
115
+ fn_width = (prior_prob * (1 - sensitivity)) * 100
116
+ tn_width = ((1 - prior_prob) * specificity) * 100
117
+
118
+ # Use a clear, distinct color palette
119
+ tp_rect = patches.Rectangle((0, 0), tp_width, 100, linewidth=1,
120
+ edgecolor='#990000', facecolor='#ff5555', alpha=0.8)
121
+ test_ax.add_patch(tp_rect)
122
+
123
+ fp_rect = patches.Rectangle((tp_width, 0), fp_width, 100, linewidth=1,
124
+ edgecolor='#994400', facecolor='#ffaa77', alpha=0.8)
125
+ test_ax.add_patch(fp_rect)
126
+
127
+ fn_rect = patches.Rectangle((tp_width + fp_width, 0), fn_width, 100, linewidth=1,
128
+ edgecolor='#004499', facecolor='#77aaff', alpha=0.8)
129
+ test_ax.add_patch(fn_rect)
130
+
131
+ tn_rect = patches.Rectangle((tp_width + fp_width + fn_width, 0), tn_width, 100, linewidth=1,
132
+ edgecolor='#000066', facecolor='#5588ff', alpha=0.8)
133
+ test_ax.add_patch(tn_rect)
134
+
135
+ # Add dividing lines
136
+ test_ax.axvline(x=tp_width + fp_width, color='#666666', linestyle='--', linewidth=2, alpha=0.7)
137
+ test_ax.axvline(x=tp_width, color='#666666', linestyle=':', linewidth=1.5, alpha=0.7)
138
+ test_ax.axvline(x=tp_width + fp_width + fn_width, color='#666666', linestyle=':', linewidth=1.5, alpha=0.7)
139
+
140
+ # Improved label placement to avoid overlap
141
+ # Use vertical positioning to separate labels
142
+
143
+ # True Positives - top position
144
+ test_ax.text(tp_width/2, 75, f"True Positives\n{int(true_positive)} people",
145
+ ha='center', va='center', fontsize=14, fontweight='bold',
146
+ bbox=dict(facecolor='white', alpha=0.9, edgecolor='#990000', boxstyle='round,pad=0.3'))
147
+
148
+ # False Positives - bottom position if narrow, otherwise center
149
+ if fp_width < 10: # If the section is narrow
150
+ fp_y_pos = 25 if tp_width > 10 else 50 # Adjust based on TP width
151
+ test_ax.text(tp_width + fp_width/2, fp_y_pos, f"False\nPositives\n{int(false_positive)}",
152
+ ha='center', va='center', fontsize=12, fontweight='bold',
153
+ bbox=dict(facecolor='white', alpha=0.9, edgecolor='#994400', boxstyle='round,pad=0.3'))
154
+ else:
155
+ test_ax.text(tp_width + fp_width/2, 50, f"False Positives\n{int(false_positive)} people",
156
+ ha='center', va='center', fontsize=14, fontweight='bold',
157
+ bbox=dict(facecolor='white', alpha=0.9, edgecolor='#994400', boxstyle='round,pad=0.3'))
158
+
159
+ # False Negatives - top position if narrow, otherwise center
160
+ if fn_width < 10: # If the section is narrow
161
+ fn_y_pos = 75 if tn_width > 10 else 50 # Adjust based on TN width
162
+ test_ax.text(tp_width + fp_width + fn_width/2, fn_y_pos, f"False\nNegatives\n{int(false_negative)}",
163
+ ha='center', va='center', fontsize=12, fontweight='bold',
164
+ bbox=dict(facecolor='white', alpha=0.9, edgecolor='#004499', boxstyle='round,pad=0.3'))
165
+ else:
166
+ test_ax.text(tp_width + fp_width + fn_width/2, 50, f"False Negatives\n{int(false_negative)} people",
167
+ ha='center', va='center', fontsize=14, fontweight='bold',
168
+ bbox=dict(facecolor='white', alpha=0.9, edgecolor='#004499', boxstyle='round,pad=0.3'))
169
+
170
+ # True Negatives - bottom position
171
+ test_ax.text(tp_width + fp_width + fn_width + tn_width/2, 25, f"True Negatives\n{int(true_negative)} people",
172
+ ha='center', va='center', fontsize=14, fontweight='bold',
173
+ bbox=dict(facecolor='white', alpha=0.9, edgecolor='#000066', boxstyle='round,pad=0.3'))
174
+
175
+ # Add test positive/negative regions with clear separation and improved positioning
176
+ test_positive = true_positive + false_positive
177
+ test_negative = false_negative + true_negative
178
+
179
+ # Create a separate area below the main visualization for test result summaries
180
+ test_ax.add_patch(patches.Rectangle((0, -20), tp_width + fp_width, 15,
181
+ facecolor='#ffeeee', alpha=0.7, edgecolor='#cc0000'))
182
+ test_ax.add_patch(patches.Rectangle((tp_width + fp_width, -20), fn_width + tn_width, 15,
183
+ facecolor='#eeeeff', alpha=0.7, edgecolor='#0044cc'))
184
+
185
+ # Add clear labels for test positive/negative
186
+ test_ax.text(tp_width/2 + fp_width/2, -12.5,
187
+ f"Test Positive: {int(test_positive)} people ({test_positive/population:.1%})",
188
+ ha='center', va='center', fontsize=14, fontweight='bold', color='#990000')
189
+
190
+ test_ax.text(tp_width + fp_width + fn_width/2 + tn_width/2, -12.5,
191
+ f"Test Negative: {int(test_negative)} people ({test_negative/population:.1%})",
192
+ ha='center', va='center', fontsize=14, fontweight='bold', color='#000066')
193
+
194
+ # 3. Formula and Conclusion (Bottom)
195
+ formula_ax = fig.add_subplot(gs[2, :])
196
+ formula_ax.axis('off')
197
+
198
+ # Create a box for the formula
199
+ formula_box = patches.FancyBboxPatch((0.1, 0.4), 0.8, 0.5, boxstyle=patches.BoxStyle("Round", pad=0.6),
200
+ facecolor='#f5f5f5', edgecolor='#3399ff', linewidth=2, alpha=0.7,
201
+ transform=formula_ax.transAxes)
202
+ formula_ax.add_patch(formula_box)
203
+
204
+ # Add Bayes formula with clear formatting
205
+ formula_title = "Bayes' Theorem Applied to Medical Testing:"
206
+ formula_ax.text(0.5, 0.8, formula_title, ha='center', va='center', fontsize=16, fontweight='bold',
207
+ transform=formula_ax.transAxes)
208
+
209
+ formula = r"$P(Disease|Positive) = \frac{P(Positive|Disease) \times P(Disease)}{P(Positive)}$"
210
+ formula_ax.text(0.5, 0.65, formula, ha='center', va='center', fontsize=16,
211
+ transform=formula_ax.transAxes)
212
+
213
+ formula_explained = "Posterior Probability = Sensitivity × Prior Probability / Probability of Positive Test"
214
+ formula_ax.text(0.5, 0.5, formula_explained, ha='center', va='center', fontsize=14, color='#555555',
215
+ transform=formula_ax.transAxes)
216
+
217
+ # Add the calculation with the actual values
218
+ test_positive_prob = test_positive/population
219
+ calculation = f"= {sensitivity:.1%} × {prior_prob:.1%} / {test_positive_prob:.1%} = {posterior_prob:.1%}"
220
+ formula_ax.text(0.5, 0.35, calculation, ha='center', va='center', fontsize=16,
221
+ transform=formula_ax.transAxes)
222
+
223
+ # Create a highlighted conclusion box
224
+ conclusion_box = patches.FancyBboxPatch((0.15, 0.05), 0.7, 0.2, boxstyle=patches.BoxStyle("Round", pad=0.6),
225
+ facecolor='#ffffcc', edgecolor='#ffcc00', linewidth=2,
226
+ transform=formula_ax.transAxes)
227
+ formula_ax.add_patch(conclusion_box)
228
+
229
+ # Add the conclusion with emphasis
230
+ conclusion = f"If someone tests positive, they have a {posterior_prob:.1%} chance of having the disease"
231
+ formula_ax.text(0.5, 0.15, conclusion, ha='center', va='center', fontsize=18, fontweight='bold',
232
+ transform=formula_ax.transAxes)
233
+
234
+ plt.tight_layout()
235
+ plt.subplots_adjust(top=0.92, hspace=0.1, wspace=0.1)
236
+
237
+ return fig
238
+
239
+ def explain_bayes(prior_prob, sensitivity, specificity):
240
+ """Generate the Bayes' theorem explanation and visualization."""
241
+ population = 1000
242
+
243
+ # Calculate values based on Bayes theorem
244
+ true_positive = prior_prob * sensitivity * population
245
+ false_positive = (1 - prior_prob) * (1 - specificity) * population
246
+ true_negative = (1 - prior_prob) * specificity * population
247
+ false_negative = prior_prob * (1 - sensitivity) * population
248
+
249
+ # Calculate posterior probability (positive predictive value)
250
+ test_positive = true_positive + false_positive
251
+ test_positive_prob = test_positive / population
252
+ posterior_prob = true_positive / test_positive if test_positive > 0 else 0
253
+
254
+ # Create the visualization
255
+ fig = create_bayes_visualization(prior_prob, sensitivity, specificity, population)
256
+
257
+ # Generate explanation text with clearer explanations of the percentages
258
+ explanation = f"""
259
+ ### Medical Test Example Explained Step-by-Step
260
+
261
+ Imagine a medical test for a disease that affects {prior_prob:.1%} of the population (prior probability).
262
+
263
+ **What the percentages mean:**
264
+
265
+ 1. **Disease Prevalence ({prior_prob:.1%})**:
266
+ - This means that out of every 100 people, about {int(prior_prob*100)} people have the disease
267
+ - In our population of 1,000 people, {int(prior_prob*population)} people have the disease and {int((1-prior_prob)*population)} people don't
268
+
269
+ 2. **Test Sensitivity ({sensitivity:.1%})**:
270
+ - This means the test correctly identifies {sensitivity:.1%} of people who actually have the disease
271
+ - Out of the {int(prior_prob*population)} people with the disease:
272
+ * {int(true_positive)} people test positive (true positives) = {int(prior_prob*population)} × {sensitivity:.1%}
273
+ * {int(false_negative)} people test negative (false negatives) = {int(prior_prob*population)} × {(1-sensitivity):.1%}
274
+
275
+ 3. **Test Specificity ({specificity:.1%})**:
276
+ - This means the test correctly identifies {specificity:.1%} of people who don't have the disease
277
+ - Out of the {int((1-prior_prob)*population)} people without the disease:
278
+ * {int(true_negative)} people test negative (true negatives) = {int((1-prior_prob)*population)} × {specificity:.1%}
279
+ * {int(false_positive)} people test positive (false positives) = {int((1-prior_prob)*population)} × {(1-specificity):.1%}
280
+
281
+ 4. **Probability of Positive Test ({test_positive_prob:.1%})**:
282
+ - This is the total percentage of people who test positive, regardless of whether they have the disease
283
+ - It's calculated by adding:
284
+ * True positives: {int(true_positive)} people = {prior_prob:.1%} × {sensitivity:.1%} × 1,000
285
+ * False positives: {int(false_positive)} people = {(1-prior_prob):.1%} × {(1-specificity):.1%} × 1,000
286
+ - Total positive tests: {int(test_positive)} people out of 1,000 = {test_positive_prob:.1%} of the population
287
+
288
+ **How the formula works:**
289
+
290
+ Bayes' theorem calculates the probability that someone actually has the disease if they test positive:
291
+
292
+ P(Disease|Positive) = P(Positive|Disease) × P(Disease) / P(Positive)
293
+
294
+ Breaking this down with our numbers:
295
+ - P(Positive|Disease) = Sensitivity = {sensitivity:.1%}
296
+ - P(Disease) = Prior Probability = {prior_prob:.1%}
297
+ - P(Positive) = Probability of a positive test = {test_positive_prob:.1%}
298
+
299
+ Putting these into the formula:
300
+ - Posterior Probability = {sensitivity:.1%} × {prior_prob:.1%} ÷ {test_positive_prob:.1%}
301
+ - = {sensitivity * prior_prob:.1%} ÷ {test_positive_prob:.1%}
302
+ - = {posterior_prob:.1%}
303
+
304
+ **The key insight:** If someone tests positive, they have a {posterior_prob:.1%} chance of having the disease, not {sensitivity:.1%} as many people might think!
305
+
306
+ This is often surprising because:
307
+ 1. Even a good test ({sensitivity:.1%} accurate) can give misleading results when a disease is rare
308
+ 2. Most positive results might actually be false alarms when testing for rare conditions
309
+ 3. The more common a disease is, the more likely a positive test is to be correct
310
+ """
311
+
312
+ return fig, explanation
313
+
314
+ # Create the Gradio interface
315
+ with gr.Blocks(title="Bayes' Theorem Visualizer") as demo:
316
+ gr.Markdown("# Bayes' Theorem Visualizer")
317
+ gr.Markdown("""
318
+ Bayes' theorem helps us update our beliefs based on new evidence. This interactive tool visualizes how prior probability,
319
+ sensitivity, and specificity affect the posterior probability in a medical testing scenario.
320
+
321
+ Adjust the sliders below and see how the results change in real-time!
322
+ """)
323
+
324
+ with gr.Column():
325
+ with gr.Group():
326
+ gr.Markdown("### Adjust Parameters")
327
+
328
+ prior_prob = gr.Slider(
329
+ minimum=0.01, maximum=0.5, value=0.1, step=0.01,
330
+ label="Disease Prevalence (Prior Probability)",
331
+ info="What percentage of the population has the disease?"
332
+ )
333
+
334
+ sensitivity = gr.Slider(
335
+ minimum=0.5, maximum=1.0, value=0.9, step=0.01,
336
+ label="Test Sensitivity (True Positive Rate)",
337
+ info="How good is the test at detecting people who have the disease?"
338
+ )
339
+
340
+ specificity = gr.Slider(
341
+ minimum=0.5, maximum=1.0, value=0.9, step=0.01,
342
+ label="Test Specificity (True Negative Rate)",
343
+ info="How good is the test at correctly identifying people who don't have the disease?"
344
+ )
345
+
346
+ output_plot = gr.Plot(label="Visualization")
347
+ output_text = gr.Markdown(label="Explanation")
348
+
349
+ with gr.Accordion("Key Terms", open=False):
350
+ gr.Markdown("""
351
+ - **Prior Probability (Prevalence)**: The initial probability of having a disease before testing
352
+ - **Sensitivity**: The ability to correctly identify those with the disease (true positive rate)
353
+ - **Specificity**: The ability to correctly identify those without the disease (true negative rate)
354
+ - **Posterior Probability**: The updated probability of having the disease after a positive test
355
+ - **True Positive**: Correctly identified as having the disease
356
+ - **False Positive**: Incorrectly identified as having the disease (also called a "Type I error")
357
+ - **True Negative**: Correctly identified as not having the disease
358
+ - **False Negative**: Incorrectly identified as not having the disease (also called a "Type II error")
359
+ """)
360
+
361
+ # Update when any parameter changes
362
+ for param in [prior_prob, sensitivity, specificity]:
363
+ param.change(
364
+ explain_bayes,
365
+ inputs=[prior_prob, sensitivity, specificity],
366
+ outputs=[output_plot, output_text]
367
+ )
368
+
369
+ # Add examples
370
+ gr.Examples(
371
+ examples=[
372
+ [0.01, 0.99, 0.99], # Rare disease, excellent test
373
+ [0.1, 0.9, 0.9], # Common scenario
374
+ [0.3, 0.8, 0.7], # More common disease, less accurate test
375
+ [0.5, 0.7, 0.95] # Very common disease, asymmetric test accuracy
376
+ ],
377
+ inputs=[prior_prob, sensitivity, specificity],
378
+ outputs=[output_plot, output_text],
379
+ fn=explain_bayes,
380
+ label="Try These Examples"
381
+ )
382
+
383
+ # Initialize the visualization
384
+ demo.load(
385
+ explain_bayes,
386
+ inputs=[prior_prob, sensitivity, specificity],
387
+ outputs=[output_plot, output_text]
388
+ )
389
+
390
+ # Launch the app
391
+ if __name__ == "__main__":
392
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ matplotlib>=3.5.0
3
+ numpy>=1.20.0
4
+ seaborn>=0.11.0