Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import joblib | |
| import torch | |
| import numpy as np | |
| import html # μ¬μ ν highlighted_text_data μμ± μ html.escapeλ₯Ό μ¬μ©ν μ μμΌλ―λ‘ μ μ§ | |
| from transformers import AutoTokenizer, AutoModel, logging as hf_logging | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| import plotly.graph_objects as go | |
| # --- Global Settings and Model Loading --- | |
| hf_logging.set_verbosity_error() | |
| MODEL_NAME = "bert-base-uncased" | |
| DEVICE = "cpu" | |
| SAVE_DIR = "μ μ₯μ μ₯1" | |
| LAYER_ID = 4 | |
| SEED = 0 | |
| CLF_NAME = "linear" | |
| CLASS_LABEL_MAP = { | |
| 0: "World", | |
| 1: "Sports", | |
| 2: "Business", | |
| 3: "Sci/Tech" | |
| } | |
| TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None | |
| W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None | |
| MODELS_LOADED_SUCCESSFULLY = False | |
| MODEL_LOADING_ERROR_MESSAGE = "" | |
| try: | |
| print("Gradio App: Initializing model loading...") | |
| lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl") | |
| clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl") | |
| if not os.path.isdir(SAVE_DIR): | |
| raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.") | |
| if not os.path.exists(lda_file_path): | |
| raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.") | |
| if not os.path.exists(clf_file_path): | |
| raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.") | |
| lda = joblib.load(lda_file_path) | |
| clf = joblib.load(clf_file_path) | |
| if hasattr(clf, "base_estimator"): clf = clf.base_estimator | |
| W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE) | |
| MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE) | |
| W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE) | |
| B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE) | |
| TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| MODEL_GLOBAL = AutoModel.from_pretrained( | |
| MODEL_NAME, output_hidden_states=True, output_attentions=False | |
| ).to(DEVICE).eval() | |
| MODELS_LOADED_SUCCESSFULLY = True | |
| print("Gradio App: All models and data loaded successfully!") | |
| except Exception as e: | |
| MODELS_LOADED_SUCCESSFULLY = False | |
| MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct." | |
| print(MODEL_LOADING_ERROR_MESSAGE) | |
| # Helper function: 3D PCA Visualization using Plotly | |
| def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"): | |
| num_annotations = min(len(tokens), 20) | |
| scores_array = np.array(scores).flatten() | |
| text_annotations = [''] * len(tokens) | |
| if len(scores_array) > 0 and len(tokens) > 0: | |
| indices_to_annotate = np.argsort(scores_array)[-num_annotations:] | |
| for i in indices_to_annotate: | |
| if i < len(tokens): | |
| text_annotations[i] = tokens[i] | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=token_embeddings_3d[:, 0], | |
| y=token_embeddings_3d[:, 1], | |
| z=token_embeddings_3d[:, 2], | |
| mode='markers+text', | |
| text=text_annotations, | |
| textfont=dict(size=9, color='#333333'), | |
| textposition='top center', | |
| marker=dict( | |
| size=6, | |
| color=scores_array, | |
| colorscale='RdBu', | |
| reversescale=True, | |
| opacity=0.8, | |
| colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle') | |
| ), | |
| hoverinfo='text', | |
| hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)] | |
| )]) | |
| fig.update_layout( | |
| title=dict(text=title, x=0.5, font=dict(size=16)), | |
| scene=dict( | |
| xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), | |
| yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), | |
| zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), | |
| bgcolor="rgba(255, 255, 255, 0.95)", | |
| camera_eye=dict(x=1.5, y=1.5, z=0.5) | |
| ), | |
| margin=dict(l=5, r=5, b=5, t=45), | |
| paper_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| return fig | |
| # Helper function: Create an empty Plotly figure for placeholders | |
| def create_empty_plotly_figure(message="N/A"): | |
| fig = go.Figure() | |
| fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey")) | |
| fig.update_layout( | |
| xaxis={'visible': False}, | |
| yaxis={'visible': False}, | |
| height=300, | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| plot_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| return fig | |
| # --- Core Analysis Function (returns 6 items for Gradio UI) --- | |
| def analyze_sentence_for_gradio(sentence_text, top_k_value): | |
| if not MODELS_LOADED_SUCCESSFULLY: | |
| # HTML output removed, adjust error return | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig = create_empty_plotly_figure("Model Loading Failed") | |
| error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."} | |
| return [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig # 6 items | |
| try: | |
| tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL | |
| W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL | |
| enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True) | |
| input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE) | |
| if input_ids.shape[1] == 0: | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig = create_empty_plotly_figure("Invalid Input") | |
| error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."} | |
| return [], "Input Error", error_label_output, [], empty_df, empty_fig # 6 items | |
| input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach() | |
| input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True) | |
| outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask, | |
| output_hidden_states=True, output_attentions=False) | |
| cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :] | |
| z_projected = (cls_vec - mu) @ W | |
| logit_output = z_projected @ w_p.T + b_p | |
| probs = torch.softmax(logit_output, dim=1) | |
| pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item() | |
| if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_() | |
| logit_output[0, pred_idx].backward() | |
| if input_embeds_for_grad.grad is None: | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig = create_empty_plotly_figure("Gradient Error") | |
| error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."} | |
| return [],"Analysis Error", error_label_output, [], empty_df, empty_fig # 6 items | |
| grads = input_embeds_for_grad.grad.clone().detach() | |
| scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0) | |
| scores_np = scores.cpu().numpy() | |
| valid_scores_for_norm = scores_np[np.isfinite(scores_np)] | |
| scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np) | |
| tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) | |
| actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id] | |
| actual_scores_np = scores_np[:len(actual_tokens)] | |
| actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy() | |
| # HTML generation logic removed | |
| highlighted_text_data = [] | |
| cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id | |
| for i, tok_str in enumerate(actual_tokens): | |
| clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:] | |
| current_score = actual_scores_np[i] | |
| current_score_clipped = max(0, min(1, current_score)) | |
| current_token_id = input_ids[0, i].item() | |
| if current_token_id == cls_token_id or current_token_id == sep_token_id: | |
| highlighted_text_data.append((clean_tok_str + " ", None)) | |
| else: | |
| highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3))) | |
| top_tokens_for_df, top_tokens_for_barplot_list = [], [] | |
| valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) | |
| if token_id not in [cls_token_id, sep_token_id]] | |
| sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx]) | |
| for token_idx in sorted_valid_indices[:top_k_value]: | |
| token_str = actual_tokens[token_idx] | |
| score_val_str = f"{actual_scores_np[token_idx]:.3f}" | |
| top_tokens_for_df.append([token_str, score_val_str]) | |
| top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]}) | |
| barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score']) | |
| predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})") | |
| prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}" | |
| prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} | |
| pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)") | |
| non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) | |
| if token_id not in [cls_token_id, sep_token_id]] | |
| if len(non_special_token_indices) >= 3 : | |
| pca_tokens = [actual_tokens[i] for i in non_special_token_indices] | |
| if len(pca_tokens) > 0: | |
| pca_embeddings = actual_input_embeds[non_special_token_indices, :] | |
| pca_scores_for_plot = actual_scores_np[non_special_token_indices] | |
| pca = PCA(n_components=3, random_state=SEED) | |
| token_embeddings_3d = pca.fit_transform(pca_embeddings) | |
| pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot) | |
| return (highlighted_text_data, # HTML output removed | |
| prediction_summary_text, prediction_details_for_label, | |
| top_tokens_for_df, barplot_df, | |
| pca_fig) # 6 items | |
| except Exception as e: | |
| import traceback | |
| tb_str = traceback.format_exc() | |
| # HTML output removed | |
| print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}") | |
| empty_df = pd.DataFrame(columns=['token', 'score']) | |
| empty_fig = create_empty_plotly_figure("Analysis Error") | |
| error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"} | |
| return [], "Analysis Failed", error_label_output, [], empty_df, empty_fig # 6 items | |
| # --- Gradio UI Definition (HTML Highlight Tab removed) --- | |
| theme = gr.themes.Monochrome( | |
| primary_hue=gr.themes.colors.blue, | |
| secondary_hue=gr.themes.colors.sky, | |
| neutral_hue=gr.themes.colors.slate | |
| ).set( | |
| body_background_fill="#f0f2f6", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_background_fill="*primary_500", | |
| button_primary_text_color="white", | |
| ) | |
| with gr.Blocks(title="AI Sentence Analyzer XAI π", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo: | |
| gr.Markdown("# π AI Sentence Analyzer XAI: Exploring Model Explanations") | |
| gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. " | |
| "Explore token importance and their distribution in the embedding space.") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=350): | |
| with gr.Group(): | |
| gr.Markdown("### βοΈ Input Sentence & Settings") | |
| input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...") | |
| input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens") | |
| submit_button = gr.Button("Analyze Sentence π«", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Accordion("π― Prediction Outcome", open=True): | |
| output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False) | |
| output_prediction_details = gr.Label(label="Prediction Details & Confidence") | |
| with gr.Accordion("β Top-K Important Tokens (Table)", open=True): | |
| output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens", | |
| row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True) | |
| gr.Markdown("---") | |
| gr.Markdown("## π Detailed Visualizations") | |
| # HTML Highlight (Custom) section removed | |
| with gr.Group(): # HighlightedText | |
| gr.Markdown("### ποΈ Highlighted Text (Gradio)") | |
| output_highlighted_text = gr.HighlightedText( | |
| label="Token Importance (Score: 0-1)", | |
| show_legend=True, | |
| combine_adjacent=False | |
| ) | |
| with gr.Row(): # BarPlot and PCA Plot Side-by-Side | |
| with gr.Column(scale=1, min_width=400): | |
| with gr.Group(): | |
| gr.Markdown("### π Top-K Bar Plot") | |
| output_top_tokens_barplot = gr.BarPlot( | |
| label="Top-K Token Importance Scores", | |
| x="token", | |
| y="score", | |
| tooltip=['token', 'score'], | |
| min_width=300 | |
| ) | |
| with gr.Column(scale=1, min_width=400): | |
| with gr.Group(): | |
| gr.Markdown("### π Token Embeddings 3D PCA (Interactive)") | |
| output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)") | |
| gr.Markdown("---") | |
| gr.Examples( | |
| examples=[ | |
| ["This movie is an absolute masterpiece, captivating from start to finish.", 5], | |
| ["Despite some flaws, the film offers a compelling narrative.", 3], | |
| ["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4] | |
| ], | |
| inputs=[input_sentence, input_top_k], | |
| outputs=[ # output_html_visualization removed | |
| output_highlighted_text, | |
| output_prediction_summary, output_prediction_details, | |
| output_top_tokens_df, output_top_tokens_barplot, | |
| output_pca_plot | |
| ], | |
| fn=analyze_sentence_for_gradio, | |
| cache_examples=False | |
| ) | |
| gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>") | |
| submit_button.click( | |
| fn=analyze_sentence_for_gradio, | |
| inputs=[input_sentence, input_top_k], | |
| outputs=[ # output_html_visualization removed | |
| output_highlighted_text, | |
| output_prediction_summary, output_prediction_details, | |
| output_top_tokens_df, output_top_tokens_barplot, | |
| output_pca_plot | |
| ], | |
| api_name="explain_sentence_xai" | |
| ) | |
| if __name__ == "__main__": | |
| if not MODELS_LOADED_SUCCESSFULLY: | |
| print("*"*80) | |
| print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}") | |
| print("The Gradio UI will be displayed, but analysis will fail.") | |
| print("*"*80) | |
| demo.launch() |