dyra1222's picture
fixed new changes
6182788
# app.py (fixed version)
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from utils.explainers import LimeExplainer, ShapExplainer, CaptumExplainer
from utils.visualization import create_visualization, create_attribution_plot, create_confidence_chart
from utils.export import export_to_csv, export_to_json, export_plot_as_png
# Available models with dataset information
MODELS = {
"BERT Base (English)": {
"path": "bert-base-uncased",
"trained_on": ["BookCorpus", "English Wikipedia"],
"domain": "General text"
},
"DistilBERT (English)": {
"path": "distilbert-base-uncased",
"trained_on": ["BookCorpus", "English Wikipedia"],
"domain": "General text"
},
"RoBERTa Base (English)": {
"path": "roberta-base",
"trained_on": ["BookCorpus", "English Wikipedia", "CommonCrawl", "OpenWebText"],
"domain": "General text"
},
"ALBERT Base (English)": {
"path": "albert-base-v2",
"trained_on": ["BookCorpus", "English Wikipedia"],
"domain": "General text"
},
}
# Global variables to cache models
model_cache = {}
def load_model(model_name):
"""Load model and tokenizer with caching"""
if model_name in model_cache:
return model_cache[model_name]
try:
model_info = MODELS[model_name]
print(f"Loading model: {model_info['path']}")
tokenizer = AutoTokenizer.from_pretrained(model_info['path'])
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForSequenceClassification.from_pretrained(
model_info['path'],
num_labels=2,
output_attentions=False,
output_hidden_states=False
)
# Cache the model
model_cache[model_name] = (tokenizer, model, model_info)
return tokenizer, model, model_info
except Exception as e:
print(f"Error loading model: {e}")
return None, None, None
def predict_and_explain(text, model_choices, explainer_choice, compare_mode):
"""Main function to make predictions and generate explanations"""
if not text.strip():
return "Please enter some text to analyze.", None, None, None, None, None
results = []
visualizations = []
plots = []
explanations = []
confidence_charts = []
for model_choice in model_choices:
# Load selected model
tokenizer, model, model_info = load_model(model_choice)
if model is None:
results.append(f"Error loading {model_choice}")
visualizations.append(None)
plots.append(None)
explanations.append(None)
confidence_charts.append(None)
continue
# Prepare inputs
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
# Get prediction
model.eval()
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1).numpy()[0]
predicted_class = np.argmax(probabilities)
confidence = probabilities[predicted_class]
# Format prediction result
result = f"{model_choice}: Class {predicted_class} ({confidence:.2%})"
results.append(result)
# Create confidence chart
confidence_html = create_confidence_chart(probabilities, ["Negative", "Positive"])
confidence_charts.append(confidence_html)
# Generate explanation
try:
if explainer_choice == "LIME":
explainer = LimeExplainer(model, tokenizer)
explanation = explainer.explain(text, num_features=15)
elif explainer_choice == "SHAP":
explainer = ShapExplainer(model, tokenizer)
explanation = explainer.explain(text)
else: # Captum
explainer = CaptumExplainer(model, tokenizer)
explanation = explainer.explain(text)
except Exception as e:
print(f"Error generating explanation for {model_choice}: {e}")
explanation = []
explanations.append(explanation)
# Create visualizations
visualization_html = create_visualization(text, explanation, tokenizer, explainer_choice)
plot_html = create_attribution_plot(explanation, explainer_choice)
visualizations.append(visualization_html)
plots.append(plot_html)
except Exception as e:
print(f"Prediction error for {model_choice}: {e}")
results.append(f"{model_choice}: Error - {str(e)}")
visualizations.append(None)
plots.append(None)
explanations.append(None)
confidence_charts.append(None)
# Format outputs based on comparison mode
if compare_mode and len(model_choices) > 1:
# Show comparison summary
comparison_html = """
<div style="padding: 20px; background: #f8f9fa; border-radius: 10px; border: 2px solid #e9ecef;">
<h3 style="margin-top: 0; color: #495057;">πŸ” Model Comparison Results</h3>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 15px;">
"""
for i, model_choice in enumerate(model_choices):
comparison_html += f"""
<div style="padding: 15px; background: white; border-radius: 8px; border: 1px solid #dee2e6;">
<h4 style="margin: 0 0 10px 0; color: #6c757d;">{model_choice}</h4>
<p style="margin: 0; font-weight: bold; color: #495057;">{results[i] if i < len(results) else 'N/A'}</p>
</div>
"""
comparison_html += """
</div>
<p style="margin: 15px 0 0 0; color: #6c757d; font-style: italic;">
Select individual models from the checkbox to see detailed explanations.
</p>
</div>
"""
return (
"\n".join(results),
comparison_html,
comparison_html,
{"comparison_mode": True, "results": results},
comparison_html
)
else:
# Show single model results
result_output = results[0] if results else "No results"
vis_output = visualizations[0] if visualizations else None
plot_output = plots[0] if plots else None
explanation_output = explanations[0] if explanations else None
confidence_output = confidence_charts[0] if confidence_charts else None
return result_output, vis_output, plot_output, explanation_output, confidence_output
# Create Gradio interface
with gr.Blocks(title="Explainability Sandbox for Transformers", css="footer {visibility: hidden}") as demo:
gr.Markdown("""
<div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 2.5em;">πŸ” Explainability Sandbox for Transformers</h1>
<p style="margin: 10px 0 0 0; font-size: 1.2em; opacity: 0.9;">Advanced model interpretability with multiple comparison</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### βš™οΈ Input Settings")
text_input = gr.Textbox(
label="Input Text",
lines=5,
placeholder="Enter text to analyze...",
value="The movie was fantastic with great acting and an engaging plot."
)
model_choices = gr.CheckboxGroup(
choices=list(MODELS.keys()),
label="Select Models",
value=["BERT Base (English)"],
interactive=True
)
explainer_choice = gr.Radio(
choices=["LIME", "SHAP", "Captum"],
label="Explanation Method",
value="LIME"
)
compare_mode = gr.Checkbox(
label="Enable Comparison Mode",
value=False,
info="Compare multiple models side-by-side"
)
analyze_btn = gr.Button("Analyze Text", variant="primary")
gr.Markdown("""
---
### πŸ“Š Export Results
""")
export_btn = gr.Button("Export Results", variant="secondary")
export_output = gr.HTML()
with gr.Column(scale=2):
gr.Markdown("### πŸ“ˆ Results")
output_text = gr.Textbox(label="Prediction Result")
gr.Markdown("#### πŸ“Š Confidence Distribution")
confidence_output = gr.HTML()
gr.Markdown("#### 🎨 Token Attributions")
output_vis = gr.HTML(label="Visualization")
gr.Markdown("#### πŸ“‰ Attribution Plot")
output_plot = gr.HTML()
gr.Markdown("#### πŸ” Explanation Data")
explanation_output = gr.JSON(label="Detailed Data")
# Export functionality
def export_results(explanation_data, plot_html):
if explanation_data and isinstance(explanation_data, dict) and explanation_data.get("comparison_mode"):
return "<div style='color: #6c757d; padding: 10px;'>Export not available in comparison mode. Select individual models to export.</div>"
csv_export = export_to_csv(explanation_data) if explanation_data else "No data to export"
json_export = export_to_json(explanation_data) if explanation_data else "No data to export"
png_export = export_plot_as_png(plot_html) if plot_html else "No plot to export"
return f"""
<div style="padding: 15px; background: #f8f9fa; border-radius: 8px; border: 1px solid #ddd;">
<h4 style="margin-top: 0;">Export Options:</h4>
<div style="display: flex; gap: 10px; flex-wrap: wrap;">
<div style="padding: 10px; background: white; border-radius: 5px; border: 1px solid #ccc;">{csv_export}</div>
<div style="padding: 10px; background: white; border-radius: 5px; border: 1px solid #ccc;">{json_export}</div>
<div style="padding: 10px; background: white; border-radius: 5px; border: 1px solid #ccc;">{png_export}</div>
</div>
</div>
"""
# Examples
gr.Markdown("### πŸš€ Quick Examples")
examples = gr.Examples(
examples=[
["This movie was absolutely fantastic! The acting was superb.", ["BERT Base (English)"], "LIME", False],
["The patient shows symptoms of fever and cough.", ["BERT Base (English)", "RoBERTa Base (English)"], "SHAP", True],
["The financial report indicates strong growth.", ["DistilBERT (English)", "ALBERT Base (English)"], "Captum", True]
],
inputs=[text_input, model_choices, explainer_choice, compare_mode],
outputs=[output_text, output_vis, output_plot, explanation_output, confidence_output],
fn=predict_and_explain,
cache_examples=False
)
# Enhanced Model Card & Ethical Considerations
gr.Markdown("---")
gr.Markdown("""
### πŸ“‹ Expanded Model Card & Ethical Considerations
**Datasets Used for Pretraining:**
- BookCorpus (800M words)
- English Wikipedia (2,500M words)
- CommonCrawl News Dataset
- Various domain-specific datasets for fine-tuning
**⚠️ Important Limitations & Warnings:**
**Not for Clinical/Diagnostic Use:**
- This tool is for research and educational purposes only
- NOT suitable for medical diagnosis, clinical decisions, or patient care
- Models may produce incorrect or biased outputs
**Explanation Method Limitations:**
- LIME: Local approximations, may not capture global model behavior
- SHAP: Game-theoretic approach, computationally intensive
- Captum: Gradient-based, sensitive to model architecture
- Different methods may produce conflicting explanations
**Bias Awareness:**
- Models may reproduce and amplify societal biases present in training data
- Performance may vary across demographic groups
- Always validate with domain experts for critical applications
**Interpretability β‰  Ground Truth:**
- Explanations are approximations of model behavior
- They show correlation, not necessarily causation
- Use multiple methods to validate findings
""")
# Event handlers
analyze_btn.click(
fn=predict_and_explain,
inputs=[text_input, model_choices, explainer_choice, compare_mode],
outputs=[output_text, output_vis, output_plot, explanation_output, confidence_output]
)
export_btn.click(
fn=export_results,
inputs=[explanation_output, output_plot],
outputs=[export_output]
)
if __name__ == "__main__":
demo.launch(share=False)