Spaces:
Sleeping
Sleeping
| # app.py (fixed version) | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from utils.explainers import LimeExplainer, ShapExplainer, CaptumExplainer | |
| from utils.visualization import create_visualization, create_attribution_plot, create_confidence_chart | |
| from utils.export import export_to_csv, export_to_json, export_plot_as_png | |
| # Available models with dataset information | |
| MODELS = { | |
| "BERT Base (English)": { | |
| "path": "bert-base-uncased", | |
| "trained_on": ["BookCorpus", "English Wikipedia"], | |
| "domain": "General text" | |
| }, | |
| "DistilBERT (English)": { | |
| "path": "distilbert-base-uncased", | |
| "trained_on": ["BookCorpus", "English Wikipedia"], | |
| "domain": "General text" | |
| }, | |
| "RoBERTa Base (English)": { | |
| "path": "roberta-base", | |
| "trained_on": ["BookCorpus", "English Wikipedia", "CommonCrawl", "OpenWebText"], | |
| "domain": "General text" | |
| }, | |
| "ALBERT Base (English)": { | |
| "path": "albert-base-v2", | |
| "trained_on": ["BookCorpus", "English Wikipedia"], | |
| "domain": "General text" | |
| }, | |
| } | |
| # Global variables to cache models | |
| model_cache = {} | |
| def load_model(model_name): | |
| """Load model and tokenizer with caching""" | |
| if model_name in model_cache: | |
| return model_cache[model_name] | |
| try: | |
| model_info = MODELS[model_name] | |
| print(f"Loading model: {model_info['path']}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_info['path']) | |
| # Add padding token if it doesn't exist | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_info['path'], | |
| num_labels=2, | |
| output_attentions=False, | |
| output_hidden_states=False | |
| ) | |
| # Cache the model | |
| model_cache[model_name] = (tokenizer, model, model_info) | |
| return tokenizer, model, model_info | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None, None, None | |
| def predict_and_explain(text, model_choices, explainer_choice, compare_mode): | |
| """Main function to make predictions and generate explanations""" | |
| if not text.strip(): | |
| return "Please enter some text to analyze.", None, None, None, None, None | |
| results = [] | |
| visualizations = [] | |
| plots = [] | |
| explanations = [] | |
| confidence_charts = [] | |
| for model_choice in model_choices: | |
| # Load selected model | |
| tokenizer, model, model_info = load_model(model_choice) | |
| if model is None: | |
| results.append(f"Error loading {model_choice}") | |
| visualizations.append(None) | |
| plots.append(None) | |
| explanations.append(None) | |
| confidence_charts.append(None) | |
| continue | |
| # Prepare inputs | |
| try: | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| # Get prediction | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=1).numpy()[0] | |
| predicted_class = np.argmax(probabilities) | |
| confidence = probabilities[predicted_class] | |
| # Format prediction result | |
| result = f"{model_choice}: Class {predicted_class} ({confidence:.2%})" | |
| results.append(result) | |
| # Create confidence chart | |
| confidence_html = create_confidence_chart(probabilities, ["Negative", "Positive"]) | |
| confidence_charts.append(confidence_html) | |
| # Generate explanation | |
| try: | |
| if explainer_choice == "LIME": | |
| explainer = LimeExplainer(model, tokenizer) | |
| explanation = explainer.explain(text, num_features=15) | |
| elif explainer_choice == "SHAP": | |
| explainer = ShapExplainer(model, tokenizer) | |
| explanation = explainer.explain(text) | |
| else: # Captum | |
| explainer = CaptumExplainer(model, tokenizer) | |
| explanation = explainer.explain(text) | |
| except Exception as e: | |
| print(f"Error generating explanation for {model_choice}: {e}") | |
| explanation = [] | |
| explanations.append(explanation) | |
| # Create visualizations | |
| visualization_html = create_visualization(text, explanation, tokenizer, explainer_choice) | |
| plot_html = create_attribution_plot(explanation, explainer_choice) | |
| visualizations.append(visualization_html) | |
| plots.append(plot_html) | |
| except Exception as e: | |
| print(f"Prediction error for {model_choice}: {e}") | |
| results.append(f"{model_choice}: Error - {str(e)}") | |
| visualizations.append(None) | |
| plots.append(None) | |
| explanations.append(None) | |
| confidence_charts.append(None) | |
| # Format outputs based on comparison mode | |
| if compare_mode and len(model_choices) > 1: | |
| # Show comparison summary | |
| comparison_html = """ | |
| <div style="padding: 20px; background: #f8f9fa; border-radius: 10px; border: 2px solid #e9ecef;"> | |
| <h3 style="margin-top: 0; color: #495057;">π Model Comparison Results</h3> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 15px;"> | |
| """ | |
| for i, model_choice in enumerate(model_choices): | |
| comparison_html += f""" | |
| <div style="padding: 15px; background: white; border-radius: 8px; border: 1px solid #dee2e6;"> | |
| <h4 style="margin: 0 0 10px 0; color: #6c757d;">{model_choice}</h4> | |
| <p style="margin: 0; font-weight: bold; color: #495057;">{results[i] if i < len(results) else 'N/A'}</p> | |
| </div> | |
| """ | |
| comparison_html += """ | |
| </div> | |
| <p style="margin: 15px 0 0 0; color: #6c757d; font-style: italic;"> | |
| Select individual models from the checkbox to see detailed explanations. | |
| </p> | |
| </div> | |
| """ | |
| return ( | |
| "\n".join(results), | |
| comparison_html, | |
| comparison_html, | |
| {"comparison_mode": True, "results": results}, | |
| comparison_html | |
| ) | |
| else: | |
| # Show single model results | |
| result_output = results[0] if results else "No results" | |
| vis_output = visualizations[0] if visualizations else None | |
| plot_output = plots[0] if plots else None | |
| explanation_output = explanations[0] if explanations else None | |
| confidence_output = confidence_charts[0] if confidence_charts else None | |
| return result_output, vis_output, plot_output, explanation_output, confidence_output | |
| # Create Gradio interface | |
| with gr.Blocks(title="Explainability Sandbox for Transformers", css="footer {visibility: hidden}") as demo: | |
| gr.Markdown(""" | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-bottom: 20px;"> | |
| <h1 style="margin: 0; font-size: 2.5em;">π Explainability Sandbox for Transformers</h1> | |
| <p style="margin: 10px 0 0 0; font-size: 1.2em; opacity: 0.9;">Advanced model interpretability with multiple comparison</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Input Settings") | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| lines=5, | |
| placeholder="Enter text to analyze...", | |
| value="The movie was fantastic with great acting and an engaging plot." | |
| ) | |
| model_choices = gr.CheckboxGroup( | |
| choices=list(MODELS.keys()), | |
| label="Select Models", | |
| value=["BERT Base (English)"], | |
| interactive=True | |
| ) | |
| explainer_choice = gr.Radio( | |
| choices=["LIME", "SHAP", "Captum"], | |
| label="Explanation Method", | |
| value="LIME" | |
| ) | |
| compare_mode = gr.Checkbox( | |
| label="Enable Comparison Mode", | |
| value=False, | |
| info="Compare multiple models side-by-side" | |
| ) | |
| analyze_btn = gr.Button("Analyze Text", variant="primary") | |
| gr.Markdown(""" | |
| --- | |
| ### π Export Results | |
| """) | |
| export_btn = gr.Button("Export Results", variant="secondary") | |
| export_output = gr.HTML() | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Results") | |
| output_text = gr.Textbox(label="Prediction Result") | |
| gr.Markdown("#### π Confidence Distribution") | |
| confidence_output = gr.HTML() | |
| gr.Markdown("#### π¨ Token Attributions") | |
| output_vis = gr.HTML(label="Visualization") | |
| gr.Markdown("#### π Attribution Plot") | |
| output_plot = gr.HTML() | |
| gr.Markdown("#### π Explanation Data") | |
| explanation_output = gr.JSON(label="Detailed Data") | |
| # Export functionality | |
| def export_results(explanation_data, plot_html): | |
| if explanation_data and isinstance(explanation_data, dict) and explanation_data.get("comparison_mode"): | |
| return "<div style='color: #6c757d; padding: 10px;'>Export not available in comparison mode. Select individual models to export.</div>" | |
| csv_export = export_to_csv(explanation_data) if explanation_data else "No data to export" | |
| json_export = export_to_json(explanation_data) if explanation_data else "No data to export" | |
| png_export = export_plot_as_png(plot_html) if plot_html else "No plot to export" | |
| return f""" | |
| <div style="padding: 15px; background: #f8f9fa; border-radius: 8px; border: 1px solid #ddd;"> | |
| <h4 style="margin-top: 0;">Export Options:</h4> | |
| <div style="display: flex; gap: 10px; flex-wrap: wrap;"> | |
| <div style="padding: 10px; background: white; border-radius: 5px; border: 1px solid #ccc;">{csv_export}</div> | |
| <div style="padding: 10px; background: white; border-radius: 5px; border: 1px solid #ccc;">{json_export}</div> | |
| <div style="padding: 10px; background: white; border-radius: 5px; border: 1px solid #ccc;">{png_export}</div> | |
| </div> | |
| </div> | |
| """ | |
| # Examples | |
| gr.Markdown("### π Quick Examples") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["This movie was absolutely fantastic! The acting was superb.", ["BERT Base (English)"], "LIME", False], | |
| ["The patient shows symptoms of fever and cough.", ["BERT Base (English)", "RoBERTa Base (English)"], "SHAP", True], | |
| ["The financial report indicates strong growth.", ["DistilBERT (English)", "ALBERT Base (English)"], "Captum", True] | |
| ], | |
| inputs=[text_input, model_choices, explainer_choice, compare_mode], | |
| outputs=[output_text, output_vis, output_plot, explanation_output, confidence_output], | |
| fn=predict_and_explain, | |
| cache_examples=False | |
| ) | |
| # Enhanced Model Card & Ethical Considerations | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| ### π Expanded Model Card & Ethical Considerations | |
| **Datasets Used for Pretraining:** | |
| - BookCorpus (800M words) | |
| - English Wikipedia (2,500M words) | |
| - CommonCrawl News Dataset | |
| - Various domain-specific datasets for fine-tuning | |
| **β οΈ Important Limitations & Warnings:** | |
| **Not for Clinical/Diagnostic Use:** | |
| - This tool is for research and educational purposes only | |
| - NOT suitable for medical diagnosis, clinical decisions, or patient care | |
| - Models may produce incorrect or biased outputs | |
| **Explanation Method Limitations:** | |
| - LIME: Local approximations, may not capture global model behavior | |
| - SHAP: Game-theoretic approach, computationally intensive | |
| - Captum: Gradient-based, sensitive to model architecture | |
| - Different methods may produce conflicting explanations | |
| **Bias Awareness:** | |
| - Models may reproduce and amplify societal biases present in training data | |
| - Performance may vary across demographic groups | |
| - Always validate with domain experts for critical applications | |
| **Interpretability β Ground Truth:** | |
| - Explanations are approximations of model behavior | |
| - They show correlation, not necessarily causation | |
| - Use multiple methods to validate findings | |
| """) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=predict_and_explain, | |
| inputs=[text_input, model_choices, explainer_choice, compare_mode], | |
| outputs=[output_text, output_vis, output_plot, explanation_output, confidence_output] | |
| ) | |
| export_btn.click( | |
| fn=export_results, | |
| inputs=[explanation_output, output_plot], | |
| outputs=[export_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |