Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,22 +5,27 @@ import torch
|
|
| 5 |
import numpy as np
|
| 6 |
import html
|
| 7 |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
|
|
|
| 9 |
# Hugging Face Transformers λ‘κΉ
λ 벨 μ€μ
|
| 10 |
hf_logging.set_verbosity_error()
|
| 11 |
|
| 12 |
-
#
|
| 13 |
MODEL_NAME = "bert-base-uncased"
|
| 14 |
DEVICE = "cpu"
|
| 15 |
-
SAVE_DIR = "μ μ₯μ μ₯1"
|
| 16 |
LAYER_ID = 4
|
| 17 |
SEED = 0
|
| 18 |
CLF_NAME = "linear"
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
TOKENIZER_GLOBAL = None
|
| 23 |
-
MODEL_GLOBAL = None
|
| 24 |
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
|
| 25 |
CLASS_NAMES_GLOBAL = None
|
| 26 |
MODELS_LOADED_SUCCESSFULLY = False
|
|
@@ -32,7 +37,7 @@ try:
|
|
| 32 |
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
|
| 33 |
|
| 34 |
if not os.path.isdir(SAVE_DIR):
|
| 35 |
-
raise FileNotFoundError(f"μ€λ₯: λͺ¨λΈ μ μ₯ λλ ν 리 '{SAVE_DIR}'λ₯Ό μ°Ύμ μ μμ΅λλ€.
|
| 36 |
if not os.path.exists(lda_file_path):
|
| 37 |
raise FileNotFoundError(f"μ€λ₯: LDA λͺ¨λΈ νμΌ '{lda_file_path}'λ₯Ό μ°Ύμ μ μμ΅λλ€.")
|
| 38 |
if not os.path.exists(clf_file_path):
|
|
@@ -41,8 +46,7 @@ try:
|
|
| 41 |
lda = joblib.load(lda_file_path)
|
| 42 |
clf = joblib.load(clf_file_path)
|
| 43 |
|
| 44 |
-
if hasattr(clf, "base_estimator"):
|
| 45 |
-
clf = clf.base_estimator
|
| 46 |
|
| 47 |
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
|
| 48 |
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
|
|
@@ -50,184 +54,259 @@ try:
|
|
| 50 |
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
|
| 51 |
|
| 52 |
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
| 53 |
-
MODEL_GLOBAL = AutoModel.from_pretrained(
|
| 54 |
-
MODEL_NAME, output_hidden_states=True
|
| 55 |
).to(DEVICE).eval()
|
| 56 |
|
| 57 |
-
if hasattr(lda, 'classes_'):
|
| 58 |
-
|
| 59 |
-
elif hasattr(clf, 'classes_'):
|
| 60 |
-
CLASS_NAMES_GLOBAL = clf.classes_
|
| 61 |
|
| 62 |
MODELS_LOADED_SUCCESSFULLY = True
|
| 63 |
print("Gradio App: λͺ¨λ λͺ¨λΈ λ° λ°μ΄ν° λ‘λ μ±κ³΅!")
|
| 64 |
|
| 65 |
except Exception as e:
|
|
|
|
| 66 |
MODEL_LOADING_ERROR_MESSAGE = f"λͺ¨λΈ λ‘λ© μ€ μ¬κ°ν μ€λ₯ λ°μ: {str(e)}\n'μ μ₯μ μ₯1' ν΄λμ λ΄μ©λ¬Όμ νμΈν΄μ£ΌμΈμ."
|
| 67 |
print(MODEL_LOADING_ERROR_MESSAGE)
|
| 68 |
-
# μ΄ μ€λ₯λ Gradio UIλ₯Ό ν΅ν΄ μ¬μ©μμκ² μ λ¬λ μ μλλ‘ μ²λ¦¬ν μ μμ΅λλ€.
|
| 69 |
|
| 70 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
if not MODELS_LOADED_SUCCESSFULLY:
|
| 73 |
-
# λͺ¨λΈ λ‘λ© μ€ν¨ μ Gradio μΆλ ₯ νμμ λ§μΆ° μ€λ₯ λ©μμ§ λ°ν
|
| 74 |
error_html = f"<p style='color:red;'>μ΄κΈ°ν μ€λ₯: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
| 75 |
-
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
try:
|
| 79 |
-
|
| 80 |
-
tokenizer = TOKENIZER_GLOBAL
|
| 81 |
-
model = MODEL_GLOBAL
|
| 82 |
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
|
| 83 |
class_names = CLASS_NAMES_GLOBAL
|
| 84 |
|
| 85 |
-
# 1) ν ν°ν
|
| 86 |
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
|
| 87 |
-
input_ids = enc["input_ids"].to(DEVICE)
|
| 88 |
-
attn_mask = enc["attention_mask"].to(DEVICE)
|
| 89 |
|
| 90 |
if input_ids.shape[1] == 0:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
input_embeds = model.embeddings.word_embeddings(input_ids).clone().detach()
|
| 95 |
-
input_embeds.requires_grad_(True)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
|
| 100 |
-
|
| 101 |
-
# 4) LDA ν¬μ λ° λΆλ₯
|
| 102 |
z_projected = (cls_vec - mu) @ W
|
| 103 |
logit_output = z_projected @ w_p.T + b_p
|
| 104 |
probs = torch.softmax(logit_output, dim=1)
|
| 105 |
-
pred_idx = torch.argmax(probs, dim=1).item()
|
| 106 |
-
pred_prob_val = probs[0, pred_idx].item()
|
| 107 |
|
| 108 |
-
|
| 109 |
-
if input_embeds.grad is not None:
|
| 110 |
-
input_embeds.grad.zero_()
|
| 111 |
logit_output[0, pred_idx].backward()
|
| 112 |
-
if
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
| 118 |
scores_np = scores.cpu().numpy()
|
| 119 |
valid_scores = scores_np[np.isfinite(scores_np)]
|
| 120 |
-
if len(valid_scores) > 0 and valid_scores.max() > 0
|
| 121 |
-
scores_np = scores_np / (valid_scores.max() + 1e-9)
|
| 122 |
-
else:
|
| 123 |
-
scores_np = np.zeros_like(scores_np)
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
|
|
|
|
| 135 |
else:
|
| 136 |
-
|
| 137 |
-
color = f"rgba(255, 0, 0, {max(0, min(1, score_val)):.2f})"
|
| 138 |
html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
|
| 139 |
-
|
|
|
|
| 140 |
html_output_str = " ".join(html_tokens_list).replace(" ##", "")
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -scores_np[idx])
|
| 147 |
for token_idx in sorted_valid_indices[:top_k_value]:
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
# μμΈ‘ ν΄λμ€ λ μ΄λΈ
|
| 151 |
predicted_class_label_str = str(pred_idx)
|
| 152 |
if class_names is not None and 0 <= pred_idx < len(class_names):
|
| 153 |
predicted_class_label_str = str(class_names[pred_idx])
|
| 154 |
|
| 155 |
prediction_summary_text = f"ν΄λμ€: {predicted_class_label_str}\nνλ₯ : {pred_prob_val:.3f}"
|
| 156 |
prediction_details_for_label = {"μμΈ‘ ν΄λμ€": predicted_class_label_str, "νλ₯ ": f"{pred_prob_val:.3f}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
|
| 161 |
except Exception as e:
|
| 162 |
import traceback
|
| 163 |
tb_str = traceback.format_exc()
|
| 164 |
error_html = f"<p style='color:red;'>λΆμ μ€ μ€λ₯ λ°μ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
|
| 165 |
print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
-
# ββββββββββ Gradio μΈν°νμ΄μ€ μ μ ββββββββββ
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
# μΆλ ₯ μ»΄ν¬λνΈ
|
| 175 |
-
output_html_visualization = gr.HTML(label="ν ν° μ€μλ μκ°ν")
|
| 176 |
-
output_prediction_summary = gr.Textbox(label="μμΈ‘ μμ½", lines=2) # κ°λ¨ν ν
μ€νΈ μμ½μ©
|
| 177 |
-
output_prediction_details = gr.Label(label="μμΈ‘ μμΈ") # Labelμ λμ
λ리λ₯Ό μ 보μ¬μ€
|
| 178 |
-
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Top-K μ€μ ν ν°", row_count=(1,"dynamic"), col_count=(2,"fixed"))
|
| 179 |
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
with gr.
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
with gr.Column(scale=2):
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
output_html_visualization.render()
|
| 196 |
-
output_top_tokens_df.render()
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
gr.Markdown("---")
|
| 199 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
-
# λ²νΌ ν΄λ¦ μ ν¨μ μ°κ²°
|
| 202 |
submit_button.click(
|
| 203 |
fn=analyze_sentence_for_gradio,
|
| 204 |
inputs=[input_sentence, input_top_k],
|
| 205 |
-
outputs=[
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
examples=[
|
| 211 |
-
["This is a great movie and I really enjoyed it!", 5],
|
| 212 |
-
["The weather is quite gloomy today.", 3],
|
| 213 |
-
["I am not sure if this is the right way to do it, but let's try.", 4]
|
| 214 |
],
|
| 215 |
-
|
| 216 |
-
outputs=[output_html_visualization, output_prediction_summary, output_prediction_details, output_top_tokens_df], # μμ μ€ν μμλ λͺ¨λ μΆλ ₯ μ»΄ν¬λνΈ νμ
|
| 217 |
-
fn=analyze_sentence_for_gradio, # μμ μ€ν μμλ λμΌ ν¨μ μ¬μ©
|
| 218 |
-
cache_examples=False # λͺ¨λΈμ΄ μλ κ²½μ° Trueλ‘ νλ©΄ μμ λ‘λ©μ΄ λΉ¨λΌμ§ μ μμΌλ, λλ²κΉ
μ€μλ False κΆμ₯
|
| 219 |
)
|
| 220 |
|
| 221 |
-
# Gradio μ± μ€ν (Hugging Face Spacesμμλ μ΄ λΆλΆμ΄ μλμΌλ‘ μ²λ¦¬λ¨)
|
| 222 |
-
# λ‘컬μμ ν
μ€νΈ μ: demo.launch()
|
| 223 |
if __name__ == "__main__":
|
| 224 |
if not MODELS_LOADED_SUCCESSFULLY:
|
| 225 |
print("*"*80)
|
| 226 |
-
print("κ²½κ³ : λͺ¨λΈ
|
| 227 |
-
print(
|
| 228 |
-
print("Gradio UIλ νμλμ§λ§, 'λΆμ μ€ννκΈ°' λ²νΌμ λλ μ λ μ€λ₯κ° λ°μν©λλ€.")
|
| 229 |
-
print("`μ μ₯μ μ₯1` ν΄λ λ° λ΄λΆ νμΌλ€μ΄ `app.py`μ λμΌν λλ ν 리μ μλμ§ νμΈνμΈμ.")
|
| 230 |
print("*"*80)
|
| 231 |
-
# Hugging Face Spacesλ app.pyλ₯Ό μ€ννκ³ demo.launch()λ₯Ό μ°Ύκ±°λ
|
| 232 |
-
# demoλΌλ μ΄λ¦μ launchable Blocks/Interface κ°μ²΄λ₯Ό μ°Ύμ΅λλ€.
|
| 233 |
demo.launch()
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import html
|
| 7 |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use('Agg') # Matplotlib λ°±μλ μ€μ
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
# from mpl_toolkits.mplot3d import Axes3D # 3D νλ‘―μ νμ
|
| 13 |
+
from sklearn.decomposition import PCA
|
| 14 |
|
| 15 |
+
# --- κΈ°μ‘΄ μ€μ λ° μ μ λͺ¨λΈ λ‘λ λΆλΆ ---
|
| 16 |
# Hugging Face Transformers λ‘κΉ
λ 벨 μ€μ
|
| 17 |
hf_logging.set_verbosity_error()
|
| 18 |
|
| 19 |
+
# μ€μ
|
| 20 |
MODEL_NAME = "bert-base-uncased"
|
| 21 |
DEVICE = "cpu"
|
| 22 |
+
SAVE_DIR = "μ μ₯μ μ₯1"
|
| 23 |
LAYER_ID = 4
|
| 24 |
SEED = 0
|
| 25 |
CLF_NAME = "linear"
|
| 26 |
|
| 27 |
+
# μ μ λͺ¨λΈ λ‘λ
|
| 28 |
+
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
|
|
|
|
|
|
|
| 29 |
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
|
| 30 |
CLASS_NAMES_GLOBAL = None
|
| 31 |
MODELS_LOADED_SUCCESSFULLY = False
|
|
|
|
| 37 |
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
|
| 38 |
|
| 39 |
if not os.path.isdir(SAVE_DIR):
|
| 40 |
+
raise FileNotFoundError(f"μ€λ₯: λͺ¨λΈ μ μ₯ λλ ν 리 '{SAVE_DIR}'λ₯Ό μ°Ύμ μ μμ΅λλ€.")
|
| 41 |
if not os.path.exists(lda_file_path):
|
| 42 |
raise FileNotFoundError(f"μ€λ₯: LDA λͺ¨λΈ νμΌ '{lda_file_path}'λ₯Ό μ°Ύμ μ μμ΅λλ€.")
|
| 43 |
if not os.path.exists(clf_file_path):
|
|
|
|
| 46 |
lda = joblib.load(lda_file_path)
|
| 47 |
clf = joblib.load(clf_file_path)
|
| 48 |
|
| 49 |
+
if hasattr(clf, "base_estimator"): clf = clf.base_estimator
|
|
|
|
| 50 |
|
| 51 |
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
|
| 52 |
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
|
|
|
|
| 54 |
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
|
| 55 |
|
| 56 |
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
| 57 |
+
MODEL_GLOBAL = AutoModel.from_pretrained( # output_attentions μ κ±° λλ False
|
| 58 |
+
MODEL_NAME, output_hidden_states=True, output_attentions=False
|
| 59 |
).to(DEVICE).eval()
|
| 60 |
|
| 61 |
+
if hasattr(lda, 'classes_'): CLASS_NAMES_GLOBAL = lda.classes_
|
| 62 |
+
elif hasattr(clf, 'classes_'): CLASS_NAMES_GLOBAL = clf.classes_
|
|
|
|
|
|
|
| 63 |
|
| 64 |
MODELS_LOADED_SUCCESSFULLY = True
|
| 65 |
print("Gradio App: λͺ¨λ λͺ¨λΈ λ° λ°μ΄ν° λ‘λ μ±κ³΅!")
|
| 66 |
|
| 67 |
except Exception as e:
|
| 68 |
+
MODELS_LOADED_SUCCESSFULLY = False
|
| 69 |
MODEL_LOADING_ERROR_MESSAGE = f"λͺ¨λΈ λ‘λ© μ€ μ¬κ°ν μ€λ₯ λ°μ: {str(e)}\n'μ μ₯μ μ₯1' ν΄λμ λ΄μ©λ¬Όμ νμΈν΄μ£ΌμΈμ."
|
| 70 |
print(MODEL_LOADING_ERROR_MESSAGE)
|
|
|
|
| 71 |
|
| 72 |
+
# ν¬νΌ ν¨μ: PCA μκ°ν (3Dλ‘ μμ )
|
| 73 |
+
def plot_token_pca_3d(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
|
| 74 |
+
fig = plt.figure(figsize=(10, 8))
|
| 75 |
+
ax = fig.add_subplot(111, projection='3d') # 3D μΆ μμ±
|
| 76 |
+
|
| 77 |
+
# μΌλΆ ν ν°λ§ μ΄λ
Έν
μ΄μ
(λ무 λ§μΌλ©΄ 볡μ‘ν΄μ§)
|
| 78 |
+
# μλ₯Ό λ€μ΄, μ€μλ μμ Nκ° λλ κ°κ²©μ λκ³ μ΄λ
Έν
μ΄μ
|
| 79 |
+
num_annotations = min(len(tokens), 15) # μ΅λ 15κ° ν ν° μ΄λ
Έν
μ΄μ
|
| 80 |
+
indices_to_annotate = np.argsort(scores)[-num_annotations:] # μ€μλ λμ μ
|
| 81 |
+
|
| 82 |
+
scatter = ax.scatter(token_embeddings_3d[:, 0], token_embeddings_3d[:, 1], token_embeddings_3d[:, 2],
|
| 83 |
+
c=scores, cmap="coolwarm_r", s=50, alpha=0.8, depthshade=True) # coolwarm_r: λμμλ‘ μ§ν λΉ¨κ°
|
| 84 |
+
|
| 85 |
+
for i in range(len(tokens)):
|
| 86 |
+
if i in indices_to_annotate: # μ νλ μΈλ±μ€μ ν ν°λ§ νμ
|
| 87 |
+
ax.text(token_embeddings_3d[i, 0], token_embeddings_3d[i, 1], token_embeddings_3d[i, 2],
|
| 88 |
+
f' {tokens[i]}', size=8, zorder=1, color='k')
|
| 89 |
+
|
| 90 |
+
ax.set_title(title, fontsize=14)
|
| 91 |
+
ax.set_xlabel("PCA Component 1", fontsize=10)
|
| 92 |
+
ax.set_ylabel("PCA Component 2", fontsize=10)
|
| 93 |
+
ax.set_zlabel("PCA Component 3", fontsize=10)
|
| 94 |
+
|
| 95 |
+
cbar = plt.colorbar(scatter, label="Importance Score", shrink=0.7)
|
| 96 |
+
cbar.ax.tick_params(labelsize=8)
|
| 97 |
+
ax.tick_params(axis='both', which='major', labelsize=8)
|
| 98 |
+
|
| 99 |
+
plt.tight_layout()
|
| 100 |
+
return fig
|
| 101 |
+
|
| 102 |
+
# ββββββββββ ν΅μ¬ λΆμ ν¨μ (μ΄ν
μ
λ§΅ μ μΈ, PCA 3Dλ‘, λ°ν κ° 7κ°) ββββββββββ
|
| 103 |
def analyze_sentence_for_gradio(sentence_text, top_k_value):
|
| 104 |
+
# λΉ νλ‘― μμ± ν¨μ (μ€λ₯ μ μ¬μ©)
|
| 105 |
+
def create_empty_plot(message="N/A"):
|
| 106 |
+
fig = plt.figure(figsize=(2,2));
|
| 107 |
+
ax = fig.add_subplot(111)
|
| 108 |
+
ax.text(0.5, 0.5, message, ha='center', va='center', fontsize=10)
|
| 109 |
+
ax.axis('off')
|
| 110 |
+
plt.close(fig) # λ©λͺ¨λ¦¬ κ΄λ¦¬λ₯Ό μν΄ λ°λ‘ λ«μ (Gradioκ° Figure κ°μ²΄λ₯Ό 볡μ¬ν΄ κ°)
|
| 111 |
+
return fig
|
| 112 |
+
|
| 113 |
if not MODELS_LOADED_SUCCESSFULLY:
|
|
|
|
| 114 |
error_html = f"<p style='color:red;'>μ΄κΈ°ν μ€λ₯: {html.escape(MODEL_LOADING_ERROR_MESSAGE)}</p>"
|
| 115 |
+
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 116 |
+
empty_fig_placeholder = create_empty_plot()
|
| 117 |
+
return error_html, [], "λͺ¨λΈ λ‘λ© μ€ν¨", "N/A", [], empty_df, empty_fig_placeholder # 7κ° λ°ν
|
| 118 |
|
| 119 |
try:
|
| 120 |
+
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
|
|
|
|
|
|
|
| 121 |
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
|
| 122 |
class_names = CLASS_NAMES_GLOBAL
|
| 123 |
|
|
|
|
| 124 |
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
|
| 125 |
+
input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
|
|
|
|
| 126 |
|
| 127 |
if input_ids.shape[1] == 0:
|
| 128 |
+
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 129 |
+
empty_fig_placeholder = create_empty_plot()
|
| 130 |
+
return "<p style='color:orange;'>μ
λ ₯ μ€λ₯: μ ν¨ν ν ν°μ΄ μμ΅λλ€.</p>", [], "μ
λ ₯ μ€λ₯", "N/A", [], empty_df, empty_fig_placeholder
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
|
| 133 |
+
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
|
| 134 |
+
|
| 135 |
+
outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
|
| 136 |
+
output_hidden_states=True, output_attentions=False) # μ΄ν
μ
μ μ΄μ νμ μμ
|
| 137 |
+
|
| 138 |
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
|
| 139 |
+
|
|
|
|
| 140 |
z_projected = (cls_vec - mu) @ W
|
| 141 |
logit_output = z_projected @ w_p.T + b_p
|
| 142 |
probs = torch.softmax(logit_output, dim=1)
|
| 143 |
+
pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item()
|
|
|
|
| 144 |
|
| 145 |
+
if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_()
|
|
|
|
|
|
|
| 146 |
logit_output[0, pred_idx].backward()
|
| 147 |
+
if input_embeds_for_grad.grad is None:
|
| 148 |
+
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 149 |
+
empty_fig_placeholder = create_empty_plot()
|
| 150 |
+
return "<p style='color:red;'>λΆμ μ€λ₯: κ·ΈλλμΈνΈ κ³μ° μ€ν¨.</p>", [],"λΆμ μ€λ₯", "N/A", [], empty_df, empty_fig_placeholder
|
| 151 |
+
|
| 152 |
+
grads = input_embeds_for_grad.grad.clone().detach()
|
| 153 |
+
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
|
| 154 |
scores_np = scores.cpu().numpy()
|
| 155 |
valid_scores = scores_np[np.isfinite(scores_np)]
|
| 156 |
+
scores_np = scores_np / (valid_scores.max() + 1e-9) if len(valid_scores) > 0 and valid_scores.max() > 0 else np.zeros_like(scores_np)
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
|
| 159 |
+
actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
|
| 160 |
+
actual_scores_np = scores_np[:len(actual_tokens)]
|
| 161 |
+
actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
|
| 162 |
|
| 163 |
+
html_tokens_list, highlighted_text_data = [], []
|
| 164 |
+
cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
|
| 165 |
+
|
| 166 |
+
for i, tok_str in enumerate(actual_tokens):
|
| 167 |
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
|
| 168 |
+
current_score = actual_scores_np[i]
|
| 169 |
+
current_score_clipped = max(0, min(1, current_score))
|
| 170 |
+
current_token_id = input_ids[0, i].item()
|
| 171 |
+
|
| 172 |
+
if current_token_id == cls_token_id or current_token_id == sep_token_id:
|
| 173 |
html_tokens_list.append(f"<span style='font-weight:bold;'>{html.escape(clean_tok_str)}</span>")
|
| 174 |
+
highlighted_text_data.append((clean_tok_str + " ", None))
|
| 175 |
else:
|
| 176 |
+
color = f"rgba(255, 0, 0, {current_score_clipped:.2f})"
|
|
|
|
| 177 |
html_tokens_list.append(f"<span style='background-color:{color}; padding: 1px 2px; margin: 1px; border-radius: 3px; display:inline-block;'>{html.escape(clean_tok_str)}</span>")
|
| 178 |
+
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
|
| 179 |
+
|
| 180 |
html_output_str = " ".join(html_tokens_list).replace(" ##", "")
|
| 181 |
|
| 182 |
+
top_tokens_for_df, top_tokens_for_barplot_list = [], []
|
| 183 |
+
valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
| 184 |
+
if token_id not in [cls_token_id, sep_token_id]]
|
| 185 |
+
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx])
|
|
|
|
| 186 |
for token_idx in sorted_valid_indices[:top_k_value]:
|
| 187 |
+
token_str = actual_tokens[token_idx]
|
| 188 |
+
score_val_str = f"{actual_scores_np[token_idx]:.3f}"
|
| 189 |
+
top_tokens_for_df.append([token_str, score_val_str])
|
| 190 |
+
top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]})
|
| 191 |
+
|
| 192 |
+
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
|
| 193 |
|
|
|
|
| 194 |
predicted_class_label_str = str(pred_idx)
|
| 195 |
if class_names is not None and 0 <= pred_idx < len(class_names):
|
| 196 |
predicted_class_label_str = str(class_names[pred_idx])
|
| 197 |
|
| 198 |
prediction_summary_text = f"ν΄λμ€: {predicted_class_label_str}\nνλ₯ : {pred_prob_val:.3f}"
|
| 199 |
prediction_details_for_label = {"μμΈ‘ ν΄λμ€": predicted_class_label_str, "νλ₯ ": f"{pred_prob_val:.3f}"}
|
| 200 |
+
|
| 201 |
+
# ν ν° μλ² λ© PCA μκ°ν (3D)
|
| 202 |
+
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
|
| 203 |
+
if token_id not in [cls_token_id, sep_token_id]]
|
| 204 |
+
|
| 205 |
+
# PCAλ n_samples >= n_components μ¬μΌ ν¨. μ¬κΈ°μλ 3κ°μ μ»΄ν¬λνΈ.
|
| 206 |
+
if len(non_special_token_indices) >= 3 :
|
| 207 |
+
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
|
| 208 |
+
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
|
| 209 |
+
pca_scores = actual_scores_np[non_special_token_indices]
|
| 210 |
+
|
| 211 |
+
pca = PCA(n_components=3, random_state=SEED)
|
| 212 |
+
token_embeddings_3d = pca.fit_transform(pca_embeddings)
|
| 213 |
+
pca_fig = plot_token_pca_3d(token_embeddings_3d, pca_tokens, pca_scores)
|
| 214 |
+
else:
|
| 215 |
+
pca_fig = create_empty_plot("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
|
| 216 |
|
| 217 |
+
return (html_output_str, highlighted_text_data,
|
| 218 |
+
prediction_summary_text, prediction_details_for_label,
|
| 219 |
+
top_tokens_for_df, barplot_df,
|
| 220 |
+
pca_fig) # μ΄ν
μ
λ§΅ λμ pca_figλ§ λ°ν (μ΄ 7κ°)
|
| 221 |
|
| 222 |
except Exception as e:
|
| 223 |
import traceback
|
| 224 |
tb_str = traceback.format_exc()
|
| 225 |
error_html = f"<p style='color:red;'>λΆμ μ€ μ€λ₯ λ°μ: {html.escape(str(e))}</p><pre>{html.escape(tb_str)}</pre>"
|
| 226 |
print(f"Analyze_sentence_for_gradio error: {e}\n{tb_str}")
|
| 227 |
+
empty_df = pd.DataFrame(columns=['token', 'score'])
|
| 228 |
+
empty_fig_placeholder = create_empty_plot("Error during plot generation")
|
| 229 |
+
return error_html, [], "λΆμ μ€ν¨", {"μ€λ₯": str(e)}, [], empty_df, empty_fig_placeholder
|
| 230 |
|
| 231 |
|
| 232 |
+
# ββββββββββ Gradio μΈν°νμ΄μ€ μ μ (μ΄ν
μ
λ§΅ ν μ κ±°) ββββββββββ
|
| 233 |
+
theme = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan", neutral_hue="sky").set(
|
| 234 |
+
body_background_fill="linear-gradient(to right, #c9d6ff, #e2e2e2)", # λ°°κ²½ κ·ΈλΌλ°μ΄μ
|
| 235 |
+
block_background_fill="rgba(255,255,255,0.8)", # λΈλ‘ λ°°κ²½ λ°ν¬λͺ
|
| 236 |
+
block_border_width="1px",
|
| 237 |
+
block_shadow="*shadow_drop_lg"
|
| 238 |
+
)
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
with gr.Blocks(title="AI λ¬Έμ₯ λΆμκΈ° XAI π", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
|
| 242 |
+
gr.Markdown("# π AI λ¬Έμ₯ λΆμκΈ° XAI: λͺ¨λΈ ν΄μ νν")
|
| 243 |
+
gr.Markdown("BERT λͺ¨λΈ μμΈ‘μ κ·Όκ±°λ₯Ό λ€μν μκ°ν κΈ°λ²μΌλ‘ νμν©λλ€. ν ν°μ μ€μλμ μλ² λ© κ³΅κ°μμμ λΆν¬λ₯Ό νμΈν΄λ³΄μΈμ.")
|
| 244 |
|
| 245 |
+
with gr.Row(equal_height=False):
|
| 246 |
+
with gr.Column(scale=1, min_width=300):
|
| 247 |
+
with gr.Group():
|
| 248 |
+
gr.Markdown("### βοΈ λ¬Έμ₯ μ
λ ₯ & μ€μ ")
|
| 249 |
+
input_sentence = gr.Textbox(lines=5, label="λΆμν μμ΄ λ¬Έμ₯", placeholder="μ¬κΈ°μ λΆμνκ³ μΆμ μμ΄ λ¬Έμ₯μ μ
λ ₯νμΈμ...")
|
| 250 |
+
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-K ν ν° μ")
|
| 251 |
+
submit_button = gr.Button("λΆμ μμ π«", variant="primary", scale=1)
|
| 252 |
+
|
| 253 |
with gr.Column(scale=2):
|
| 254 |
+
with gr.Accordion("π― μμΈ‘ κ²°κ³Ό", open=True):
|
| 255 |
+
output_prediction_summary = gr.Textbox(label="κ°λ¨ μμ½", lines=2, interactive=False)
|
| 256 |
+
output_prediction_details = gr.Label(label="μμΈ μ 보") # Labelμ λμ
λ리 νμ
|
| 257 |
+
with gr.Accordion("β Top-K μ€μ ν ν° (ν)", open=True):
|
| 258 |
+
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="μ€μλ λμ ν ν°",
|
| 259 |
+
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
with gr.Tabs() as tabs:
|
| 262 |
+
with gr.TabItem("π¨ HTML νμ΄λΌμ΄νΈ", id=0):
|
| 263 |
+
output_html_visualization = gr.HTML(label="ν ν°λ³ μ€μλ (Gradient x Input)")
|
| 264 |
+
with gr.TabItem("ποΈ ν
μ€νΈ νμ΄λΌμ΄νΈ", id=1):
|
| 265 |
+
output_highlighted_text = gr.HighlightedText(
|
| 266 |
+
label="μ€μλ κΈ°λ° ν
μ€νΈ νμ΄λΌμ΄νΈ (μ μ: 0~1)",
|
| 267 |
+
show_legend=True,
|
| 268 |
+
combine_adjacent=False
|
| 269 |
+
)
|
| 270 |
+
with gr.TabItem("π Top-K λ§λ κ·Έλν", id=2):
|
| 271 |
+
output_top_tokens_barplot = gr.BarPlot(label="Top-K ν ν° μ€μλ", x="token", y="score", tooltip=['token', 'score'], min_width=300, color_legend_ ΞΞ½Ξ±Ξ½ΟΞΉ=None)
|
| 272 |
+
with gr.TabItem("π ν ν° μλ² λ© 3D PCA", id=3): # μ΄ν
μ
λ§΅ λμ PCA ν
|
| 273 |
+
output_pca_plot = gr.Plot(label="ν ν° μλ² λ© 3D PCA (μ€μλ μμ)")
|
| 274 |
+
|
| 275 |
gr.Markdown("---")
|
| 276 |
+
gr.Examples(
|
| 277 |
+
examples=[
|
| 278 |
+
["This movie is an absolute masterpiece, captivating from start to finish.", 5],
|
| 279 |
+
["Despite some flaws, the film offers a compelling narrative.", 3],
|
| 280 |
+
["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
|
| 281 |
+
],
|
| 282 |
+
inputs=[input_sentence, input_top_k],
|
| 283 |
+
outputs=[ # λ°ν κ° κ°μμ λ§μΆ° 7κ°λ‘ μμ
|
| 284 |
+
output_html_visualization, output_highlighted_text,
|
| 285 |
+
output_prediction_summary, output_prediction_details,
|
| 286 |
+
output_top_tokens_df, output_top_tokens_barplot,
|
| 287 |
+
output_pca_plot # μ΄ν
μ
νλ‘― μ κ±°, PCA νλ‘―λ§ λ¨κΉ
|
| 288 |
+
],
|
| 289 |
+
fn=analyze_sentence_for_gradio,
|
| 290 |
+
cache_examples=False
|
| 291 |
+
)
|
| 292 |
+
gr.Markdown("<p style='text-align: center; color: #666;'>Explainable AI Demo with Gradio & Transformers</p>", unsafe_allow_html=True)
|
| 293 |
|
|
|
|
| 294 |
submit_button.click(
|
| 295 |
fn=analyze_sentence_for_gradio,
|
| 296 |
inputs=[input_sentence, input_top_k],
|
| 297 |
+
outputs=[ # λ°ν κ° κ°μμ λ§μΆ° 7κ°λ‘ μμ
|
| 298 |
+
output_html_visualization, output_highlighted_text,
|
| 299 |
+
output_prediction_summary, output_prediction_details,
|
| 300 |
+
output_top_tokens_df, output_top_tokens_barplot,
|
| 301 |
+
output_pca_plot # μ΄ν
μ
νλ‘― μ κ±°, PCA νλ‘―λ§ λ¨κΉ
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
],
|
| 303 |
+
api_name="explain_sentence_xai"
|
|
|
|
|
|
|
|
|
|
| 304 |
)
|
| 305 |
|
|
|
|
|
|
|
| 306 |
if __name__ == "__main__":
|
| 307 |
if not MODELS_LOADED_SUCCESSFULLY:
|
| 308 |
print("*"*80)
|
| 309 |
+
print(f"κ²½κ³ : λͺ¨λΈ λ‘λ© μ€ν¨! {MODEL_LOADING_ERROR_MESSAGE}")
|
| 310 |
+
print("Gradio UIλ νμλμ§λ§ λΆμ κΈ°λ₯μ΄ μ λλ‘ μλνμ§ μμ΅λλ€.")
|
|
|
|
|
|
|
| 311 |
print("*"*80)
|
|
|
|
|
|
|
| 312 |
demo.launch()
|