kikikara commited on
Commit
4223edb
Β·
verified Β·
1 Parent(s): 6ae9375

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -117
app.py CHANGED
@@ -5,22 +5,27 @@ import torch
5
  import numpy as np
6
  import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
 
 
 
 
 
 
8
 
 
9
  # Hugging Face Transformers λ‘œκΉ… 레벨 μ„€μ •
10
  hf_logging.set_verbosity_error()
11
 
12
- # ────────── μ„€μ • ──────────
13
  MODEL_NAME = "bert-base-uncased"
14
  DEVICE = "cpu"
15
- SAVE_DIR = "μ €μž₯μ €μž₯1" # 이 폴더가 app.py와 같은 μœ„μΉ˜μ— μžˆμ–΄μ•Ό ν•©λ‹ˆλ‹€.
16
  LAYER_ID = 4
17
  SEED = 0
18
  CLF_NAME = "linear"
19
 
20
- # ────────── μ „μ—­ λͺ¨λΈ λ‘œλ“œ (Gradio μ•± μ‹œμž‘ μ‹œ ν•œ 번 μ‹€ν–‰) ──────────
21
- # Streamlit의 @st.cache_resource λŒ€μ‹ , μ•± μ‹œμž‘ μ‹œ λ‘œλ“œλ˜λ„λ‘ μ „μ—­ λ³€μˆ˜λ‘œ 관리
22
- TOKENIZER_GLOBAL = None
23
- MODEL_GLOBAL = None
24
  W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
25
  CLASS_NAMES_GLOBAL = None
26
  MODELS_LOADED_SUCCESSFULLY = False
@@ -32,7 +37,7 @@ try:
32
  clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
33
 
34
  if not os.path.isdir(SAVE_DIR):
35
- raise FileNotFoundError(f"였λ₯˜: λͺ¨λΈ μ €μž₯ 디렉토리 '{SAVE_DIR}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. 'μ €μž₯μ €μž₯1' 폴더λ₯Ό ν™•μΈν•˜μ„Έμš”.")
36
  if not os.path.exists(lda_file_path):
37
  raise FileNotFoundError(f"였λ₯˜: LDA λͺ¨λΈ 파일 '{lda_file_path}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
38
  if not os.path.exists(clf_file_path):
@@ -41,8 +46,7 @@ try:
41
  lda = joblib.load(lda_file_path)
42
  clf = joblib.load(clf_file_path)
43
 
44
- if hasattr(clf, "base_estimator"):
45
- clf = clf.base_estimator
46
 
47
  W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
48
  MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
@@ -50,184 +54,259 @@ try:
50
  B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
51
 
52
  TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
53
- MODEL_GLOBAL = AutoModel.from_pretrained(
54
- MODEL_NAME, output_hidden_states=True
55
  ).to(DEVICE).eval()
56
 
57
- if hasattr(lda, 'classes_'):
58
- CLASS_NAMES_GLOBAL = lda.classes_
59
- elif hasattr(clf, 'classes_'):
60
- CLASS_NAMES_GLOBAL = clf.classes_
61
 
62
  MODELS_LOADED_SUCCESSFULLY = True
63
  print("Gradio App: λͺ¨λ“  λͺ¨λΈ 및 데이터 λ‘œλ“œ 성곡!")
64
 
65
  except Exception as e:
 
66
  MODEL_LOADING_ERROR_MESSAGE = f"λͺ¨λΈ λ‘œλ”© 쀑 μ‹¬κ°ν•œ 였λ₯˜ λ°œμƒ: {str(e)}\n'μ €μž₯μ €μž₯1' 폴더와 λ‚΄μš©λ¬Όμ„ ν™•μΈν•΄μ£Όμ„Έμš”."
67
  print(MODEL_LOADING_ERROR_MESSAGE)
68
- # 이 였λ₯˜λŠ” Gradio UIλ₯Ό 톡해 μ‚¬μš©μžμ—κ²Œ 전달될 수 μžˆλ„λ‘ μ²˜λ¦¬ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
69
 
70
- # ────────── 핡심 뢄석 ν•¨μˆ˜ (Gradio μΈν„°νŽ˜μ΄μŠ€κ°€ 호좜) ──────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def analyze_sentence_for_gradio(sentence_text, top_k_value):
 
 
 
 
 
 
 
 
 
72
  if not MODELS_LOADED_SUCCESSFULLY:
73
- # λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨ μ‹œ Gradio 좜λ ₯ ν˜•μ‹μ— 맞좰 였λ₯˜ λ©”μ‹œμ§€ λ°˜ν™˜
74
  error_html = f"<p style='color:red;'>μ΄ˆκΈ°ν™” 였λ₯˜: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
75
- # Gradio InterfaceλŠ” μ •μ˜λœ λͺ¨λ“  좜λ ₯에 λŒ€ν•΄ 값을 λ°›μ•„μ•Ό ν•©λ‹ˆλ‹€.
76
- return error_html, "λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨", "N/A", [] # HTML, μ˜ˆμΈ‘κ²°κ³Όν…μŠ€νŠΈ, 상세결과(Label), TopK(DataFrame)
 
77
 
78
  try:
79
- # μ „μ—­μ—μ„œ λ‘œλ“œλœ λͺ¨λΈ μ‚¬μš©
80
- tokenizer = TOKENIZER_GLOBAL
81
- model = MODEL_GLOBAL
82
  W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
83
  class_names = CLASS_NAMES_GLOBAL
84
 
85
- # 1) 토큰화
86
  enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
87
- input_ids = enc["input_ids"].to(DEVICE)
88
- attn_mask = enc["attention_mask"].to(DEVICE)
89
 
90
  if input_ids.shape[1] == 0:
91
- return "<p style='color:orange;'>μž…λ ₯ 였λ₯˜: μœ νš¨ν•œ 토큰이 μ—†μŠ΅λ‹ˆλ‹€.</p>", "μž…λ ₯ 였λ₯˜", "N/A", []
92
-
93
- # 2) μž„λ² λ”© 및 κ·Έλž˜λ””μ–ΈνŠΈ μ„€μ •
94
- input_embeds = model.embeddings.word_embeddings(input_ids).clone().detach()
95
- input_embeds.requires_grad_(True)
96
 
97
- # 3) Forward pass
98
- outputs = model(inputs_embeds=input_embeds, attention_mask=attn_mask, output_hidden_states=True)
 
 
 
 
99
  cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
100
-
101
- # 4) LDA 투영 및 λΆ„λ₯˜
102
  z_projected = (cls_vec - mu) @ W
103
  logit_output = z_projected @ w_p.T + b_p
104
  probs = torch.softmax(logit_output, dim=1)
105
- pred_idx = torch.argmax(probs, dim=1).item()
106
- pred_prob_val = probs[0, pred_idx].item()
107
 
108
- # 5) Gradient 계산
109
- if input_embeds.grad is not None:
110
- input_embeds.grad.zero_()
111
  logit_output[0, pred_idx].backward()
112
- if input_embeds.grad is None:
113
- return "<p style='color:red;'>뢄석 였λ₯˜: κ·Έλž˜λ””μ–ΈνŠΈ 계산 μ‹€νŒ¨.</p>", "뢄석 였λ₯˜", "N/A", []
114
- grads = input_embeds.grad.clone().detach()
115
-
116
- # 6) μ€‘μš”λ„ 점수 계산
117
- scores = (grads * input_embeds.detach()).norm(dim=2).squeeze(0)
 
118
  scores_np = scores.cpu().numpy()
119
  valid_scores = scores_np[np.isfinite(scores_np)]
120
- if len(valid_scores) > 0 and valid_scores.max() > 0:
121
- scores_np = scores_np / (valid_scores.max() + 1e-9)
122
- else:
123
- scores_np = np.zeros_like(scores_np)
124
 
125
- # 7) HTML 생성
126
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
127
- html_tokens_list = []
128
- cls_token_id, sep_token_id, pad_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id
129
 
130
- for i, tok_str in enumerate(tokens):
131
- if input_ids[0, i] == pad_token_id: continue
 
 
132
  clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
133
- if input_ids[0, i] == cls_token_id or input_ids[0, i] == sep_token_id:
 
 
 
 
134
  html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
 
135
  else:
136
- score_val = scores_np[i] if i < len(scores_np) else 0
137
- color = f"rgba(255, 0, 0, {max(0, min(1, score_val)):.2f})"
138
  html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
139
-
 
140
  html_output_str = " ".join(html_tokens_list).replace(" ##", "")
141
 
142
- # Top-K 토큰 (DataFrame용 리슀트의 리슀트)
143
- top_tokens_for_df = []
144
- valid_indices = [idx for idx, token_id in enumerate(input_ids[0].tolist())
145
- if token_id not in [cls_token_id, sep_token_id, pad_token_id] and idx < len(scores_np)]
146
- sorted_valid_indices = sorted(valid_indices, key=lambda idx: -scores_np[idx])
147
  for token_idx in sorted_valid_indices[:top_k_value]:
148
- top_tokens_for_df.append([tokens[token_idx], f"{scores_np[token_idx]:.3f}"])
 
 
 
 
 
149
 
150
- # 예츑 클래슀 λ ˆμ΄λΈ”
151
  predicted_class_label_str = str(pred_idx)
152
  if class_names is not None and 0 <= pred_idx < len(class_names):
153
  predicted_class_label_str = str(class_names[pred_idx])
154
 
155
  prediction_summary_text = f"클래슀: {predicted_class_label_str}\nν™•λ₯ : {pred_prob_val:.3f}"
156
  prediction_details_for_label = {"예츑 클래슀": predicted_class_label_str, "ν™•λ₯ ": f"{pred_prob_val:.3f}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
-
159
- return html_output_str, prediction_summary_text, prediction_details_for_label, top_tokens_for_df
 
 
160
 
161
  except Exception as e:
162
  import traceback
163
  tb_str = traceback.format_exc()
164
  error_html = f"<p style='color:red;'>뢄석 쀑 였λ₯˜ λ°œμƒ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
165
  print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
166
- return error_html, "뢄석 μ‹€νŒ¨", {"였λ₯˜": str(e)}, []
 
 
167
 
168
 
169
- # ────────── Gradio μΈν„°νŽ˜μ΄μŠ€ μ •μ˜ ──────────
170
- # μž…λ ₯ μ»΄ν¬λ„ŒνŠΈ
171
- input_sentence = gr.Textbox(lines=3, label="뢄석할 μ˜μ–΄ λ¬Έμž₯", placeholder="여기에 μ˜μ–΄ λ¬Έμž₯을 μž…λ ₯ν•˜μ„Έμš”...")
172
- input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="ν‘œμ‹œν•  Top-K μ€‘μš” 토큰 수")
 
 
 
173
 
174
- # 좜λ ₯ μ»΄ν¬λ„ŒνŠΈ
175
- output_html_visualization = gr.HTML(label="토큰 μ€‘μš”λ„ μ‹œκ°ν™”")
176
- output_prediction_summary = gr.Textbox(label="예츑 μš”μ•½", lines=2) # κ°„λ‹¨ν•œ ν…μŠ€νŠΈ μš”μ•½μš©
177
- output_prediction_details = gr.Label(label="예츑 상세") # Label은 λ”•μ…”λ„ˆλ¦¬λ₯Ό 잘 λ³΄μ—¬μ€Œ
178
- output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Top-K μ€‘μš” 토큰", row_count=(1,"dynamic"), col_count=(2,"fixed"))
179
 
 
 
 
180
 
181
- # Gradio Blocksλ₯Ό μ‚¬μš©ν•˜μ—¬ λ ˆμ΄μ•„μ›ƒ ꡬ성 (선택 사항, Interface보닀 μœ μ—°ν•¨)
182
- with gr.Blocks(title="λ¬Έμž₯ 토큰 μ€‘μš”λ„ 뢄석기 (Gradio)", theme=gr.themes.Soft()) as demo:
183
- gr.Markdown("# πŸ“ λ¬Έμž₯ 토큰 μ€‘μš”λ„ 뢄석기 (Gradio)")
184
- gr.Markdown("BERT와 LDAλ₯Ό ν™œμš©ν•˜μ—¬ λ¬Έμž₯ λ‚΄ 각 ν† ν°μ˜ μ€‘μš”λ„λ₯Ό μ‹œκ°ν™”ν•©λ‹ˆλ‹€.")
185
-
186
- with gr.Row():
 
 
187
  with gr.Column(scale=2):
188
- input_sentence.render()
189
- input_top_k.render()
190
- submit_button = gr.Button("뢄석 μ‹€ν–‰ν•˜κΈ° πŸš€", variant="primary")
191
- with gr.Column(scale=3):
192
- output_prediction_summary.render()
193
- output_prediction_details.render()
194
-
195
- output_html_visualization.render()
196
- output_top_tokens_df.render()
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  gr.Markdown("---")
199
- gr.Markdown("<p style='text-align: center; color: grey;'>BERT 기반 λ¬Έμž₯ 뢄석 데λͺ¨ (Gradio)</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # λ²„νŠΌ 클릭 μ‹œ ν•¨μˆ˜ μ—°κ²°
202
  submit_button.click(
203
  fn=analyze_sentence_for_gradio,
204
  inputs=[input_sentence, input_top_k],
205
- outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df]
206
- )
207
-
208
- # 예제 μΆ”κ°€
209
- gr.Examples(
210
- examples=[
211
- ["This is a great movie and I really enjoyed it!", 5],
212
- ["The weather is quite gloomy today.", 3],
213
- ["I am not sure if this is the right way to do it, but let's try.", 4]
214
  ],
215
- inputs=[input_sentence, input_top_k],
216
- outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df], # 예제 μ‹€ν–‰ μ‹œμ—λ„ λͺ¨λ“  좜λ ₯ μ»΄ν¬λ„ŒνŠΈ ν•„μš”
217
- fn=analyze_sentence_for_gradio, # 예제 μ‹€ν–‰ μ‹œμ—λ„ 동일 ν•¨μˆ˜ μ‚¬μš©
218
- cache_examples=False # λͺ¨λΈμ΄ μžˆλŠ” 경우 True둜 ν•˜λ©΄ 예제 λ‘œλ”©μ΄ 빨라질 수 μžˆμœΌλ‚˜, 디버깅 μ€‘μ—λŠ” False ꢌμž₯
219
  )
220
 
221
- # Gradio μ•± μ‹€ν–‰ (Hugging Face Spacesμ—μ„œλŠ” 이 뢀뢄이 μžλ™μœΌλ‘œ 처리됨)
222
- # λ‘œμ»¬μ—μ„œ ν…ŒμŠ€νŠΈ μ‹œ: demo.launch()
223
  if __name__ == "__main__":
224
  if not MODELS_LOADED_SUCCESSFULLY:
225
  print("*"*80)
226
- print("κ²½κ³ : λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν•˜μ—¬ Gradio 앱이 μ •μƒμ μœΌλ‘œ μž‘λ™ν•˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
227
- print(f"였λ₯˜ λ‚΄μš©: {MODEL_LOADING_ERROR_MESSAGE}")
228
- print("Gradio UIλŠ” ν‘œμ‹œλ˜μ§€λ§Œ, '뢄석 μ‹€ν–‰ν•˜κΈ°' λ²„νŠΌμ„ λˆŒλ €μ„ λ•Œ 였λ₯˜κ°€ λ°œμƒν•©λ‹ˆλ‹€.")
229
- print("`μ €μž₯μ €μž₯1` 폴더 및 λ‚΄λΆ€ νŒŒμΌλ“€μ΄ `app.py`와 λ™μΌν•œ 디렉토리에 μžˆλŠ”μ§€ ν™•μΈν•˜μ„Έμš”.")
230
  print("*"*80)
231
- # Hugging Face SpacesλŠ” app.pyλ₯Ό μ‹€ν–‰ν•˜κ³  demo.launch()λ₯Ό μ°Ύκ±°λ‚˜
232
- # demoλΌλŠ” μ΄λ¦„μ˜ launchable Blocks/Interface 객체λ₯Ό μ°ΎμŠ΅λ‹ˆλ‹€.
233
  demo.launch()
 
5
  import numpy as np
6
  import html
7
  from transformers import AutoTokenizer, AutoModel, logging as hf_logging
8
+ import pandas as pd
9
+ import matplotlib
10
+ matplotlib.use('Agg') # Matplotlib λ°±μ—”λ“œ μ„€μ •
11
+ import matplotlib.pyplot as plt
12
+ # from mpl_toolkits.mplot3d import Axes3D # 3D ν”Œλ‘―μ— ν•„μš”
13
+ from sklearn.decomposition import PCA
14
 
15
+ # --- κΈ°μ‘΄ μ„€μ • 및 μ „μ—­ λͺ¨λΈ λ‘œλ“œ λΆ€λΆ„ ---
16
  # Hugging Face Transformers λ‘œκΉ… 레벨 μ„€μ •
17
  hf_logging.set_verbosity_error()
18
 
19
+ # μ„€μ •
20
  MODEL_NAME = "bert-base-uncased"
21
  DEVICE = "cpu"
22
+ SAVE_DIR = "μ €μž₯μ €μž₯1"
23
  LAYER_ID = 4
24
  SEED = 0
25
  CLF_NAME = "linear"
26
 
27
+ # μ „μ—­ λͺ¨λΈ λ‘œλ“œ
28
+ TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
 
 
29
  W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
30
  CLASS_NAMES_GLOBAL = None
31
  MODELS_LOADED_SUCCESSFULLY = False
 
37
  clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
38
 
39
  if not os.path.isdir(SAVE_DIR):
40
+ raise FileNotFoundError(f"였λ₯˜: λͺ¨λΈ μ €μž₯ 디렉토리 '{SAVE_DIR}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
41
  if not os.path.exists(lda_file_path):
42
  raise FileNotFoundError(f"였λ₯˜: LDA λͺ¨λΈ 파일 '{lda_file_path}'λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
43
  if not os.path.exists(clf_file_path):
 
46
  lda = joblib.load(lda_file_path)
47
  clf = joblib.load(clf_file_path)
48
 
49
+ if hasattr(clf, "base_estimator"): clf = clf.base_estimator
 
50
 
51
  W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
52
  MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
 
54
  B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
55
 
56
  TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
57
+ MODEL_GLOBAL = AutoModel.from_pretrained( # output_attentions 제거 λ˜λŠ” False
58
+ MODEL_NAME, output_hidden_states=True, output_attentions=False
59
  ).to(DEVICE).eval()
60
 
61
+ if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
62
+ elif hasattr(clf, 'classes_'): CLASS_NAMES_GLOBAL = clf.classes_
 
 
63
 
64
  MODELS_LOADED_SUCCESSFULLY = True
65
  print("Gradio App: λͺ¨λ“  λͺ¨λΈ 및 데이터 λ‘œλ“œ 성곡!")
66
 
67
  except Exception as e:
68
+ MODELS_LOADED_SUCCESSFULLY = False
69
  MODEL_LOADING_ERROR_MESSAGE = f"λͺ¨λΈ λ‘œλ”© 쀑 μ‹¬κ°ν•œ 였λ₯˜ λ°œμƒ: {str(e)}\n'μ €μž₯μ €μž₯1' 폴더와 λ‚΄μš©λ¬Όμ„ ν™•μΈν•΄μ£Όμ„Έμš”."
70
  print(MODEL_LOADING_ERROR_MESSAGE)
 
71
 
72
+ # 헬퍼 ν•¨μˆ˜: PCA μ‹œκ°ν™” (3D둜 μˆ˜μ •)
73
+ def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
74
+ fig = plt.figure(figsize=(10, 8))
75
+ ax = fig.add_subplot(111, projection='3d') # 3D μΆ• 생성
76
+
77
+ # 일뢀 ν† ν°λ§Œ μ–΄λ…Έν…Œμ΄μ…˜ (λ„ˆλ¬΄ 많으면 λ³΅μž‘ν•΄μ§)
78
+ # 예λ₯Ό λ“€μ–΄, μ€‘μš”λ„ μƒμœ„ N개 λ˜λŠ” 간격을 두고 μ–΄λ…Έν…Œμ΄μ…˜
79
+ num_annotations = min(len(tokens), 15) # μ΅œλŒ€ 15개 토큰 μ–΄λ…Έν…Œμ΄μ…˜
80
+ indices_to_annotate = np.argsort(scores)[-num_annotations:] # μ€‘μš”λ„ 높은 순
81
+
82
+ scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
83
+ c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True) # coolwarm_r: λ†’μ„μˆ˜λ‘ μ§„ν•œ λΉ¨κ°•
84
+
85
+ for i in range(len(tokens)):
86
+ if i in indices_to_annotate: # μ„ νƒλœ 인덱슀의 ν† ν°λ§Œ ν‘œμ‹œ
87
+ ax.text(token_embeddings_3d[i, 0], token_embeddings_3d[i, 1], token_embeddings_3d[i, 2],
88
+ f' {tokens[i]}', size=8, zorder=1, color='k')
89
+
90
+ ax.set_title(title, fontsize=14)
91
+ ax.set_xlabel("PCA Component 1", fontsize=10)
92
+ ax.set_ylabel("PCA Component 2", fontsize=10)
93
+ ax.set_zlabel("PCA Component 3", fontsize=10)
94
+
95
+ cbar = plt.colorbar(scatter, label="Importance Score", shrink=0.7)
96
+ cbar.ax.tick_params(labelsize=8)
97
+ ax.tick_params(axis='both', which='major', labelsize=8)
98
+
99
+ plt.tight_layout()
100
+ return fig
101
+
102
+ # ────────── 핡심 뢄석 ν•¨μˆ˜ (μ–΄ν…μ…˜ λ§΅ μ œμ™Έ, PCA 3D둜, λ°˜ν™˜ κ°’ 7개) ──────────
103
  def analyze_sentence_for_gradio(sentence_text, top_k_value):
104
+ # 빈 ν”Œλ‘― 생성 ν•¨μˆ˜ (였λ₯˜ μ‹œ μ‚¬μš©)
105
+ def create_empty_plot(message="N/A"):
106
+ fig = plt.figure(figsize=(2,2));
107
+ ax = fig.add_subplot(111)
108
+ ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
109
+ ax.axis('off')
110
+ plt.close(fig) # λ©”λͺ¨λ¦¬ 관리λ₯Ό μœ„ν•΄ λ°”λ‘œ λ‹«μŒ (Gradioκ°€ Figure 객체λ₯Ό 볡사해 감)
111
+ return fig
112
+
113
  if not MODELS_LOADED_SUCCESSFULLY:
 
114
  error_html = f"<p style='color:red;'>μ΄ˆκΈ°ν™” 였λ₯˜: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
115
+ empty_df = pd.DataFrame(columns=['token', 'score'])
116
+ empty_fig_placeholder = create_empty_plot()
117
+ return error_html, [], "λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨", "N/A", [], empty_df, empty_fig_placeholder # 7개 λ°˜ν™˜
118
 
119
  try:
120
+ tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
 
 
121
  W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
122
  class_names = CLASS_NAMES_GLOBAL
123
 
 
124
  enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
125
+ input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
 
126
 
127
  if input_ids.shape[1] == 0:
128
+ empty_df = pd.DataFrame(columns=['token', 'score'])
129
+ empty_fig_placeholder = create_empty_plot()
130
+ return "<p style='color:orange;'>μž…λ ₯ 였λ₯˜: μœ νš¨ν•œ 토큰이 μ—†μŠ΅λ‹ˆλ‹€.</p>", [], "μž…λ ₯ 였λ₯˜", "N/A", [], empty_df, empty_fig_placeholder
 
 
131
 
132
+ input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
133
+ input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
134
+
135
+ outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
136
+ output_hidden_states=True, output_attentions=False) # μ–΄ν…μ…˜μ€ 이제 ν•„μš” μ—†μŒ
137
+
138
  cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
139
+
 
140
  z_projected = (cls_vec - mu) @ W
141
  logit_output = z_projected @ w_p.T + b_p
142
  probs = torch.softmax(logit_output, dim=1)
143
+ pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item()
 
144
 
145
+ if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_()
 
 
146
  logit_output[0, pred_idx].backward()
147
+ if input_embeds_for_grad.grad is None:
148
+ empty_df = pd.DataFrame(columns=['token', 'score'])
149
+ empty_fig_placeholder = create_empty_plot()
150
+ return "<p style='color:red;'>뢄석 였λ₯˜: κ·Έλž˜λ””μ–ΈνŠΈ 계산 μ‹€νŒ¨.</p>", [],"뢄석 였λ₯˜", "N/A", [], empty_df, empty_fig_placeholder
151
+
152
+ grads = input_embeds_for_grad.grad.clone().detach()
153
+ scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
154
  scores_np = scores.cpu().numpy()
155
  valid_scores = scores_np[np.isfinite(scores_np)]
156
+ scores_np = scores_np / (valid_scores.max() + 1e-9) if len(valid_scores) > 0 and valid_scores.max() > 0 else np.zeros_like(scores_np)
 
 
 
157
 
158
+ tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
159
+ actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
160
+ actual_scores_np = scores_np[:len(actual_tokens)]
161
+ actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
162
 
163
+ html_tokens_list, highlighted_text_data = [], []
164
+ cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
165
+
166
+ for i, tok_str in enumerate(actual_tokens):
167
  clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
168
+ current_score = actual_scores_np[i]
169
+ current_score_clipped = max(0, min(1, current_score))
170
+ current_token_id = input_ids[0, i].item()
171
+
172
+ if current_token_id == cls_token_id or current_token_id == sep_token_id:
173
  html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
174
+ highlighted_text_data.append((clean_tok_str + " ", None))
175
  else:
176
+ color = f"rgba(255, 0, 0, {current_score_clipped:.2f})"
 
177
  html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
178
+ highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
179
+
180
  html_output_str = " ".join(html_tokens_list).replace(" ##", "")
181
 
182
+ top_tokens_for_df, top_tokens_for_barplot_list = [], []
183
+ valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
184
+ if token_id not in [cls_token_id, sep_token_id]]
185
+ sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx])
 
186
  for token_idx in sorted_valid_indices[:top_k_value]:
187
+ token_str = actual_tokens[token_idx]
188
+ score_val_str = f"{actual_scores_np[token_idx]:.3f}"
189
+ top_tokens_for_df.append([token_str, score_val_str])
190
+ top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]})
191
+
192
+ barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
193
 
 
194
  predicted_class_label_str = str(pred_idx)
195
  if class_names is not None and 0 <= pred_idx < len(class_names):
196
  predicted_class_label_str = str(class_names[pred_idx])
197
 
198
  prediction_summary_text = f"클래슀: {predicted_class_label_str}\nν™•λ₯ : {pred_prob_val:.3f}"
199
  prediction_details_for_label = {"예츑 클래슀": predicted_class_label_str, "ν™•λ₯ ": f"{pred_prob_val:.3f}"}
200
+
201
+ # 토큰 μž„λ² λ”© PCA μ‹œκ°ν™” (3D)
202
+ non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
203
+ if token_id not in [cls_token_id, sep_token_id]]
204
+
205
+ # PCAλŠ” n_samples >= n_components μ—¬μ•Ό 함. μ—¬κΈ°μ„œλŠ” 3개의 μ»΄ν¬λ„ŒνŠΈ.
206
+ if len(non_special_token_indices) >= 3 :
207
+ pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
208
+ pca_embeddings = actual_input_embeds[non_special_token_indices, :]
209
+ pca_scores = actual_scores_np[non_special_token_indices]
210
+
211
+ pca = PCA(n_components=3, random_state=SEED)
212
+ token_embeddings_3d = pca.fit_transform(pca_embeddings)
213
+ pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
214
+ else:
215
+ pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
216
 
217
+ return (html_output_str, highlighted_text_data,
218
+ prediction_summary_text, prediction_details_for_label,
219
+ top_tokens_for_df, barplot_df,
220
+ pca_fig) # μ–΄ν…μ…˜ λ§΅ λŒ€μ‹  pca_fig만 λ°˜ν™˜ (총 7개)
221
 
222
  except Exception as e:
223
  import traceback
224
  tb_str = traceback.format_exc()
225
  error_html = f"<p style='color:red;'>뢄석 쀑 였λ₯˜ λ°œμƒ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
226
  print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
227
+ empty_df = pd.DataFrame(columns=['token', 'score'])
228
+ empty_fig_placeholder = create_empty_plot("Error during plot generation")
229
+ return error_html, [], "뢄석 μ‹€νŒ¨", {"였λ₯˜": str(e)}, [], empty_df, empty_fig_placeholder
230
 
231
 
232
+ # ────────── Gradio μΈν„°νŽ˜μ΄μŠ€ μ •μ˜ (μ–΄ν…μ…˜ λ§΅ νƒ­ 제거) ──────────
233
+ theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
234
+ body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)", # λ°°κ²½ κ·ΈλΌλ°μ΄μ…˜
235
+ block_background_fill="rgba(255,255,255,0.8)", # 블둝 λ°°κ²½ 반투λͺ…
236
+ block_border_width="1px",
237
+ block_shadow="*shadow_drop_lg"
238
+ )
239
 
 
 
 
 
 
240
 
241
+ with gr.Blocks(title="AI λ¬Έμž₯ 뢄석기 XAI πŸš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
242
+ gr.Markdown("# πŸš€ AI λ¬Έμž₯ 뢄석기 XAI: λͺ¨λΈ 해석 νƒν—˜")
243
+ gr.Markdown("BERT λͺ¨λΈ 예츑의 κ·Όκ±°λ₯Ό λ‹€μ–‘ν•œ μ‹œκ°ν™” κΈ°λ²•μœΌλ‘œ νƒμƒ‰ν•©λ‹ˆλ‹€. ν† ν°μ˜ μ€‘μš”λ„μ™€ μž„λ² λ”© κ³΅κ°„μ—μ„œμ˜ 뢄포λ₯Ό ν™•μΈν•΄λ³΄μ„Έμš”.")
244
 
245
+ with gr.Row(equal_height=False):
246
+ with gr.Column(scale=1, min_width=300):
247
+ with gr.Group():
248
+ gr.Markdown("### ✏️ λ¬Έμž₯ μž…λ ₯ & μ„€μ •")
249
+ input_sentence = gr.Textbox(lines=5, label="뢄석할 μ˜μ–΄ λ¬Έμž₯", placeholder="여기에 λΆ„μ„ν•˜κ³  싢은 μ˜μ–΄ λ¬Έμž₯을 μž…λ ₯ν•˜μ„Έμš”...")
250
+ input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-K 토큰 수")
251
+ submit_button = gr.Button("뢄석 μ‹œμž‘ πŸ’«", variant="primary", scale=1)
252
+
253
  with gr.Column(scale=2):
254
+ with gr.Accordion("🎯 예츑 결과", open=True):
255
+ output_prediction_summary = gr.Textbox(label="간단 μš”μ•½", lines=2, interactive=False)
256
+ output_prediction_details = gr.Label(label="상세 정보") # Label은 λ”•μ…”λ„ˆλ¦¬ ν‘œμ‹œ
257
+ with gr.Accordion("⭐ Top-K μ€‘μš” 토큰 (ν‘œ)", open=True):
258
+ output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="μ€‘μš”λ„ 높은 토큰",
259
+ row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
 
 
 
260
 
261
+ with gr.Tabs() as tabs:
262
+ with gr.TabItem("🎨 HTML ν•˜μ΄λΌμ΄νŠΈ", id=0):
263
+ output_html_visualization = gr.HTML(label="토큰별 μ€‘μš”λ„ (Gradient x Input)")
264
+ with gr.TabItem("πŸ–οΈ ν…μŠ€νŠΈ ν•˜μ΄λΌμ΄νŠΈ", id=1):
265
+ output_highlighted_text = gr.HighlightedText(
266
+ label="μ€‘μš”λ„ 기반 ν…μŠ€νŠΈ ν•˜μ΄λΌμ΄νŠΈ (점수: 0~1)",
267
+ show_legend=True,
268
+ combine_adjacent=False
269
+ )
270
+ with gr.TabItem("πŸ“Š Top-K λ§‰λŒ€ κ·Έλž˜ν”„", id=2):
271
+ output_top_tokens_barplot = gr.BarPlot(label="Top-K 토큰 μ€‘μš”λ„", x="token", y="score", tooltip=['token', 'score'], min_width=300, color_legend_ έναντι=None)
272
+ with gr.TabItem("🌐 토큰 μž„λ² λ”© 3D PCA", id=3): # μ–΄ν…μ…˜ λ§΅ λŒ€μ‹  PCA νƒ­
273
+ output_pca_plot = gr.Plot(label="토큰 μž„λ² λ”© 3D PCA (μ€‘μš”λ„ 색상)")
274
+
275
  gr.Markdown("---")
276
+ gr.Examples(
277
+ examples=[
278
+ ["This movie is an absolute masterpiece, captivating from start to finish.", 5],
279
+ ["Despite some flaws, the film offers a compelling narrative.", 3],
280
+ ["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
281
+ ],
282
+ inputs=[input_sentence, input_top_k],
283
+ outputs=[ # λ°˜ν™˜ κ°’ κ°œμˆ˜μ— 맞좰 7개둜 μˆ˜μ •
284
+ output_html_visualization, output_highlighted_text,
285
+ output_prediction_summary, output_prediction_details,
286
+ output_top_tokens_df, output_top_tokens_barplot,
287
+ output_pca_plot # μ–΄ν…μ…˜ ν”Œλ‘― 제거, PCA ν”Œλ‘―λ§Œ 남김
288
+ ],
289
+ fn=analyze_sentence_for_gradio,
290
+ cache_examples=False
291
+ )
292
+ gr.Markdown("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>", unsafe_allow_html=True)
293
 
 
294
  submit_button.click(
295
  fn=analyze_sentence_for_gradio,
296
  inputs=[input_sentence, input_top_k],
297
+ outputs=[ # λ°˜ν™˜ κ°’ κ°œμˆ˜μ— 맞좰 7개둜 μˆ˜μ •
298
+ output_html_visualization, output_highlighted_text,
299
+ output_prediction_summary, output_prediction_details,
300
+ output_top_tokens_df, output_top_tokens_barplot,
301
+ output_pca_plot # μ–΄ν…μ…˜ ν”Œλ‘― 제거, PCA ν”Œλ‘―λ§Œ 남김
 
 
 
 
302
  ],
303
+ api_name="explain_sentence_xai"
 
 
 
304
  )
305
 
 
 
306
  if __name__ == "__main__":
307
  if not MODELS_LOADED_SUCCESSFULLY:
308
  print("*"*80)
309
+ print(f"κ²½κ³ : λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨! {MODEL_LOADING_ERROR_MESSAGE}")
310
+ print("Gradio UIλŠ” ν‘œμ‹œλ˜μ§€λ§Œ 뢄석 κΈ°λŠ₯이 μ œλŒ€λ‘œ μž‘λ™ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.")
 
 
311
  print("*"*80)
 
 
312
  demo.launch()