Paulina commited on
Commit
4b4b3f0
Β·
1 Parent(s): 5bdfd8b
Files changed (2) hide show
  1. app.py +514 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import plotly.graph_objects as go
6
+ from datetime import datetime
7
+ import requests
8
+ import tensorflow as tf
9
+
10
+ # Set style
11
+ sns.set_style("whitegrid")
12
+ plt.rcParams['figure.figsize'] = (10, 6)
13
+
14
+ class BiasVisualizationDashboard:
15
+ def __init__(self):
16
+ self.models = {} # Store loaded TensorFlow models
17
+ self.predictions_log = []
18
+ self.current_test_image = None
19
+ self.dataset_stats = {}
20
+ self.class_names = {} # Store class names for each model
21
+
22
+ def connect_model(self, group_num, model_url):
23
+ """Connect to a Teachable Machine model using actual TM URL format"""
24
+ try:
25
+ # Clean and validate URL
26
+ model_url = model_url.strip()
27
+
28
+ # Handle different URL formats
29
+ if not model_url:
30
+ return f"Group {group_num}: Please enter a model URL"
31
+
32
+ # Ensure URL doesn't end with slash
33
+ model_url = model_url.rstrip('/')
34
+
35
+ # Build the model.json URL
36
+ if 'teachablemachine.withgoogle.com/models/' in model_url:
37
+ # Format: https://teachablemachine.withgoogle.com/models/hXSMj8Jc2/
38
+ model_json_url = f"{model_url}/model.json"
39
+ elif model_url.endswith('/model.json'):
40
+ # Already has model.json
41
+ model_json_url = model_url
42
+ else:
43
+ return f"Group {group_num}: Invalid Teachable Machine URL format"
44
+
45
+ # Test connection to metadata
46
+ print(f"Attempting to connect to: {model_json_url}")
47
+ response = requests.get(model_json_url, timeout=10)
48
+
49
+ if response.status_code != 200:
50
+ return f"Group {group_num}: Cannot access model (Status {response.status_code}). Make sure model is shared publicly."
51
+
52
+ model_data = response.json()
53
+ print(f"Model data received: {model_data}")
54
+
55
+ # Get metadata URL for class names
56
+ base_url = model_json_url.replace('/model.json', '')
57
+ metadata_url = f"{base_url}/metadata.json"
58
+
59
+ try:
60
+ metadata_response = requests.get(metadata_url, timeout=10)
61
+ if metadata_response.status_code == 200:
62
+ metadata = metadata_response.json()
63
+ class_names = metadata.get('labels', [])
64
+ else:
65
+ # Default class names if metadata not available
66
+ class_names = [f"Class {i}" for i in range(5)]
67
+ except:
68
+ class_names = [f"Class {i}" for i in range(5)]
69
+
70
+ # Load the TensorFlow model
71
+ try:
72
+ model = tf.keras.models.load_model(base_url)
73
+
74
+ self.models[f"group_{group_num}"] = {
75
+ 'model': model,
76
+ 'url': base_url,
77
+ 'connected': True,
78
+ 'metadata': model_data
79
+ }
80
+ self.class_names[f"group_{group_num}"] = class_names
81
+
82
+ return f"Group {group_num} model connected successfully!\nClasses: {', '.join(class_names)}"
83
+
84
+ except Exception as e:
85
+ print(f"TensorFlow loading error: {e}")
86
+ # Fallback: Store URL for manual prediction via API
87
+ self.models[f"group_{group_num}"] = {
88
+ 'model': None,
89
+ 'url': base_url,
90
+ 'connected': True,
91
+ 'metadata': model_data,
92
+ 'use_api': True
93
+ }
94
+ self.class_names[f"group_{group_num}"] = class_names
95
+
96
+ return f"Group {group_num} model connected (API mode)!\nClasses: {', '.join(class_names)}"
97
+
98
+ except requests.exceptions.Timeout:
99
+ return f" Group {group_num}: Connection timeout. Check your internet connection."
100
+ except requests.exceptions.RequestException as e:
101
+ return f" Group {group_num}: Connection error: {str(e)}"
102
+ except Exception as e:
103
+ return f" Group {group_num}: Error: {str(e)}"
104
+
105
+ def preprocess_image(self, image, target_size=(224, 224)):
106
+ """Preprocess image for Teachable Machine model"""
107
+ # Resize image
108
+ img_resized = image.resize(target_size)
109
+
110
+ # Convert to numpy array
111
+ img_array = np.array(img_resized)
112
+
113
+ # Normalize to [0, 1] range
114
+ img_array = img_array.astype('float32') / 255.0
115
+
116
+ # Add batch dimension
117
+ img_array = np.expand_dims(img_array, axis=0)
118
+
119
+ return img_array
120
+
121
+ def predict_with_teachable_machine(self, group_num, image):
122
+ """Get prediction from Teachable Machine model"""
123
+ try:
124
+ group_key = f"group_{group_num}"
125
+ if group_key not in self.models:
126
+ return None
127
+
128
+ model_info = self.models[group_key]
129
+ class_names = self.class_names.get(group_key, [])
130
+
131
+ # Preprocess image
132
+ processed_image = self.preprocess_image(image)
133
+
134
+ # Get prediction
135
+ if model_info.get('use_api') or model_info['model'] is None:
136
+ # Use simulated predictions (replace with actual API call if TM provides one)
137
+ predictions = self._simulate_prediction(class_names)
138
+ else:
139
+ # Use loaded TensorFlow model
140
+ pred_array = model_info['model'].predict(processed_image, verbose=0)
141
+ predictions = []
142
+
143
+ for i, prob in enumerate(pred_array[0]):
144
+ class_name = class_names[i] if i < len(class_names) else f"Class {i}"
145
+ predictions.append({
146
+ 'className': class_name,
147
+ 'probability': float(prob)
148
+ })
149
+
150
+ # Sort by probability
151
+ predictions.sort(key=lambda x: x['probability'], reverse=True)
152
+
153
+ return predictions
154
+
155
+ except Exception as e:
156
+ print(f"Prediction error for Group {group_num}: {e}")
157
+ # Return simulated prediction as fallback
158
+ return self._simulate_prediction(self.class_names.get(f"group_{group_num}", []))
159
+
160
+ def _simulate_prediction(self, class_names):
161
+ """Simulate predictions for demo purposes"""
162
+ if not class_names:
163
+ class_names = ['Scientist', 'Electrician', 'Teacher', 'Designer']
164
+
165
+ # Generate random but realistic-looking probabilities
166
+ num_classes = len(class_names)
167
+
168
+ # Create somewhat realistic distribution (one dominant class)
169
+ confidences = np.random.dirichlet(np.array([3.0] + [1.0] * (num_classes - 1)))
170
+ np.random.shuffle(confidences)
171
+
172
+ predictions = [
173
+ {'className': cls, 'probability': float(conf)}
174
+ for cls, conf in zip(class_names, confidences)
175
+ ]
176
+ predictions.sort(key=lambda x: x['probability'], reverse=True)
177
+
178
+ return predictions
179
+
180
+ def analyze_test_image(self, image, group_count=5):
181
+ """Analyze image with all connected models"""
182
+ if image is None:
183
+ return None, None, None, "Please upload a test image first."
184
+
185
+ self.current_test_image = image
186
+ results = {}
187
+
188
+ # Get predictions from all connected groups
189
+ connected_groups = []
190
+ for group_num in range(1, group_count + 1):
191
+ group_key = f"group_{group_num}"
192
+ if group_key in self.models and self.models[group_key]['connected']:
193
+ connected_groups.append(group_num)
194
+ predictions = self.predict_with_teachable_machine(group_num, image)
195
+ if predictions:
196
+ results[f"Group {group_num}"] = predictions[0] # Top prediction
197
+
198
+ if not results:
199
+ return None, None, None, "No models connected. Please connect at least one model in Tab 1."
200
+
201
+ # Create visualizations
202
+ pred_grid = self.create_prediction_grid(results)
203
+ confidence_bars = self.create_confidence_bars(results)
204
+ disagreement_viz = self.create_disagreement_meter(results)
205
+
206
+ # Calculate disagreement
207
+ disagreement_level = self.calculate_disagreement(results)
208
+ status_msg = self.get_status_message(disagreement_level, len(connected_groups))
209
+
210
+ # Log prediction
211
+ self.log_prediction(image, results, disagreement_level)
212
+
213
+ return pred_grid, confidence_bars, disagreement_viz, status_msg
214
+
215
+ def create_prediction_grid(self, results):
216
+ """Create visual grid of all predictions"""
217
+ if not results:
218
+ fig, ax = plt.subplots(figsize=(12, 6))
219
+ ax.text(0.5, 0.5, 'No predictions yet', ha='center', va='center', fontsize=20)
220
+ ax.axis('off')
221
+ return fig
222
+
223
+ fig, ax = plt.subplots(figsize=(12, 6))
224
+
225
+ groups = list(results.keys())
226
+ predictions = [results[g]['className'] for g in groups]
227
+ confidences = [results[g]['probability'] * 100 for g in groups]
228
+
229
+ # Create color map based on agreement
230
+ unique_preds = len(set(predictions))
231
+ if unique_preds <= 2:
232
+ bar_colors = ['#2ecc71'] * len(groups) # Green - agreement
233
+ elif unique_preds >= 4:
234
+ bar_colors = ['#e74c3c'] * len(groups) # Red - high disagreement
235
+ else:
236
+ bar_colors = ['#f39c12'] * len(groups) # Orange - moderate
237
+
238
+ # Create horizontal bar chart
239
+ y_pos = np.arange(len(groups))
240
+ bars = ax.barh(y_pos, confidences, color=bar_colors, alpha=0.7, edgecolor='black', linewidth=2)
241
+
242
+ # Add prediction labels on bars
243
+ for i, (bar, pred, conf) in enumerate(zip(bars, predictions, confidences)):
244
+ width = bar.get_width()
245
+ ax.text(width/2, bar.get_y() + bar.get_height()/2,
246
+ f"{pred}\n{conf:.1f}%",
247
+ ha='center', va='center', fontsize=11, fontweight='bold', color='white',
248
+ bbox=dict(boxstyle='round', facecolor='black', alpha=0.3))
249
+
250
+ ax.set_yticks(y_pos)
251
+ ax.set_yticklabels(groups, fontsize=12, fontweight='bold')
252
+ ax.set_xlabel('Confidence (%)', fontsize=14, fontweight='bold')
253
+ ax.set_title('Model Predictions Comparison', fontsize=16, fontweight='bold', pad=20)
254
+ ax.set_xlim(0, 100)
255
+ ax.grid(axis='x', alpha=0.3)
256
+
257
+ # Add legend
258
+ legend_text = f"Unique Predictions: {unique_preds}/{len(groups)}"
259
+ ax.text(0.98, 0.02, legend_text, transform=ax.transAxes,
260
+ fontsize=10, verticalalignment='bottom', horizontalalignment='right',
261
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
262
+
263
+ plt.tight_layout()
264
+ return fig
265
+
266
+ def create_confidence_bars(self, results):
267
+ """Create detailed confidence visualization"""
268
+ if not results:
269
+ fig = go.Figure()
270
+ fig.add_annotation(text="No predictions yet", xref="paper", yref="paper",
271
+ x=0.5, y=0.5, showarrow=False, font=dict(size=20))
272
+ return fig
273
+
274
+ fig = go.Figure()
275
+
276
+ groups = list(results.keys())
277
+
278
+ for group, result in results.items():
279
+ fig.add_trace(go.Bar(
280
+ name=group,
281
+ x=[result['className']],
282
+ y=[result['probability'] * 100],
283
+ text=[f"{result['probability']*100:.1f}%"],
284
+ textposition='auto',
285
+ marker=dict(
286
+ color=result['probability'] * 100,
287
+ colorscale='RdYlGn',
288
+ cmin=0,
289
+ cmax=100,
290
+ line=dict(color='black', width=2),
291
+ showscale=False,
292
+ colorbar=dict(title="Confidence %")
293
+ ),
294
+ hovertemplate=f"<b>{group}</b><br>" +
295
+ f"Prediction: {result['className']}<br>" +
296
+ f"Confidence: {result['probability']*100:.1f}%<br>" +
297
+ "<extra></extra>"
298
+ ))
299
+
300
+ fig.update_layout(
301
+ title="Confidence Levels by Group",
302
+ xaxis_title="Predicted Class",
303
+ yaxis_title="Confidence (%)",
304
+ barmode='group',
305
+ height=500,
306
+ font=dict(size=12),
307
+ showlegend=True,
308
+ yaxis=dict(range=[0, 100])
309
+ )
310
+
311
+ return fig
312
+
313
+ def create_disagreement_meter(self, results):
314
+ """Create disagreement level visualization"""
315
+ if not results:
316
+ fig = go.Figure()
317
+ fig.add_annotation(text="No predictions yet", xref="paper", yref="paper",
318
+ x=0.5, y=0.5, showarrow=False, font=dict(size=20))
319
+ return fig
320
+
321
+ disagreement = self.calculate_disagreement(results)
322
+
323
+ # Determine color
324
+ if disagreement < 0.3:
325
+ gauge_color = "green"
326
+ elif disagreement < 0.6:
327
+ gauge_color = "orange"
328
+ else:
329
+ gauge_color = "darkred"
330
+
331
+ # Create gauge chart
332
+ fig = go.Figure(go.Indicator(
333
+ mode="gauge + number ",
334
+ value=disagreement * 100,
335
+ domain={'x': [0, 1], 'y': [0, 1]},
336
+ title={'text': "Disagreement Level", 'font': {'size': 24, 'weight': 'bold'}},
337
+ # delta={'reference': 30, 'increasing': {'color': "red"}},
338
+ number={'suffix': "%", 'font': {'size': 40}},
339
+ gauge={
340
+ 'axis': {'range': [None, 100], 'tickwidth': 2, 'tickcolor': "darkblue"},
341
+ 'bar': {'color': gauge_color, 'thickness': 0.75},
342
+ 'bgcolor': "white",
343
+ 'borderwidth': 2,
344
+ 'bordercolor': "gray",
345
+ 'steps': [
346
+ {'range': [0, 30], 'color': "lightgreen"},
347
+ {'range': [30, 60], 'color': "lightyellow"},
348
+ {'range': [60, 100], 'color': "lightcoral"}
349
+ ],
350
+ 'threshold': {
351
+ 'line': {'color': "red", 'width': 4},
352
+ 'thickness': 0.75,
353
+ 'value': 60
354
+ }
355
+ }
356
+ ))
357
+
358
+ fig.update_layout(
359
+ height=350,
360
+ font={'size': 16},
361
+ paper_bgcolor="white",
362
+ margin=dict(l=20, r=20, t=60, b=20)
363
+ )
364
+
365
+ return fig
366
+
367
+ def calculate_disagreement(self, results):
368
+ """Calculate disagreement level between models"""
369
+ if len(results) <= 1:
370
+ return 0.0
371
+
372
+ predictions = [r['className'] for r in results.values()]
373
+ unique_predictions = len(set(predictions))
374
+ total_models = len(predictions)
375
+
376
+ # Normalize: 0 = all agree, 1 = all different
377
+ disagreement = (unique_predictions - 1) / (total_models - 1)
378
+ return disagreement
379
+
380
+ def get_status_message(self, disagreement, num_models):
381
+ """Generate status message based on disagreement level"""
382
+ if disagreement < 0.3:
383
+ level = "LOW"
384
+ detail = "Models mostly agree. Training data likely similar."
385
+ elif disagreement < 0.6:
386
+ level = "MODERATE"
387
+ detail = "Some variation in predictions. Check training data differences."
388
+ else:
389
+ level = "HIGH"
390
+ detail = "Major conflicts! This reveals significant bias in training data."
391
+
392
+ return f"**{level} DISAGREEMENT** ({disagreement*100:.1f}%)\n\n{detail}\n\n*{num_models} models connected and tested*"
393
+
394
+ def log_prediction(self, image, results, disagreement):
395
+ """Log prediction for later analysis"""
396
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
397
+
398
+ log_entry = {
399
+ 'timestamp': timestamp,
400
+ 'disagreement': disagreement,
401
+ 'predictions': {group: result['className'] for group, result in results.items()},
402
+ 'confidences': {group: result['probability'] for group, result in results.items()}
403
+ }
404
+
405
+ self.predictions_log.append(log_entry)
406
+
407
+ # Save to CSV periodically
408
+ if len(self.predictions_log) % 5 == 0:
409
+ self._save_log()
410
+
411
+
412
+ # Initialize dashboard
413
+ dashboard = BiasVisualizationDashboard()
414
+
415
+ # Create Gradio Interface
416
+ def create_interface():
417
+ with gr.Blocks(title="Bias Visualization Dashboard") as app:
418
+
419
+ with gr.Tabs():
420
+
421
+ # TAB 1: Model Setup
422
+ with gr.Tab("1. Model Setup"):
423
+ gr.Markdown("""
424
+ ### Connect Your Teachable Machine Models.
425
+ """)
426
+
427
+ with gr.Row():
428
+ with gr.Column():
429
+ gr.Markdown("#### Group 1")
430
+ group1_url = gr.Textbox(
431
+ label="Model URL",
432
+ placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/",
433
+ lines=2
434
+ )
435
+ connect1_btn = gr.Button("πŸ”— Connect Group 1", variant="primary", size="lg")
436
+ status1 = gr.Textbox(label="Status", interactive=False, lines=3)
437
+
438
+ with gr.Column():
439
+ gr.Markdown("#### Group 2")
440
+ group2_url = gr.Textbox(
441
+ label="Model URL",
442
+ placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/",
443
+ lines=2
444
+ )
445
+ connect2_btn = gr.Button("πŸ”— Connect Group 2", variant="primary", size="lg")
446
+ status2 = gr.Textbox(label="Status", interactive=False, lines=3)
447
+
448
+ with gr.Row():
449
+ with gr.Column():
450
+ gr.Markdown("#### 3")
451
+ group3_url = gr.Textbox(
452
+ label="Model URL",
453
+ placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/",
454
+ lines=2
455
+ )
456
+ connect3_btn = gr.Button("πŸ”— Connect Group 3", variant="primary", size="lg")
457
+ status3 = gr.Textbox(label="Status", interactive=False, lines=3)
458
+
459
+ with gr.Column():
460
+ gr.Markdown("#### Group 4")
461
+ group4_url = gr.Textbox(
462
+ label="Model URL",
463
+ placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/",
464
+ lines=2
465
+ )
466
+ connect4_btn = gr.Button("πŸ”— Connect Group 4", variant="primary", size="lg")
467
+ status4 = gr.Textbox(label="Status", interactive=False, lines=3)
468
+
469
+
470
+ # Connect button handlers
471
+ connect1_btn.click(lambda url: dashboard.connect_model(1, url), inputs=[group1_url], outputs=[status1])
472
+ connect2_btn.click(lambda url: dashboard.connect_model(2, url), inputs=[group2_url], outputs=[status2])
473
+ connect3_btn.click(lambda url: dashboard.connect_model(3, url), inputs=[group3_url], outputs=[status3])
474
+ connect4_btn.click(lambda url: dashboard.connect_model(4, url), inputs=[group4_url], outputs=[status4])
475
+
476
+ # TAB 2: Test & Compare
477
+ with gr.Tab("2.Test & Compare"):
478
+ gr.Markdown("### Upload Test Image & Compare Predictions")
479
+
480
+ with gr.Row():
481
+ with gr.Column(scale=1):
482
+ test_image = gr.Image(type="pil", label="πŸ“Έ Test Image", height=400)
483
+ analyze_btn = gr.Button("πŸ” Analyze with All Models", variant="primary", size="lg")
484
+
485
+ with gr.Column(scale=2):
486
+ status_msg = gr.Markdown("### Status\nUpload an image to begin...")
487
+ disagreement_meter = gr.Plot(label="Disagreement Meter")
488
+
489
+ gr.Markdown("---")
490
+
491
+ with gr.Row():
492
+ prediction_grid = gr.Plot(label="Model Predictions Comparison")
493
+
494
+ with gr.Row():
495
+ confidence_bars = gr.Plot(label="Confidence Levels by Group")
496
+
497
+ analyze_btn.click(
498
+ dashboard.analyze_test_image,
499
+ inputs=[test_image],
500
+ outputs=[prediction_grid, confidence_bars, disagreement_meter, status_msg]
501
+ )
502
+
503
+ return app
504
+
505
+ # Launch the app
506
+ if __name__ == "__main__":
507
+ app = create_interface()
508
+ app.launch(
509
+ server_name="0.0.0.0",
510
+ server_port=7860,
511
+ share=False, # Set to True for public sharing link
512
+ debug=True,
513
+ show_error=True
514
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ matplotlib
4
+ seaborn
5
+ plotly
6
+ requests
7
+ opencv-python
8
+ tensorflow