Update app.py
Browse files
app.py
CHANGED
|
@@ -44,9 +44,10 @@ def custom_forward(input_ids, attention_mask):
|
|
| 44 |
lig = LayerIntegratedGradients(custom_forward, model.roberta.embeddings.word_embeddings)
|
| 45 |
|
| 46 |
llm_client = InferenceClient("Qwen/Qwen3-8B", token=HF_TOKEN)
|
|
|
|
| 47 |
def generate_row_explanation(a_list, idx, text_b):
|
| 48 |
if not a_list or idx >= len(a_list) or not text_b:
|
| 49 |
-
return "", ""
|
| 50 |
|
| 51 |
policy_a = clean_policy_text(a_list[idx])
|
| 52 |
policy_b = clean_policy_text(text_b)
|
|
@@ -77,7 +78,12 @@ def generate_row_explanation(a_list, idx, text_b):
|
|
| 77 |
attributions = attributions.cpu().detach().numpy()
|
| 78 |
|
| 79 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
formatted_scores = ", ".join(score_list)
|
| 82 |
|
| 83 |
# 2. Call Qwen LLM
|
|
@@ -98,11 +104,13 @@ Write a highly analytical, 2 to 3 sentence explanation of the model's reasoning.
|
|
| 98 |
think_content = match.group(1).strip()
|
| 99 |
final_answer = raw_output.replace(match.group(0), '').strip()
|
| 100 |
html_out = f"""<details style="margin-bottom: 12px; padding: 10px; background-color: #f3f4f6; border-radius: 6px; border: 1px solid #e5e7eb;"><summary style="cursor: pointer; font-weight: bold; color: #4b5563; outline: none;">🧠 Click to peek into the AI's thought process</summary><div style="margin-top: 10px; font-size: 0.9em; color: #6b7280; white-space: pre-wrap;">{think_content}</div></details>"""
|
| 101 |
-
return html_out, final_answer
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
except Exception as e:
|
| 105 |
-
return "", f"⚠️ Explainability Error: {str(e)}"
|
| 106 |
|
| 107 |
|
| 108 |
def bucket_score(score):
|
|
@@ -364,6 +372,7 @@ def load_hf_dataset():
|
|
| 364 |
"Target_Column", "Target_A_Row", "Target_B_Row",
|
| 365 |
"Context_Column", "Context_A_Chunk", "Context_B_Chunk",
|
| 366 |
"Model_Coarse_Label", "Model_Drill_Down_Label", "Model_Confidences", # New Columns
|
|
|
|
| 367 |
"Coherence_Label", "Drill_Down_Label", "Justification", "AnnotatorUsername"
|
| 368 |
])
|
| 369 |
|
|
@@ -516,6 +525,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 516 |
m_coarse_st = gr.State("")
|
| 517 |
m_drill_st = gr.State("")
|
| 518 |
m_conf_st = gr.State("")
|
|
|
|
|
|
|
| 519 |
|
| 520 |
with gr.Row(equal_height=True):
|
| 521 |
b_text = gr.Textbox(label=f"Target B", interactive=False, scale=4, min_width=200, lines=3, max_lines=8)
|
|
@@ -535,11 +546,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 535 |
explain_btn.click(
|
| 536 |
fn=generate_row_explanation,
|
| 537 |
inputs=[target_a_list_state, current_index_state, b_text],
|
| 538 |
-
outputs=[explain_html, just_box]
|
| 539 |
-
|
|
|
|
| 540 |
|
| 541 |
# Tracking 9 items per row now
|
| 542 |
-
eval_rows.append((row_container, b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse_st, m_drill_st, m_conf_st))
|
|
|
|
| 543 |
|
| 544 |
with gr.Row():
|
| 545 |
skip_btn = gr.Button("Skip Target A", size="lg")
|
|
@@ -614,7 +627,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 614 |
updates = []
|
| 615 |
|
| 616 |
# 9 components per row to reset
|
| 617 |
-
empty_row = [gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "", "", ""]
|
|
|
|
| 618 |
|
| 619 |
if not a_list:
|
| 620 |
prog_txt = t_text("**Progress:** No unannotated items found.", lang)
|
|
@@ -646,7 +660,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 646 |
gr.update(value=""), # conf_md
|
| 647 |
gr.update(choices=[], value=None),
|
| 648 |
gr.update(value=""), # just_box
|
| 649 |
-
"", "", ""
|
| 650 |
])
|
| 651 |
else:
|
| 652 |
updates.extend(empty_row)
|
|
@@ -656,11 +670,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 656 |
def load_workspace(dom_a, pol_a, dom_b, pol_b, tar_col, ctx_col, hf_df, user_tag, lang):
|
| 657 |
if not pol_a or not pol_b:
|
| 658 |
err = t_text("Error: Select both policies.", lang)
|
| 659 |
-
return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*
|
| 660 |
|
| 661 |
if tar_col == ctx_col:
|
| 662 |
err = t_text("Error: Target and Context cannot be the same.", lang)
|
| 663 |
-
return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*
|
| 664 |
|
| 665 |
df_a = DOMAIN_MAP[dom_a]
|
| 666 |
df_b = DOMAIN_MAP[dom_b]
|
|
@@ -724,14 +738,16 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 724 |
b_val_eng = b_eng_list[i]
|
| 725 |
|
| 726 |
# Row data length is now 8 elements: [b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse, m_drill, m_conf]
|
| 727 |
-
rel = row_data[i*
|
| 728 |
-
inter = row_data[i*
|
| 729 |
-
just = row_data[i*
|
| 730 |
|
| 731 |
# Extract the independent model predictions from hidden states
|
| 732 |
-
model_coarse = row_data[i*
|
| 733 |
-
model_drill = row_data[i*
|
| 734 |
-
model_conf = row_data[i*
|
|
|
|
|
|
|
| 735 |
|
| 736 |
has_rel = bool(rel)
|
| 737 |
has_inter = bool(inter)
|
|
@@ -758,11 +774,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 758 |
"Model_Coarse_Label": model_coarse, # Log model prediction
|
| 759 |
"Model_Drill_Down_Label": model_drill, # Log model drill
|
| 760 |
"Model_Confidences": model_conf, # Log model JSON confidence
|
| 761 |
-
"
|
|
|
|
|
|
|
| 762 |
"Drill_Down_Label": inter, # Log User prediction
|
| 763 |
"Justification": just.strip(),
|
| 764 |
"AnnotatorUsername": user_tag
|
| 765 |
})
|
|
|
|
| 766 |
|
| 767 |
if new_rows:
|
| 768 |
new_df = pd.DataFrame(new_rows)
|
|
@@ -819,9 +838,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 819 |
first_pass_outputs = []
|
| 820 |
|
| 821 |
# Unpack 9 items per row
|
| 822 |
-
for container, b, r, c_md, inter, j, m_co, m_dr, m_cf in eval_rows:
|
| 823 |
-
|
| 824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
first_pass_outputs.extend([r, c_md, inter, m_co, m_dr, m_cf])
|
| 826 |
|
| 827 |
lang_selector.change(
|
|
|
|
| 44 |
lig = LayerIntegratedGradients(custom_forward, model.roberta.embeddings.word_embeddings)
|
| 45 |
|
| 46 |
llm_client = InferenceClient("Qwen/Qwen3-8B", token=HF_TOKEN)
|
| 47 |
+
|
| 48 |
def generate_row_explanation(a_list, idx, text_b):
|
| 49 |
if not a_list or idx >= len(a_list) or not text_b:
|
| 50 |
+
return "", "", "", ""
|
| 51 |
|
| 52 |
policy_a = clean_policy_text(a_list[idx])
|
| 53 |
policy_b = clean_policy_text(text_b)
|
|
|
|
| 78 |
attributions = attributions.cpu().detach().numpy()
|
| 79 |
|
| 80 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 81 |
+
|
| 82 |
+
# NEW: Create a pure JSON dictionary of the attributions to save to the CSV
|
| 83 |
+
ig_dict = {t.replace('Ġ', '').strip(): float(s) for t, s in zip(tokens, attributions) if t.replace('Ġ', '').strip()}
|
| 84 |
+
ig_json_str = json.dumps(ig_dict)
|
| 85 |
+
|
| 86 |
+
score_list = [f"'{k}': {v:.3f}" for k, v in ig_dict.items()]
|
| 87 |
formatted_scores = ", ".join(score_list)
|
| 88 |
|
| 89 |
# 2. Call Qwen LLM
|
|
|
|
| 104 |
think_content = match.group(1).strip()
|
| 105 |
final_answer = raw_output.replace(match.group(0), '').strip()
|
| 106 |
html_out = f"""<details style="margin-bottom: 12px; padding: 10px; background-color: #f3f4f6; border-radius: 6px; border: 1px solid #e5e7eb;"><summary style="cursor: pointer; font-weight: bold; color: #4b5563; outline: none;">🧠 Click to peek into the AI's thought process</summary><div style="margin-top: 10px; font-size: 0.9em; color: #6b7280; white-space: pre-wrap;">{think_content}</div></details>"""
|
|
|
|
| 107 |
|
| 108 |
+
# Return: UI HTML, UI TextBox, Hidden AI Text, Hidden IG JSON
|
| 109 |
+
return html_out, final_answer, raw_output, ig_json_str
|
| 110 |
+
|
| 111 |
+
return "", raw_output, raw_output, ig_json_str
|
| 112 |
except Exception as e:
|
| 113 |
+
return "", f"⚠️ Explainability Error: {str(e)}", "", ""
|
| 114 |
|
| 115 |
|
| 116 |
def bucket_score(score):
|
|
|
|
| 372 |
"Target_Column", "Target_A_Row", "Target_B_Row",
|
| 373 |
"Context_Column", "Context_A_Chunk", "Context_B_Chunk",
|
| 374 |
"Model_Coarse_Label", "Model_Drill_Down_Label", "Model_Confidences", # New Columns
|
| 375 |
+
"AI_Justification", "IG_JSON",
|
| 376 |
"Coherence_Label", "Drill_Down_Label", "Justification", "AnnotatorUsername"
|
| 377 |
])
|
| 378 |
|
|
|
|
| 525 |
m_coarse_st = gr.State("")
|
| 526 |
m_drill_st = gr.State("")
|
| 527 |
m_conf_st = gr.State("")
|
| 528 |
+
m_ai_just_st = gr.State("")
|
| 529 |
+
m_ig_json_st = gr.State("")
|
| 530 |
|
| 531 |
with gr.Row(equal_height=True):
|
| 532 |
b_text = gr.Textbox(label=f"Target B", interactive=False, scale=4, min_width=200, lines=3, max_lines=8)
|
|
|
|
| 546 |
explain_btn.click(
|
| 547 |
fn=generate_row_explanation,
|
| 548 |
inputs=[target_a_list_state, current_index_state, b_text],
|
| 549 |
+
outputs=[explain_html, just_box, m_ai_just_st, m_ig_json_st] # <-- ADD OUTPUTS
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
|
| 553 |
# Tracking 9 items per row now
|
| 554 |
+
# eval_rows.append((row_container, b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse_st, m_drill_st, m_conf_st))
|
| 555 |
+
eval_rows.append((row_container, b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse_st, m_drill_st, m_conf_st, m_ai_just_st, m_ig_json_st))
|
| 556 |
|
| 557 |
with gr.Row():
|
| 558 |
skip_btn = gr.Button("Skip Target A", size="lg")
|
|
|
|
| 627 |
updates = []
|
| 628 |
|
| 629 |
# 9 components per row to reset
|
| 630 |
+
# empty_row = [gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "", "", ""]
|
| 631 |
+
empty_row = [gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "", "", "", "", ""]
|
| 632 |
|
| 633 |
if not a_list:
|
| 634 |
prog_txt = t_text("**Progress:** No unannotated items found.", lang)
|
|
|
|
| 660 |
gr.update(value=""), # conf_md
|
| 661 |
gr.update(choices=[], value=None),
|
| 662 |
gr.update(value=""), # just_box
|
| 663 |
+
"", "", "", "", "" # Reset the 5 hidden model states
|
| 664 |
])
|
| 665 |
else:
|
| 666 |
updates.extend(empty_row)
|
|
|
|
| 670 |
def load_workspace(dom_a, pol_a, dom_b, pol_b, tar_col, ctx_col, hf_df, user_tag, lang):
|
| 671 |
if not pol_a or not pol_b:
|
| 672 |
err = t_text("Error: Select both policies.", lang)
|
| 673 |
+
return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*11)
|
| 674 |
|
| 675 |
if tar_col == ctx_col:
|
| 676 |
err = t_text("Error: Target and Context cannot be the same.", lang)
|
| 677 |
+
return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*11)
|
| 678 |
|
| 679 |
df_a = DOMAIN_MAP[dom_a]
|
| 680 |
df_b = DOMAIN_MAP[dom_b]
|
|
|
|
| 738 |
b_val_eng = b_eng_list[i]
|
| 739 |
|
| 740 |
# Row data length is now 8 elements: [b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse, m_drill, m_conf]
|
| 741 |
+
rel = row_data[i*10 + 1]
|
| 742 |
+
inter = row_data[i*10 + 3]
|
| 743 |
+
just = row_data[i*10 + 4]
|
| 744 |
|
| 745 |
# Extract the independent model predictions from hidden states
|
| 746 |
+
model_coarse = row_data[i*10 + 5]
|
| 747 |
+
model_drill = row_data[i*10 + 6]
|
| 748 |
+
model_conf = row_data[i*10 + 7]
|
| 749 |
+
ai_just = row_data[i*10 + 8] #
|
| 750 |
+
ig_json = row_data[i*10 + 9]
|
| 751 |
|
| 752 |
has_rel = bool(rel)
|
| 753 |
has_inter = bool(inter)
|
|
|
|
| 774 |
"Model_Coarse_Label": model_coarse, # Log model prediction
|
| 775 |
"Model_Drill_Down_Label": model_drill, # Log model drill
|
| 776 |
"Model_Confidences": model_conf, # Log model JSON confidence
|
| 777 |
+
"AI_Justification": ai_just, # Log pure AI Thoughts
|
| 778 |
+
"IG_JSON": ig_json, # Log Captum Gradients
|
| 779 |
+
"Coherence_Label": rel,
|
| 780 |
"Drill_Down_Label": inter, # Log User prediction
|
| 781 |
"Justification": just.strip(),
|
| 782 |
"AnnotatorUsername": user_tag
|
| 783 |
})
|
| 784 |
+
|
| 785 |
|
| 786 |
if new_rows:
|
| 787 |
new_df = pd.DataFrame(new_rows)
|
|
|
|
| 838 |
first_pass_outputs = []
|
| 839 |
|
| 840 |
# Unpack 9 items per row
|
| 841 |
+
# for container, b, r, c_md, inter, j, m_co, m_dr, m_cf in eval_rows:
|
| 842 |
+
# row_outputs.extend([container, b, r, c_md, inter, j, m_co, m_dr, m_cf])
|
| 843 |
+
# row_inputs.extend([b, r, c_md, inter, j, m_co, m_dr, m_cf])
|
| 844 |
+
# first_pass_outputs.extend([r, c_md, inter, m_co, m_dr, m_cf])
|
| 845 |
+
for container, b, r, c_md, inter, j, m_co, m_dr, m_cf, m_ai_j, m_ig_j in eval_rows:
|
| 846 |
+
row_outputs.extend([container, b, r, c_md, inter, j, m_co, m_dr, m_cf, m_ai_j, m_ig_j])
|
| 847 |
+
row_inputs.extend([b, r, c_md, inter, j, m_co, m_dr, m_cf, m_ai_j, m_ig_j])
|
| 848 |
first_pass_outputs.extend([r, c_md, inter, m_co, m_dr, m_cf])
|
| 849 |
|
| 850 |
lang_selector.change(
|