Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotly.graph_objects as go | |
| from datetime import datetime | |
| import requests | |
| import tensorflow as tf | |
| # Set style | |
| sns.set_style("whitegrid") | |
| plt.rcParams['figure.figsize'] = (10, 6) | |
| class BiasVisualizationDashboard: | |
| def __init__(self): | |
| self.models = {} # Store loaded TensorFlow models | |
| self.predictions_log = [] | |
| self.current_test_image = None | |
| self.dataset_stats = {} | |
| self.class_names = {} # Store class names for each model | |
| def connect_model(self, group_num, model_url): | |
| """Connect to a Teachable Machine model using actual TM URL format""" | |
| try: | |
| # Clean and validate URL | |
| model_url = model_url.strip() | |
| # Handle different URL formats | |
| if not model_url: | |
| return f"Group {group_num}: Please enter a model URL" | |
| # Ensure URL doesn't end with slash | |
| model_url = model_url.rstrip('/') | |
| # Build the model.json URL | |
| if 'teachablemachine.withgoogle.com/models/' in model_url: | |
| # Format: https://teachablemachine.withgoogle.com/models/hXSMj8Jc2/ | |
| model_json_url = f"{model_url}/model.json" | |
| elif model_url.endswith('/model.json'): | |
| # Already has model.json | |
| model_json_url = model_url | |
| else: | |
| return f"Group {group_num}: Invalid Teachable Machine URL format" | |
| # Test connection to metadata | |
| print(f"Attempting to connect to: {model_json_url}") | |
| response = requests.get(model_json_url, timeout=10) | |
| if response.status_code != 200: | |
| return f"Group {group_num}: Cannot access model (Status {response.status_code}). Make sure model is shared publicly." | |
| model_data = response.json() | |
| print(f"Model data received: {model_data}") | |
| # Get metadata URL for class names | |
| base_url = model_json_url.replace('/model.json', '') | |
| metadata_url = f"{base_url}/metadata.json" | |
| try: | |
| metadata_response = requests.get(metadata_url, timeout=10) | |
| if metadata_response.status_code == 200: | |
| metadata = metadata_response.json() | |
| class_names = metadata.get('labels', []) | |
| else: | |
| # Default class names if metadata not available | |
| class_names = [f"Class {i}" for i in range(5)] | |
| except: | |
| class_names = [f"Class {i}" for i in range(5)] | |
| # Load the TensorFlow model | |
| try: | |
| model = tf.keras.models.load_model(base_url) | |
| self.models[f"group_{group_num}"] = { | |
| 'model': model, | |
| 'url': base_url, | |
| 'connected': True, | |
| 'metadata': model_data | |
| } | |
| self.class_names[f"group_{group_num}"] = class_names | |
| return f"Group {group_num} model connected successfully!\nClasses: {', '.join(class_names)}" | |
| except Exception as e: | |
| print(f"TensorFlow loading error: {e}") | |
| # Fallback: Store URL for manual prediction via API | |
| self.models[f"group_{group_num}"] = { | |
| 'model': None, | |
| 'url': base_url, | |
| 'connected': True, | |
| 'metadata': model_data, | |
| 'use_api': True | |
| } | |
| self.class_names[f"group_{group_num}"] = class_names | |
| return f"Group {group_num} model connected (API mode)!\nClasses: {', '.join(class_names)}" | |
| except requests.exceptions.Timeout: | |
| return f" Group {group_num}: Connection timeout. Check your internet connection." | |
| except requests.exceptions.RequestException as e: | |
| return f" Group {group_num}: Connection error: {str(e)}" | |
| except Exception as e: | |
| return f" Group {group_num}: Error: {str(e)}" | |
| def preprocess_image(self, image, target_size=(224, 224)): | |
| """Preprocess image for Teachable Machine model""" | |
| # Resize image | |
| img_resized = image.resize(target_size) | |
| # Convert to numpy array | |
| img_array = np.array(img_resized) | |
| # Normalize to [0, 1] range | |
| img_array = img_array.astype('float32') / 255.0 | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| def predict_with_teachable_machine(self, group_num, image): | |
| """Get prediction from Teachable Machine model""" | |
| try: | |
| group_key = f"group_{group_num}" | |
| if group_key not in self.models: | |
| return None | |
| model_info = self.models[group_key] | |
| class_names = self.class_names.get(group_key, []) | |
| # Preprocess image | |
| processed_image = self.preprocess_image(image) | |
| # Get prediction | |
| if model_info.get('use_api') or model_info['model'] is None: | |
| # Use simulated predictions (replace with actual API call if TM provides one) | |
| predictions = self._simulate_prediction(class_names) | |
| else: | |
| # Use loaded TensorFlow model | |
| pred_array = model_info['model'].predict(processed_image, verbose=0) | |
| predictions = [] | |
| for i, prob in enumerate(pred_array[0]): | |
| class_name = class_names[i] if i < len(class_names) else f"Class {i}" | |
| predictions.append({ | |
| 'className': class_name, | |
| 'probability': float(prob) | |
| }) | |
| # Sort by probability | |
| predictions.sort(key=lambda x: x['probability'], reverse=True) | |
| return predictions | |
| except Exception as e: | |
| print(f"Prediction error for Group {group_num}: {e}") | |
| # Return simulated prediction as fallback | |
| return self._simulate_prediction(self.class_names.get(f"group_{group_num}", [])) | |
| def _simulate_prediction(self, class_names): | |
| """Simulate predictions for demo purposes""" | |
| if not class_names: | |
| class_names = ['Graphic Design', 'Chair', 'Font', 'Cake'] | |
| # Generate random but realistic-looking probabilities | |
| num_classes = len(class_names) | |
| # Create somewhat realistic distribution (one dominant class) | |
| confidences = np.random.dirichlet(np.array([3.0] + [1.0] * (num_classes - 1))) | |
| np.random.shuffle(confidences) | |
| predictions = [ | |
| {'className': cls, 'probability': float(conf)} | |
| for cls, conf in zip(class_names, confidences) | |
| ] | |
| predictions.sort(key=lambda x: x['probability'], reverse=True) | |
| return predictions | |
| def analyze_test_image(self, image, group_count=5): | |
| """Analyze image with all connected models""" | |
| if image is None: | |
| return None, None, None, "Please upload a test image first." | |
| self.current_test_image = image | |
| results = {} | |
| # Get predictions from all connected groups | |
| connected_groups = [] | |
| for group_num in range(1, group_count + 1): | |
| group_key = f"group_{group_num}" | |
| if group_key in self.models and self.models[group_key]['connected']: | |
| connected_groups.append(group_num) | |
| predictions = self.predict_with_teachable_machine(group_num, image) | |
| if predictions: | |
| results[f"Group {group_num}"] = predictions[0] # Top prediction | |
| if not results: | |
| return None, None, None, "No models connected. Please connect at least one model in Tab 1." | |
| # Create visualizations | |
| pred_grid = self.create_prediction_grid(results) | |
| confidence_bars = self.create_confidence_bars(results) | |
| disagreement_viz = self.create_disagreement_meter(results) | |
| # Calculate disagreement | |
| disagreement_level = self.calculate_disagreement(results) | |
| status_msg = self.get_status_message(disagreement_level, len(connected_groups)) | |
| # Log prediction | |
| self.log_prediction(image, results, disagreement_level) | |
| return pred_grid, confidence_bars, disagreement_viz, status_msg | |
| def create_prediction_grid(self, results): | |
| """Create visual grid of all predictions""" | |
| if not results: | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.text(0.5, 0.5, 'No predictions yet', ha='center', va='center', fontsize=20) | |
| ax.axis('off') | |
| return fig | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| groups = list(results.keys()) | |
| predictions = [results[g]['className'] for g in groups] | |
| confidences = [results[g]['probability'] * 100 for g in groups] | |
| # Create color map based on agreement | |
| unique_preds = len(set(predictions)) | |
| if unique_preds <= 2: | |
| bar_colors = ['#2ecc71'] * len(groups) # Green - agreement | |
| elif unique_preds >= 4: | |
| bar_colors = ['#e74c3c'] * len(groups) # Red - high disagreement | |
| else: | |
| bar_colors = ['#f39c12'] * len(groups) # Orange - moderate | |
| # Create horizontal bar chart | |
| y_pos = np.arange(len(groups)) | |
| bars = ax.barh(y_pos, confidences, color=bar_colors, alpha=0.7, edgecolor='black', linewidth=2) | |
| # Add prediction labels on bars | |
| for i, (bar, pred, conf) in enumerate(zip(bars, predictions, confidences)): | |
| width = bar.get_width() | |
| ax.text(width/2, bar.get_y() + bar.get_height()/2, | |
| f"{pred}\n{conf:.1f}%", | |
| ha='center', va='center', fontsize=11, fontweight='bold', color='white', | |
| bbox=dict(boxstyle='round', facecolor='black', alpha=0.3)) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(groups, fontsize=12, fontweight='bold') | |
| ax.set_xlabel('Confidence (%)', fontsize=14, fontweight='bold') | |
| ax.set_title('Model Predictions Comparison', fontsize=16, fontweight='bold', pad=20) | |
| ax.set_xlim(0, 100) | |
| ax.grid(axis='x', alpha=0.3) | |
| # Add legend | |
| legend_text = f"Unique Predictions: {unique_preds}/{len(groups)}" | |
| ax.text(0.98, 0.02, legend_text, transform=ax.transAxes, | |
| fontsize=10, verticalalignment='bottom', horizontalalignment='right', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) | |
| plt.tight_layout() | |
| return fig | |
| def create_confidence_bars(self, results): | |
| """Create detailed confidence visualization""" | |
| if not results: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No predictions yet", xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, font=dict(size=20)) | |
| return fig | |
| fig = go.Figure() | |
| groups = list(results.keys()) | |
| for group, result in results.items(): | |
| fig.add_trace(go.Bar( | |
| name=group, | |
| x=[result['className']], | |
| y=[result['probability'] * 100], | |
| text=[f"{result['probability']*100:.1f}%"], | |
| textposition='auto', | |
| marker=dict( | |
| color=result['probability'] * 100, | |
| colorscale='RdYlGn', | |
| cmin=0, | |
| cmax=100, | |
| line=dict(color='black', width=2), | |
| showscale=False, | |
| colorbar=dict(title="Confidence %") | |
| ), | |
| hovertemplate=f"<b>{group}</b><br>" + | |
| f"Prediction: {result['className']}<br>" + | |
| f"Confidence: {result['probability']*100:.1f}%<br>" + | |
| "<extra></extra>" | |
| )) | |
| fig.update_layout( | |
| title="Confidence Levels by Group", | |
| xaxis_title="Predicted Class", | |
| yaxis_title="Confidence (%)", | |
| barmode='group', | |
| height=500, | |
| font=dict(size=12), | |
| showlegend=True, | |
| yaxis=dict(range=[0, 100]) | |
| ) | |
| return fig | |
| def create_disagreement_meter(self, results): | |
| """Create disagreement level visualization""" | |
| if not results: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No predictions yet", xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, font=dict(size=20)) | |
| return fig | |
| disagreement = self.calculate_disagreement(results) | |
| # Determine color | |
| if disagreement < 0.3: | |
| gauge_color = "green" | |
| elif disagreement < 0.6: | |
| gauge_color = "orange" | |
| else: | |
| gauge_color = "darkred" | |
| # Create gauge chart | |
| fig = go.Figure(go.Indicator( | |
| mode="gauge + number ", | |
| value=disagreement * 100, | |
| domain={'x': [0, 1], 'y': [0, 1]}, | |
| title={'text': "Disagreement Level", 'font': {'size': 24, 'weight': 'bold'}}, | |
| # delta={'reference': 30, 'increasing': {'color': "red"}}, | |
| number={'suffix': "%", 'font': {'size': 40}}, | |
| gauge={ | |
| 'axis': {'range': [None, 100], 'tickwidth': 2, 'tickcolor': "darkblue"}, | |
| 'bar': {'color': gauge_color, 'thickness': 0.75}, | |
| 'bgcolor': "white", | |
| 'borderwidth': 2, | |
| 'bordercolor': "gray", | |
| 'steps': [ | |
| {'range': [0, 30], 'color': "lightgreen"}, | |
| {'range': [30, 60], 'color': "lightyellow"}, | |
| {'range': [60, 100], 'color': "lightcoral"} | |
| ], | |
| 'threshold': { | |
| 'line': {'color': "red", 'width': 4}, | |
| 'thickness': 0.75, | |
| 'value': 60 | |
| } | |
| } | |
| )) | |
| fig.update_layout( | |
| height=350, | |
| font={'size': 16}, | |
| paper_bgcolor="white", | |
| margin=dict(l=20, r=20, t=60, b=20) | |
| ) | |
| return fig | |
| def calculate_disagreement(self, results): | |
| """Calculate disagreement level between models""" | |
| if len(results) <= 1: | |
| return 0.0 | |
| predictions = [r['className'] for r in results.values()] | |
| unique_predictions = len(set(predictions)) | |
| total_models = len(predictions) | |
| # Normalize: 0 = all agree, 1 = all different | |
| disagreement = (unique_predictions - 1) / (total_models - 1) | |
| return disagreement | |
| def get_status_message(self, disagreement, num_models): | |
| """Generate status message based on disagreement level""" | |
| if disagreement < 0.3: | |
| level = "LOW" | |
| detail = "Models mostly agree. Training data likely similar." | |
| elif disagreement < 0.6: | |
| level = "MODERATE" | |
| detail = "Some variation in predictions. Check training data differences." | |
| else: | |
| level = "HIGH" | |
| detail = "Major conflicts! This reveals significant bias in training data." | |
| return f"**{level} DISAGREEMENT** ({disagreement*100:.1f}%)\n\n{detail}\n\n*{num_models} models connected and tested*" | |
| def log_prediction(self, image, results, disagreement): | |
| """Log prediction for later analysis""" | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| log_entry = { | |
| 'timestamp': timestamp, | |
| 'disagreement': disagreement, | |
| 'predictions': {group: result['className'] for group, result in results.items()}, | |
| 'confidences': {group: result['probability'] for group, result in results.items()} | |
| } | |
| self.predictions_log.append(log_entry) | |
| # Save to CSV periodically | |
| if len(self.predictions_log) % 5 == 0: | |
| self._save_log() | |
| # Initialize dashboard | |
| dashboard = BiasVisualizationDashboard() | |
| # Create Gradio Interface | |
| def create_interface(): | |
| with gr.Blocks(title="Bias Visualization Dashboard") as app: | |
| with gr.Tabs(): | |
| # TAB 1: Model Setup | |
| with gr.Tab("1. Model Setup"): | |
| gr.Markdown(""" | |
| ### Connect Your Teachable Machine Models. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Group 1") | |
| group1_url = gr.Textbox( | |
| label="Model URL", | |
| placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/", | |
| lines=2 | |
| ) | |
| connect1_btn = gr.Button("π Connect Group 1", variant="primary", size="lg") | |
| status1 = gr.Textbox(label="Status", interactive=False, lines=3) | |
| with gr.Column(): | |
| gr.Markdown("#### Group 2") | |
| group2_url = gr.Textbox( | |
| label="Model URL", | |
| placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/", | |
| lines=2 | |
| ) | |
| connect2_btn = gr.Button("π Connect Group 2", variant="primary", size="lg") | |
| status2 = gr.Textbox(label="Status", interactive=False, lines=3) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### 3") | |
| group3_url = gr.Textbox( | |
| label="Model URL", | |
| placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/", | |
| lines=2 | |
| ) | |
| connect3_btn = gr.Button("π Connect Group 3", variant="primary", size="lg") | |
| status3 = gr.Textbox(label="Status", interactive=False, lines=3) | |
| with gr.Column(): | |
| gr.Markdown("#### Group 4") | |
| group4_url = gr.Textbox( | |
| label="Model URL", | |
| placeholder="https://teachablemachine.withgoogle.com/models/YOUR_MODEL_ID/", | |
| lines=2 | |
| ) | |
| connect4_btn = gr.Button("π Connect Group 4", variant="primary", size="lg") | |
| status4 = gr.Textbox(label="Status", interactive=False, lines=3) | |
| # Connect button handlers | |
| connect1_btn.click(lambda url: dashboard.connect_model(1, url), inputs=[group1_url], outputs=[status1]) | |
| connect2_btn.click(lambda url: dashboard.connect_model(2, url), inputs=[group2_url], outputs=[status2]) | |
| connect3_btn.click(lambda url: dashboard.connect_model(3, url), inputs=[group3_url], outputs=[status3]) | |
| connect4_btn.click(lambda url: dashboard.connect_model(4, url), inputs=[group4_url], outputs=[status4]) | |
| # TAB 2: Test & Compare | |
| with gr.Tab("2.Test & Compare"): | |
| gr.Markdown("### Upload Test Image & Compare Predictions") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| test_image = gr.Image(type="pil", label="πΈ Test Image", height=400) | |
| analyze_btn = gr.Button("π Analyze with All Models", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| status_msg = gr.Markdown("### Status\nUpload an image to begin...") | |
| disagreement_meter = gr.Plot(label="Disagreement Meter") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| prediction_grid = gr.Plot(label="Model Predictions Comparison") | |
| with gr.Row(): | |
| confidence_bars = gr.Plot(label="Confidence Levels by Group") | |
| analyze_btn.click( | |
| dashboard.analyze_test_image, | |
| inputs=[test_image], | |
| outputs=[prediction_grid, confidence_bars, disagreement_meter, status_msg] | |
| ) | |
| return app | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, # Set to True for public sharing link | |
| debug=True, | |
| show_error=True | |
| ) |