kikikara commited on
Commit
c9044fd
ยท
verified ยท
1 Parent(s): d607281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -114
app.py CHANGED
@@ -6,41 +6,47 @@ import numpy as np
6
  import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
  import pandas as pd
9
- import matplotlib
10
- matplotlib.use('Agg') # Matplotlib ๋ฐฑ์—”๋“œ ์„ค์ • (Gradio์™€ ํ•จ๊ป˜ ์‚ฌ์šฉ ์‹œ ์ค‘์š”)
11
  import matplotlib.pyplot as plt
12
  from sklearn.decomposition import PCA
 
13
 
14
- # --- ๊ธฐ์กด ์„ค์ • ๋ฐ ์ „์—ญ ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ---
15
- # Hugging Face Transformers ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
16
  hf_logging.set_verbosity_error()
17
 
18
- # ์„ค์ •
19
  MODEL_NAME = "bert-base-uncased"
20
  DEVICE = "cpu"
21
- SAVE_DIR = "์ €์žฅ์ €์žฅ1"
22
  LAYER_ID = 4
23
  SEED = 0
24
  CLF_NAME = "linear"
25
 
26
- # ์ „์—ญ ๋ชจ๋ธ ๋กœ๋“œ
 
 
 
 
 
 
 
27
  TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
28
  W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
29
- CLASS_NAMES_GLOBAL = None
30
  MODELS_LOADED_SUCCESSFULLY = False
31
  MODEL_LOADING_ERROR_MESSAGE = ""
32
 
33
  try:
34
- print("Gradio App: ๋ชจ๋ธ ๋กœ๋”ฉ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค...")
35
  lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
36
  clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
37
 
38
  if not os.path.isdir(SAVE_DIR):
39
- raise FileNotFoundError(f"์˜ค๋ฅ˜: ๋ชจ๋ธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ '{SAVE_DIR}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
40
  if not os.path.exists(lda_file_path):
41
- raise FileNotFoundError(f"์˜ค๋ฅ˜: LDA ๋ชจ๋ธ ํŒŒ์ผ '{lda_file_path}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
42
  if not os.path.exists(clf_file_path):
43
- raise FileNotFoundError(f"์˜ค๋ฅ˜: ๋ถ„๋ฅ˜๊ธฐ ๋ชจ๋ธ ํŒŒ์ผ '{clf_file_path}'๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
44
 
45
  lda = joblib.load(lda_file_path)
46
  clf = joblib.load(clf_file_path)
@@ -57,84 +63,101 @@ try:
57
  MODEL_NAME, output_hidden_states=True, output_attentions=False
58
  ).to(DEVICE).eval()
59
 
60
- if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
61
- elif hasattr(clf, 'classes_'): CLASS_NAMES_GLOBAL = clf.classes_
62
-
63
  MODELS_LOADED_SUCCESSFULLY = True
64
- print("Gradio App: ๋ชจ๋“  ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต!")
65
 
66
  except Exception as e:
67
  MODELS_LOADED_SUCCESSFULLY = False
68
- MODEL_LOADING_ERROR_MESSAGE = f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์‹ฌ๊ฐํ•œ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n'์ €์žฅ์ €์žฅ1' ํด๋”์™€ ๋‚ด์šฉ๋ฌผ์„ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
69
  print(MODEL_LOADING_ERROR_MESSAGE)
70
 
71
- # ํ—ฌํผ ํ•จ์ˆ˜: PCA ์‹œ๊ฐํ™” (3D)
72
- def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
73
- fig = plt.figure(figsize=(10, 8))
74
- ax = fig.add_subplot(111, projection='3d')
75
-
76
- num_annotations = min(len(tokens), 15)
77
- if len(scores) > 0 and len(tokens) > 0: # scores์™€ tokens๊ฐ€ ๋น„์–ด์žˆ์ง€ ์•Š์€์ง€ ํ™•์ธ
78
- # scores๊ฐ€ NumPy ๋ฐฐ์—ด์ด ์•„๋‹ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ๋ฆฌ์ŠคํŠธ์ธ ๊ฒฝ์šฐ np.array๋กœ ๋ณ€ํ™˜
79
- scores_np_array = np.array(scores)
80
- indices_to_annotate = np.argsort(scores_np_array)[-num_annotations:]
81
- else:
82
- indices_to_annotate = np.array([])
83
-
84
- scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
85
- c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True)
86
 
87
- for i in range(len(tokens)):
88
- if i in indices_to_annotate:
89
- ax.text(token_embeddings_3d[i, 0], token_embeddings_3d[i, 1], token_embeddings_3d[i, 2],
90
- f' {tokens[i]}', size=8, zorder=1, color='k')
91
-
92
- ax.set_title(title, fontsize=14)
93
- ax.set_xlabel("PCA Component 1", fontsize=10)
94
- ax.set_ylabel("PCA Component 2", fontsize=10)
95
- ax.set_zlabel("PCA Component 3", fontsize=10)
 
96
 
97
- cbar = plt.colorbar(scatter, label="Importance Score", shrink=0.7)
98
- cbar.ax.tick_params(labelsize=8)
99
- ax.tick_params(axis='both', which='major', labelsize=8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
102
  return fig
103
 
104
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ ํ•ต์‹ฌ ๋ถ„์„ ํ•จ์ˆ˜ (๋ฐ˜ํ™˜ ๊ฐ’ 7๊ฐœ) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
105
- def analyze_sentence_for_gradio(sentence_text, top_k_value):
106
- def create_empty_plot(message="N/A"):
107
- fig = plt.figure(figsize=(2,2));
108
- ax = fig.add_subplot(111)
109
- ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
110
- ax.axis('off')
111
- return fig
 
 
 
 
112
 
 
 
113
  if not MODELS_LOADED_SUCCESSFULLY:
114
- error_html = f"<p style='color:red;'>์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
115
  empty_df = pd.DataFrame(columns=['token', 'score'])
116
- empty_fig_placeholder = create_empty_plot()
117
- return error_html, [], "๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ", {"์˜ค๋ฅ˜":"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ"}, [], empty_df, empty_fig_placeholder
118
 
119
  try:
120
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
121
  W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
122
- class_names = CLASS_NAMES_GLOBAL
123
 
124
  enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
125
  input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
126
 
127
  if input_ids.shape[1] == 0:
128
  empty_df = pd.DataFrame(columns=['token', 'score'])
129
- empty_fig_placeholder = create_empty_plot()
130
- return "<p style='color:orange;'>์ž…๋ ฅ ์˜ค๋ฅ˜: ์œ ํšจํ•œ ํ† ํฐ์ด ์—†์Šต๋‹ˆ๋‹ค.</p>", [], "์ž…๋ ฅ ์˜ค๋ฅ˜", {"์˜ค๋ฅ˜":"์ž…๋ ฅ ์˜ค๋ฅ˜"}, [], empty_df, empty_fig_placeholder
131
 
132
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
133
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
134
 
135
  outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
136
  output_hidden_states=True, output_attentions=False)
137
-
138
  cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
139
 
140
  z_projected = (cls_vec - mu) @ W
@@ -146,14 +169,14 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
146
  logit_output[0, pred_idx].backward()
147
  if input_embeds_for_grad.grad is None:
148
  empty_df = pd.DataFrame(columns=['token', 'score'])
149
- empty_fig_placeholder = create_empty_plot()
150
- return "<p style='color:red;'>๋ถ„์„ ์˜ค๋ฅ˜: ๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ ์‹คํŒจ.</p>", [],"๋ถ„์„ ์˜ค๋ฅ˜", {"์˜ค๋ฅ˜":"๋ถ„์„ ์˜ค๋ฅ˜"}, [], empty_df, empty_fig_placeholder
151
 
152
  grads = input_embeds_for_grad.grad.clone().detach()
153
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
154
  scores_np = scores.cpu().numpy()
155
- valid_scores = scores_np[np.isfinite(scores_np)]
156
- scores_np = scores_np / (valid_scores.max() + 1e-9) if len(valid_scores) > 0 and valid_scores.max() > 0 else np.zeros_like(scores_np)
157
 
158
  tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
159
  actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
@@ -173,8 +196,8 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
173
  html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
174
  highlighted_text_data.append((clean_tok_str + " ", None))
175
  else:
176
- color = f"rgba(255, 0, 0, {current_score_clipped:.2f})"
177
- html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
178
  highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
179
 
180
  html_output_str = " ".join(html_tokens_list).replace(" ##", "")
@@ -191,14 +214,12 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
191
 
192
  barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
193
 
194
- predicted_class_label_str = str(pred_idx)
195
- if class_names is not None and 0 <= pred_idx < len(class_names):
196
- predicted_class_label_str = str(class_names[pred_idx])
197
 
198
- prediction_summary_text = f"ํด๋ž˜์Šค: {predicted_class_label_str}\nํ™•๋ฅ : {pred_prob_val:.3f}"
199
- prediction_details_for_label = {"์˜ˆ์ธก ํด๋ž˜์Šค": predicted_class_label_str, "ํ™•๋ฅ ": f"{pred_prob_val:.3f}"}
200
 
201
- pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
202
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
203
  if token_id not in [cls_token_id, sep_token_id]]
204
 
@@ -206,12 +227,11 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
206
  pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
207
  if len(pca_tokens) > 0:
208
  pca_embeddings = actual_input_embeds[non_special_token_indices, :]
209
- pca_scores = actual_scores_np[non_special_token_indices]
210
 
211
  pca = PCA(n_components=3, random_state=SEED)
212
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
213
- # plt.close(pca_fig) # ์ด์ „ ๋นˆ ๊ทธ๋ฆผ ๋‹ซ๊ธฐ
214
- pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
215
 
216
  return (html_output_str, highlighted_text_data,
217
  prediction_summary_text, prediction_details_for_label,
@@ -221,59 +241,71 @@ def analyze_sentence_for_gradio(sentence_text, top_k_value):
221
  except Exception as e:
222
  import traceback
223
  tb_str = traceback.format_exc()
224
- error_html = f"<p style='color:red;'>๋ถ„์„ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
225
- print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
226
  empty_df = pd.DataFrame(columns=['token', 'score'])
227
- empty_fig_placeholder = create_empty_plot("Error during plot generation")
228
- return error_html, [], "๋ถ„์„ ์‹คํŒจ", {"์˜ค๋ฅ˜": str(e)}, [], empty_df, empty_fig_placeholder
229
-
230
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
231
- theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
232
- body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)",
233
- block_background_fill="rgba(255,255,255,0.8)",
234
- block_border_width="1px",
235
- block_shadow="*shadow_drop_lg"
 
 
 
 
 
236
  )
237
 
238
- with gr.Blocks(title="AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI ๐Ÿš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
239
- gr.Markdown("# ๐Ÿš€ AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI: ๋ชจ๋ธ ํ•ด์„ ํƒํ—˜")
240
- gr.Markdown("BERT ๋ชจ๋ธ ์˜ˆ์ธก์˜ ๊ทผ๊ฑฐ๋ฅผ ๋‹ค์–‘ํ•œ ์‹œ๊ฐํ™” ๊ธฐ๋ฒ•์œผ๋กœ ํƒ์ƒ‰ํ•ฉ๋‹ˆ๋‹ค. ํ† ํฐ์˜ ์ค‘์š”๋„์™€ ์ž„๋ฒ ๋”ฉ ๊ณต๊ฐ„์—์„œ์˜ ๋ถ„ํฌ๋ฅผ ํ™•์ธํ•ด๋ณด์„ธ์š”.")
 
 
241
 
242
  with gr.Row(equal_height=False):
243
- with gr.Column(scale=1, min_width=300):
244
  with gr.Group():
245
- gr.Markdown("### โœ๏ธ ๋ฌธ์žฅ ์ž…๋ ฅ & ์„ค์ •")
246
- input_sentence = gr.Textbox(lines=5, label="๋ถ„์„ํ•  ์˜์–ด ๋ฌธ์žฅ", placeholder="์—ฌ๊ธฐ์— ๋ถ„์„ํ•˜๊ณ  ์‹ถ์€ ์˜์–ด ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”...")
247
- input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-K ํ† ํฐ ์ˆ˜")
248
- submit_button = gr.Button("๋ถ„์„ ์‹œ์ž‘ ๐Ÿ’ซ", variant="primary", scale=1)
249
 
250
  with gr.Column(scale=2):
251
- with gr.Accordion("๐ŸŽฏ ์˜ˆ์ธก ๊ฒฐ๊ณผ", open=True):
252
- output_prediction_summary = gr.Textbox(label="๊ฐ„๋‹จ ์š”์•ฝ", lines=2, interactive=False)
253
- output_prediction_details = gr.Label(label="์ƒ์„ธ ์ •๋ณด")
254
- with gr.Accordion("โญ Top-K ์ค‘์š” ํ† ํฐ (ํ‘œ)", open=True):
255
- output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="์ค‘์š”๋„ ๋†’์€ ํ† ํฐ",
256
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
257
 
258
  with gr.Tabs() as tabs:
259
- with gr.TabItem("๐ŸŽจ HTML ํ•˜์ด๋ผ์ดํŠธ", id=0):
260
- output_html_visualization = gr.HTML(label="ํ† ํฐ๋ณ„ ์ค‘์š”๋„ (Gradient x Input)")
261
- with gr.TabItem("๐Ÿ–๏ธ ํ…์ŠคํŠธ ํ•˜์ด๋ผ์ดํŠธ", id=1):
262
  output_highlighted_text = gr.HighlightedText(
263
- label="์ค‘์š”๋„ ๊ธฐ๋ฐ˜ ํ…์ŠคํŠธ ํ•˜์ด๋ผ์ดํŠธ (์ ์ˆ˜: 0~1)",
264
  show_legend=True,
 
 
 
 
265
  combine_adjacent=False
266
  )
267
- with gr.TabItem("๐Ÿ“Š Top-K ๋ง‰๋Œ€ ๊ทธ๋ž˜ํ”„", id=2):
268
  output_top_tokens_barplot = gr.BarPlot(
269
- label="Top-K ํ† ํฐ ์ค‘์š”๋„",
270
  x="token",
271
  y="score",
272
- tooltip=['token', 'score'], # SyntaxError ์ˆ˜์ •๋จ
273
- min_width=300
 
274
  )
275
- with gr.TabItem("๐ŸŒ ํ† ํฐ ์ž„๋ฒ ๋”ฉ 3D PCA", id=3):
276
- output_pca_plot = gr.Plot(label="ํ† ํฐ ์ž„๋ฒ ๋”ฉ 3D PCA (์ค‘์š”๋„ ์ƒ‰์ƒ)")
277
 
278
  gr.Markdown("---")
279
  gr.Examples(
@@ -290,10 +322,9 @@ with gr.Blocks(title="AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI ๐Ÿš€", theme=theme, css=".gradio-c
290
  output_pca_plot
291
  ],
292
  fn=analyze_sentence_for_gradio,
293
- cache_examples=False
294
  )
295
- # gr.Markdown์„ gr.HTML๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ HTML ํƒœ๊ทธ ์ง์ ‘ ์‚ฌ์šฉ
296
- gr.HTML("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>")
297
 
298
  submit_button.click(
299
  fn=analyze_sentence_for_gradio,
@@ -310,7 +341,7 @@ with gr.Blocks(title="AI ๋ฌธ์žฅ ๋ถ„์„๊ธฐ XAI ๐Ÿš€", theme=theme, css=".gradio-c
310
  if __name__ == "__main__":
311
  if not MODELS_LOADED_SUCCESSFULLY:
312
  print("*"*80)
313
- print(f"๊ฒฝ๊ณ : ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ! {MODEL_LOADING_ERROR_MESSAGE}")
314
- print("Gradio UI๋Š” ํ‘œ์‹œ๋˜์ง€๋งŒ ๋ถ„์„ ๊ธฐ๋Šฅ์ด ์ œ๋Œ€๋กœ ์ž‘๋™ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
315
  print("*"*80)
316
  demo.launch()
 
6
  import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
  import pandas as pd
9
+ import matplotlib # Still used for a basic empty plot if Plotly one is too complex for that
10
+ matplotlib.use('Agg') # Matplotlib backend setting
11
  import matplotlib.pyplot as plt
12
  from sklearn.decomposition import PCA
13
+ import plotly.graph_objects as go # For interactive 3D PCA plot
14
 
15
+ # --- Global Settings and Model Loading ---
 
16
  hf_logging.set_verbosity_error()
17
 
 
18
  MODEL_NAME = "bert-base-uncased"
19
  DEVICE = "cpu"
20
+ SAVE_DIR = "์ €์žฅ์ €์žฅ1" # This folder name is from your setup
21
  LAYER_ID = 4
22
  SEED = 0
23
  CLF_NAME = "linear"
24
 
25
+ # Class label mapping provided by user
26
+ CLASS_LABEL_MAP = {
27
+ 0: "World",
28
+ 1: "Sports",
29
+ 2: "Business",
30
+ 3: "Sci/Tech"
31
+ }
32
+
33
  TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
34
  W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
35
+ # CLASS_NAMES_GLOBAL = None # We'll use CLASS_LABEL_MAP instead for clarity
36
  MODELS_LOADED_SUCCESSFULLY = False
37
  MODEL_LOADING_ERROR_MESSAGE = ""
38
 
39
  try:
40
+ print("Gradio App: Initializing model loading...")
41
  lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
42
  clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
43
 
44
  if not os.path.isdir(SAVE_DIR):
45
+ raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.")
46
  if not os.path.exists(lda_file_path):
47
+ raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.")
48
  if not os.path.exists(clf_file_path):
49
+ raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.")
50
 
51
  lda = joblib.load(lda_file_path)
52
  clf = joblib.load(clf_file_path)
 
63
  MODEL_NAME, output_hidden_states=True, output_attentions=False
64
  ).to(DEVICE).eval()
65
 
 
 
 
66
  MODELS_LOADED_SUCCESSFULLY = True
67
+ print("Gradio App: All models and data loaded successfully!")
68
 
69
  except Exception as e:
70
  MODELS_LOADED_SUCCESSFULLY = False
71
+ MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct."
72
  print(MODEL_LOADING_ERROR_MESSAGE)
73
 
74
+ # Helper function: 3D PCA Visualization using Plotly
75
+ def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
76
+ num_annotations = min(len(tokens), 20) # Annotate up to 20 most important tokens
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Ensure scores is a 1D numpy array for Plotly marker color processing
79
+ scores_array = np.array(scores).flatten()
80
+
81
+ # Prepare text annotations (only for most important tokens to avoid clutter)
82
+ text_annotations = [''] * len(tokens)
83
+ if len(scores_array) > 0 and len(tokens) > 0:
84
+ indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
85
+ for i in indices_to_annotate:
86
+ if i < len(tokens): # Ensure index is valid
87
+ text_annotations[i] = tokens[i]
88
 
89
+ fig = go.Figure(data=[go.Scatter3d(
90
+ x=token_embeddings_3d[:, 0],
91
+ y=token_embeddings_3d[:, 1],
92
+ z=token_embeddings_3d[:, 2],
93
+ mode='markers+text', # Show markers, text for selected
94
+ text=text_annotations,
95
+ textfont=dict(size=9, color='#333333'),
96
+ textposition='top center',
97
+ marker=dict(
98
+ size=6,
99
+ color=scores_array,
100
+ colorscale='RdBu',
101
+ reversescale=True, # Makes red high, blue low (like coolwarm_r)
102
+ opacity=0.8,
103
+ colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
104
+ ),
105
+ hoverinfo='text', # Show full token text on hover
106
+ hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)] # Custom hover text
107
+ )])
108
 
109
+ fig.update_layout(
110
+ title=dict(text=title, x=0.5, font=dict(size=16)),
111
+ scene=dict(
112
+ xaxis=dict(title='PCA Comp 1', titlefont=dict(size=10), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
113
+ yaxis=dict(title='PCA Comp 2', titlefont=dict(size=10), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
114
+ zaxis=dict(title='PCA Comp 3', titlefont=dict(size=10), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
115
+ bgcolor="rgba(255, 255, 255, 0.95)",
116
+ camera_eye=dict(x=1.5, y=1.5, z=0.5) # Initial camera angle
117
+ ),
118
+ margin=dict(l=5, r=5, b=5, t=45),
119
+ paper_bgcolor='rgba(0,0,0,0)' # Transparent paper background
120
+ )
121
  return fig
122
 
123
+ # Helper function: Create an empty Plotly figure for placeholders
124
+ def create_empty_plotly_figure(message="N/A"):
125
+ fig = go.Figure()
126
+ fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey"))
127
+ fig.update_layout(
128
+ xaxis={'visible': False},
129
+ yaxis={'visible': False},
130
+ height=300, # Define a height for empty plot
131
+ paper_bgcolor='rgba(0,0,0,0)',
132
+ plot_bgcolor='rgba(0,0,0,0)'
133
+ )
134
+ return fig
135
 
136
+ # --- Core Analysis Function (returns 7 items for Gradio UI) ---
137
+ def analyze_sentence_for_gradio(sentence_text, top_k_value):
138
  if not MODELS_LOADED_SUCCESSFULLY:
139
+ error_html = f"<p style='color:red;'>Initialization Error: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
140
  empty_df = pd.DataFrame(columns=['token', 'score'])
141
+ empty_fig = create_empty_plotly_figure("Model Loading Failed")
142
+ return error_html, [], "Model Loading Failed", {"Error":"Model Loading Failed"}, [], empty_df, empty_fig
143
 
144
  try:
145
  tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
146
  W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
 
147
 
148
  enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
149
  input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
150
 
151
  if input_ids.shape[1] == 0:
152
  empty_df = pd.DataFrame(columns=['token', 'score'])
153
+ empty_fig = create_empty_plotly_figure("Invalid Input")
154
+ return "<p style='color:orange;'>Input Error: No valid tokens found.</p>", [], "Input Error", {"Error":"Input Error"}, [], empty_df, empty_fig
155
 
156
  input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
157
  input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
158
 
159
  outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
160
  output_hidden_states=True, output_attentions=False)
 
161
  cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
162
 
163
  z_projected = (cls_vec - mu) @ W
 
169
  logit_output[0, pred_idx].backward()
170
  if input_embeds_for_grad.grad is None:
171
  empty_df = pd.DataFrame(columns=['token', 'score'])
172
+ empty_fig = create_empty_plotly_figure("Gradient Error")
173
+ return "<p style='color:red;'>Analysis Error: Gradient calculation failed.</p>", [],"Analysis Error", {"Error":"Analysis Error"}, [], empty_df, empty_fig
174
 
175
  grads = input_embeds_for_grad.grad.clone().detach()
176
  scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
177
  scores_np = scores.cpu().numpy()
178
+ valid_scores_for_norm = scores_np[np.isfinite(scores_np)] # Renamed to avoid conflict
179
+ scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
180
 
181
  tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
182
  actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
 
196
  html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
197
  highlighted_text_data.append((clean_tok_str + " ", None))
198
  else:
199
+ color = f"rgba(220, 50, 50, {current_score_clipped:.2f})" # Slightly adjusted red
200
+ html_tokens_list.append(f"<span style='background-color:{color}; color:white; padding: 1px 3px; margin: 1px; border-radius: 4px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
201
  highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
202
 
203
  html_output_str = " ".join(html_tokens_list).replace(" ##", "")
 
214
 
215
  barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
216
 
217
+ predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index: {pred_idx}")
 
 
218
 
219
+ prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
220
+ prediction_details_for_label = {"Predicted Class": predicted_class_label_str, "Probability": f"{pred_prob_val:.3f}"}
221
 
222
+ pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
223
  non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
224
  if token_id not in [cls_token_id, sep_token_id]]
225
 
 
227
  pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
228
  if len(pca_tokens) > 0:
229
  pca_embeddings = actual_input_embeds[non_special_token_indices, :]
230
+ pca_scores_for_plot = actual_scores_np[non_special_token_indices] # Use this for coloring
231
 
232
  pca = PCA(n_components=3, random_state=SEED)
233
  token_embeddings_3d = pca.fit_transform(pca_embeddings)
234
+ pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot)
 
235
 
236
  return (html_output_str, highlighted_text_data,
237
  prediction_summary_text, prediction_details_for_label,
 
241
  except Exception as e:
242
  import traceback
243
  tb_str = traceback.format_exc()
244
+ error_html = f"<p style='color:red;'>Analysis Error: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
245
+ print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
246
  empty_df = pd.DataFrame(columns=['token', 'score'])
247
+ empty_fig = create_empty_plotly_figure("Analysis Error")
248
+ return error_html, [], "Analysis Failed", {"Error": str(e)}, [], empty_df, empty_fig
249
+
250
+ # --- Gradio UI Definition (Translated and Enhanced) ---
251
+ # Using a built-in theme and some CSS for aesthetics
252
+ theme = gr.themes.Monochrome(
253
+ primary_hue=gr.themes.colors.blue,
254
+ secondary_hue=gr.themes.colors.sky,
255
+ neutral_hue=gr.themes.colors.slate
256
+ ).set(
257
+ body_background_fill="#f0f2f6",
258
+ block_shadow="*shadow_drop_lg",
259
+ button_primary_background_fill="*primary_500",
260
+ button_primary_text_color="white",
261
  )
262
 
263
+
264
+ with gr.Blocks(title="AI Sentence Analyzer XAI ๐Ÿš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
265
+ gr.Markdown("# ๐Ÿš€ AI Sentence Analyzer XAI: Exploring Model Explanations")
266
+ gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
267
+ "Explore token importance and their distribution in the embedding space.")
268
 
269
  with gr.Row(equal_height=False):
270
+ with gr.Column(scale=1, min_width=350): # Increased min_width slightly
271
  with gr.Group():
272
+ gr.Markdown("### โœ๏ธ Input Sentence & Settings")
273
+ input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
274
+ input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens")
275
+ submit_button = gr.Button("Analyze Sentence ๐Ÿ’ซ", variant="primary")
276
 
277
  with gr.Column(scale=2):
278
+ with gr.Accordion("๐ŸŽฏ Prediction Outcome", open=True):
279
+ output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
280
+ output_prediction_details = gr.Label(label="Detailed Prediction")
281
+ with gr.Accordion("โญ Top-K Important Tokens (Table)", open=True):
282
+ output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
283
  row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
284
 
285
  with gr.Tabs() as tabs:
286
+ with gr.TabItem("๐ŸŽจ HTML Highlight (Custom)", id=0):
287
+ output_html_visualization = gr.HTML(label="Token Importance (Gradient x Input based)")
288
+ with gr.TabItem("๐Ÿ–๏ธ Highlighted Text (Gradio)", id=1):
289
  output_highlighted_text = gr.HighlightedText(
290
+ label="Token Importance (Score: 0-1)",
291
  show_legend=True,
292
+ # Color map can be more sophisticated if scores are categorical
293
+ # For numerical scores (0-1), Gradio tries to infer intensity.
294
+ # Example color map (if scores were categories like "LOW", "MEDIUM", "HIGH"):
295
+ # color_map={"LOW": "lightblue", "MEDIUM": "lightgreen", "HIGH": "pink"},
296
  combine_adjacent=False
297
  )
298
+ with gr.TabItem("๐Ÿ“Š Top-K Bar Plot", id=2):
299
  output_top_tokens_barplot = gr.BarPlot(
300
+ label="Top-K Token Importance Scores",
301
  x="token",
302
  y="score",
303
+ tooltip=['token', 'score'],
304
+ min_width=300,
305
+ # title="Top-K Most Important Tokens" # BarPlot may not have a direct title prop
306
  )
307
+ with gr.TabItem("๐ŸŒ Token Embeddings 3D PCA (Interactive)", id=3):
308
+ output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
309
 
310
  gr.Markdown("---")
311
  gr.Examples(
 
322
  output_pca_plot
323
  ],
324
  fn=analyze_sentence_for_gradio,
325
+ cache_examples=False # Set to True for faster loading of examples if inputs/outputs are static
326
  )
327
+ gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
 
328
 
329
  submit_button.click(
330
  fn=analyze_sentence_for_gradio,
 
341
  if __name__ == "__main__":
342
  if not MODELS_LOADED_SUCCESSFULLY:
343
  print("*"*80)
344
+ print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}")
345
+ print("The Gradio UI will be displayed, but analysis will fail.")
346
  print("*"*80)
347
  demo.launch()