Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,11 +6,11 @@ import numpy as np
|
|
| 6 |
import html
|
| 7 |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
|
| 8 |
import pandas as pd
|
| 9 |
-
import matplotlib
|
| 10 |
-
matplotlib.use('Agg')
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from sklearn.decomposition import PCA
|
| 13 |
-
import plotly.graph_objects as go
|
| 14 |
|
| 15 |
# --- Global Settings and Model Loading ---
|
| 16 |
hf_logging.set_verbosity_error()
|
|
@@ -103,7 +103,6 @@ def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token E
|
|
| 103 |
fig.update_layout(
|
| 104 |
title=dict(text=title, x=0.5, font=dict(size=16)),
|
| 105 |
scene=dict(
|
| 106 |
-
# ์์ ๋ ๋ถ๋ถ: title ์์ฑ ๋ด์ text์ font๋ฅผ ํฌํจ
|
| 107 |
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 108 |
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 109 |
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
|
@@ -134,8 +133,9 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 134 |
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
| 135 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 136 |
empty_fig = create_empty_plotly_figure("Model Loading Failed")
|
| 137 |
-
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
try:
|
| 141 |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
|
@@ -147,7 +147,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 147 |
if input_ids.shape[1] == 0:
|
| 148 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 149 |
empty_fig = create_empty_plotly_figure("Invalid Input")
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
| 153 |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
|
@@ -166,7 +167,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 166 |
if input_embeds_for_grad.grad is None:
|
| 167 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 168 |
empty_fig = create_empty_plotly_figure("Gradient Error")
|
| 169 |
-
|
|
|
|
| 170 |
|
| 171 |
grads = input_embeds_for_grad.grad.clone().detach()
|
| 172 |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
|
@@ -210,11 +212,12 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 210 |
|
| 211 |
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
|
| 212 |
|
| 213 |
-
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index
|
| 214 |
|
| 215 |
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
|
| 216 |
-
|
| 217 |
-
|
|
|
|
| 218 |
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
| 219 |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
| 220 |
if token_id not in [cls_token_id, sep_token_id]]
|
|
@@ -242,7 +245,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
| 242 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 243 |
empty_fig = create_empty_plotly_figure("Analysis Error")
|
| 244 |
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
| 245 |
-
|
|
|
|
| 246 |
|
| 247 |
# --- Gradio UI Definition (Translated and Enhanced) ---
|
| 248 |
theme = gr.themes.Monochrome(
|
|
@@ -272,7 +276,7 @@ with gr.Blocks(title="AI Sentence Analyzer XAI ๐", theme=theme, css=".gradio-
|
|
| 272 |
with gr.Column(scale=2):
|
| 273 |
with gr.Accordion("๐ฏ Prediction Outcome", open=True):
|
| 274 |
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
|
| 275 |
-
output_prediction_details = gr.Label(label="
|
| 276 |
with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
|
| 277 |
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
|
| 278 |
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
|
|
|
|
| 6 |
import html
|
| 7 |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
|
| 8 |
import pandas as pd
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use('Agg')
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
from sklearn.decomposition import PCA
|
| 13 |
+
import plotly.graph_objects as go
|
| 14 |
|
| 15 |
# --- Global Settings and Model Loading ---
|
| 16 |
hf_logging.set_verbosity_error()
|
|
|
|
| 103 |
fig.update_layout(
|
| 104 |
title=dict(text=title, x=0.5, font=dict(size=16)),
|
| 105 |
scene=dict(
|
|
|
|
| 106 |
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 107 |
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
| 108 |
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
|
|
|
|
| 133 |
error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
| 134 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 135 |
empty_fig = create_empty_plotly_figure("Model Loading Failed")
|
| 136 |
+
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์ (๋จ์ ๋์
๋๋ฆฌ ๋๋ ๋ฌธ์์ด)
|
| 137 |
+
error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
|
| 138 |
+
return error_html, [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig
|
| 139 |
|
| 140 |
try:
|
| 141 |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
|
|
|
| 147 |
if input_ids.shape[1] == 0:
|
| 148 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 149 |
empty_fig = create_empty_plotly_figure("Invalid Input")
|
| 150 |
+
error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
|
| 151 |
+
return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", error_label_output, [], empty_df, empty_fig
|
| 152 |
|
| 153 |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
| 154 |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
|
|
|
| 167 |
if input_embeds_for_grad.grad is None:
|
| 168 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 169 |
empty_fig = create_empty_plotly_figure("Gradient Error")
|
| 170 |
+
error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
|
| 171 |
+
return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", error_label_output, [], empty_df, empty_fig
|
| 172 |
|
| 173 |
grads = input_embeds_for_grad.grad.clone().detach()
|
| 174 |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
|
|
|
| 212 |
|
| 213 |
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
|
| 214 |
|
| 215 |
+
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
|
| 216 |
|
| 217 |
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
|
| 218 |
+
# ์์ ๋ ๋ถ๋ถ: gr.Label์ ์ ํฉํ ๋์
๋๋ฆฌ ํํ (ํด๋์ค๋ช
: ํ๋ฅ ๊ฐ)
|
| 219 |
+
prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} # ํ๋ฅ ๊ฐ์ float์ผ๋ก ์ ๋ฌ
|
| 220 |
+
|
| 221 |
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
| 222 |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
| 223 |
if token_id not in [cls_token_id, sep_token_id]]
|
|
|
|
| 245 |
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 246 |
empty_fig = create_empty_plotly_figure("Analysis Error")
|
| 247 |
# gr.Label์ ๋ํ ์ค๋ฅ ๋ฐํ๊ฐ ์์
|
| 248 |
+
error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
|
| 249 |
+
return error_html, [], "Analysis Failed", error_label_output, [], empty_df, empty_fig
|
| 250 |
|
| 251 |
# --- Gradio UI Definition (Translated and Enhanced) ---
|
| 252 |
theme = gr.themes.Monochrome(
|
|
|
|
| 276 |
with gr.Column(scale=2):
|
| 277 |
with gr.Accordion("๐ฏ Prediction Outcome", open=True):
|
| 278 |
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
|
| 279 |
+
output_prediction_details = gr.Label(label="Prediction Details & Confidence") # ๋ ์ด๋ธ ์ด๋ฆ ๋ณ๊ฒฝ
|
| 280 |
with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
|
| 281 |
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
|
| 282 |
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
|