akaburia commited on
Commit
91cc9b0
·
verified ·
1 Parent(s): 0d921e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -22
app.py CHANGED
@@ -44,9 +44,10 @@ def custom_forward(input_ids, attention_mask):
44
  lig = LayerIntegratedGradients(custom_forward, model.roberta.embeddings.word_embeddings)
45
 
46
  llm_client = InferenceClient("Qwen/Qwen3-8B", token=HF_TOKEN)
 
47
  def generate_row_explanation(a_list, idx, text_b):
48
  if not a_list or idx >= len(a_list) or not text_b:
49
- return "", ""
50
 
51
  policy_a = clean_policy_text(a_list[idx])
52
  policy_b = clean_policy_text(text_b)
@@ -77,7 +78,12 @@ def generate_row_explanation(a_list, idx, text_b):
77
  attributions = attributions.cpu().detach().numpy()
78
 
79
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
80
- score_list = [f"'{t.replace('Ġ', '').strip()}': {s:.3f}" for t, s in zip(tokens, attributions) if t.replace('Ġ', '').strip()]
 
 
 
 
 
81
  formatted_scores = ", ".join(score_list)
82
 
83
  # 2. Call Qwen LLM
@@ -98,11 +104,13 @@ Write a highly analytical, 2 to 3 sentence explanation of the model's reasoning.
98
  think_content = match.group(1).strip()
99
  final_answer = raw_output.replace(match.group(0), '').strip()
100
  html_out = f"""<details style="margin-bottom: 12px; padding: 10px; background-color: #f3f4f6; border-radius: 6px; border: 1px solid #e5e7eb;"><summary style="cursor: pointer; font-weight: bold; color: #4b5563; outline: none;">🧠 Click to peek into the AI's thought process</summary><div style="margin-top: 10px; font-size: 0.9em; color: #6b7280; white-space: pre-wrap;">{think_content}</div></details>"""
101
- return html_out, final_answer
102
 
103
- return "", raw_output
 
 
 
104
  except Exception as e:
105
- return "", f"⚠️ Explainability Error: {str(e)}"
106
 
107
 
108
  def bucket_score(score):
@@ -364,6 +372,7 @@ def load_hf_dataset():
364
  "Target_Column", "Target_A_Row", "Target_B_Row",
365
  "Context_Column", "Context_A_Chunk", "Context_B_Chunk",
366
  "Model_Coarse_Label", "Model_Drill_Down_Label", "Model_Confidences", # New Columns
 
367
  "Coherence_Label", "Drill_Down_Label", "Justification", "AnnotatorUsername"
368
  ])
369
 
@@ -516,6 +525,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
516
  m_coarse_st = gr.State("")
517
  m_drill_st = gr.State("")
518
  m_conf_st = gr.State("")
 
 
519
 
520
  with gr.Row(equal_height=True):
521
  b_text = gr.Textbox(label=f"Target B", interactive=False, scale=4, min_width=200, lines=3, max_lines=8)
@@ -535,11 +546,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
535
  explain_btn.click(
536
  fn=generate_row_explanation,
537
  inputs=[target_a_list_state, current_index_state, b_text],
538
- outputs=[explain_html, just_box]
539
- )
 
540
 
541
  # Tracking 9 items per row now
542
- eval_rows.append((row_container, b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse_st, m_drill_st, m_conf_st))
 
543
 
544
  with gr.Row():
545
  skip_btn = gr.Button("Skip Target A", size="lg")
@@ -614,7 +627,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
614
  updates = []
615
 
616
  # 9 components per row to reset
617
- empty_row = [gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "", "", ""]
 
618
 
619
  if not a_list:
620
  prog_txt = t_text("**Progress:** No unannotated items found.", lang)
@@ -646,7 +660,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
646
  gr.update(value=""), # conf_md
647
  gr.update(choices=[], value=None),
648
  gr.update(value=""), # just_box
649
- "", "", "" # Reset the 3 hidden model states
650
  ])
651
  else:
652
  updates.extend(empty_row)
@@ -656,11 +670,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
656
  def load_workspace(dom_a, pol_a, dom_b, pol_b, tar_col, ctx_col, hf_df, user_tag, lang):
657
  if not pol_a or not pol_b:
658
  err = t_text("Error: Select both policies.", lang)
659
- return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*9)
660
 
661
  if tar_col == ctx_col:
662
  err = t_text("Error: Target and Context cannot be the same.", lang)
663
- return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*9)
664
 
665
  df_a = DOMAIN_MAP[dom_a]
666
  df_b = DOMAIN_MAP[dom_b]
@@ -724,14 +738,16 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
724
  b_val_eng = b_eng_list[i]
725
 
726
  # Row data length is now 8 elements: [b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse, m_drill, m_conf]
727
- rel = row_data[i*8 + 1]
728
- inter = row_data[i*8 + 3]
729
- just = row_data[i*8 + 4]
730
 
731
  # Extract the independent model predictions from hidden states
732
- model_coarse = row_data[i*8 + 5]
733
- model_drill = row_data[i*8 + 6]
734
- model_conf = row_data[i*8 + 7]
 
 
735
 
736
  has_rel = bool(rel)
737
  has_inter = bool(inter)
@@ -758,11 +774,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
758
  "Model_Coarse_Label": model_coarse, # Log model prediction
759
  "Model_Drill_Down_Label": model_drill, # Log model drill
760
  "Model_Confidences": model_conf, # Log model JSON confidence
761
- "Coherence_Label": rel, # Log User prediction
 
 
762
  "Drill_Down_Label": inter, # Log User prediction
763
  "Justification": just.strip(),
764
  "AnnotatorUsername": user_tag
765
  })
 
766
 
767
  if new_rows:
768
  new_df = pd.DataFrame(new_rows)
@@ -819,9 +838,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
819
  first_pass_outputs = []
820
 
821
  # Unpack 9 items per row
822
- for container, b, r, c_md, inter, j, m_co, m_dr, m_cf in eval_rows:
823
- row_outputs.extend([container, b, r, c_md, inter, j, m_co, m_dr, m_cf])
824
- row_inputs.extend([b, r, c_md, inter, j, m_co, m_dr, m_cf])
 
 
 
 
825
  first_pass_outputs.extend([r, c_md, inter, m_co, m_dr, m_cf])
826
 
827
  lang_selector.change(
 
44
  lig = LayerIntegratedGradients(custom_forward, model.roberta.embeddings.word_embeddings)
45
 
46
  llm_client = InferenceClient("Qwen/Qwen3-8B", token=HF_TOKEN)
47
+
48
  def generate_row_explanation(a_list, idx, text_b):
49
  if not a_list or idx >= len(a_list) or not text_b:
50
+ return "", "", "", ""
51
 
52
  policy_a = clean_policy_text(a_list[idx])
53
  policy_b = clean_policy_text(text_b)
 
78
  attributions = attributions.cpu().detach().numpy()
79
 
80
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
81
+
82
+ # NEW: Create a pure JSON dictionary of the attributions to save to the CSV
83
+ ig_dict = {t.replace('Ġ', '').strip(): float(s) for t, s in zip(tokens, attributions) if t.replace('Ġ', '').strip()}
84
+ ig_json_str = json.dumps(ig_dict)
85
+
86
+ score_list = [f"'{k}': {v:.3f}" for k, v in ig_dict.items()]
87
  formatted_scores = ", ".join(score_list)
88
 
89
  # 2. Call Qwen LLM
 
104
  think_content = match.group(1).strip()
105
  final_answer = raw_output.replace(match.group(0), '').strip()
106
  html_out = f"""<details style="margin-bottom: 12px; padding: 10px; background-color: #f3f4f6; border-radius: 6px; border: 1px solid #e5e7eb;"><summary style="cursor: pointer; font-weight: bold; color: #4b5563; outline: none;">🧠 Click to peek into the AI's thought process</summary><div style="margin-top: 10px; font-size: 0.9em; color: #6b7280; white-space: pre-wrap;">{think_content}</div></details>"""
 
107
 
108
+ # Return: UI HTML, UI TextBox, Hidden AI Text, Hidden IG JSON
109
+ return html_out, final_answer, raw_output, ig_json_str
110
+
111
+ return "", raw_output, raw_output, ig_json_str
112
  except Exception as e:
113
+ return "", f"⚠️ Explainability Error: {str(e)}", "", ""
114
 
115
 
116
  def bucket_score(score):
 
372
  "Target_Column", "Target_A_Row", "Target_B_Row",
373
  "Context_Column", "Context_A_Chunk", "Context_B_Chunk",
374
  "Model_Coarse_Label", "Model_Drill_Down_Label", "Model_Confidences", # New Columns
375
+ "AI_Justification", "IG_JSON",
376
  "Coherence_Label", "Drill_Down_Label", "Justification", "AnnotatorUsername"
377
  ])
378
 
 
525
  m_coarse_st = gr.State("")
526
  m_drill_st = gr.State("")
527
  m_conf_st = gr.State("")
528
+ m_ai_just_st = gr.State("")
529
+ m_ig_json_st = gr.State("")
530
 
531
  with gr.Row(equal_height=True):
532
  b_text = gr.Textbox(label=f"Target B", interactive=False, scale=4, min_width=200, lines=3, max_lines=8)
 
546
  explain_btn.click(
547
  fn=generate_row_explanation,
548
  inputs=[target_a_list_state, current_index_state, b_text],
549
+ outputs=[explain_html, just_box, m_ai_just_st, m_ig_json_st] # <-- ADD OUTPUTS
550
+ )
551
+
552
 
553
  # Tracking 9 items per row now
554
+ # eval_rows.append((row_container, b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse_st, m_drill_st, m_conf_st))
555
+ eval_rows.append((row_container, b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse_st, m_drill_st, m_conf_st, m_ai_just_st, m_ig_json_st))
556
 
557
  with gr.Row():
558
  skip_btn = gr.Button("Skip Target A", size="lg")
 
627
  updates = []
628
 
629
  # 9 components per row to reset
630
+ # empty_row = [gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "", "", ""]
631
+ empty_row = [gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "", "", "", "", ""]
632
 
633
  if not a_list:
634
  prog_txt = t_text("**Progress:** No unannotated items found.", lang)
 
660
  gr.update(value=""), # conf_md
661
  gr.update(choices=[], value=None),
662
  gr.update(value=""), # just_box
663
+ "", "", "", "", "" # Reset the 5 hidden model states
664
  ])
665
  else:
666
  updates.extend(empty_row)
 
670
  def load_workspace(dom_a, pol_a, dom_b, pol_b, tar_col, ctx_col, hf_df, user_tag, lang):
671
  if not pol_a or not pol_b:
672
  err = t_text("Error: Select both policies.", lang)
673
+ return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*11)
674
 
675
  if tar_col == ctx_col:
676
  err = t_text("Error: Target and Context cannot be the same.", lang)
677
+ return [gr.update(value=err)] + [gr.skip()] * (14 + MAX_ROWS*11)
678
 
679
  df_a = DOMAIN_MAP[dom_a]
680
  df_b = DOMAIN_MAP[dom_b]
 
738
  b_val_eng = b_eng_list[i]
739
 
740
  # Row data length is now 8 elements: [b_text, rel_radio, conf_md, inter_dd, just_box, m_coarse, m_drill, m_conf]
741
+ rel = row_data[i*10 + 1]
742
+ inter = row_data[i*10 + 3]
743
+ just = row_data[i*10 + 4]
744
 
745
  # Extract the independent model predictions from hidden states
746
+ model_coarse = row_data[i*10 + 5]
747
+ model_drill = row_data[i*10 + 6]
748
+ model_conf = row_data[i*10 + 7]
749
+ ai_just = row_data[i*10 + 8] #
750
+ ig_json = row_data[i*10 + 9]
751
 
752
  has_rel = bool(rel)
753
  has_inter = bool(inter)
 
774
  "Model_Coarse_Label": model_coarse, # Log model prediction
775
  "Model_Drill_Down_Label": model_drill, # Log model drill
776
  "Model_Confidences": model_conf, # Log model JSON confidence
777
+ "AI_Justification": ai_just, # Log pure AI Thoughts
778
+ "IG_JSON": ig_json, # Log Captum Gradients
779
+ "Coherence_Label": rel,
780
  "Drill_Down_Label": inter, # Log User prediction
781
  "Justification": just.strip(),
782
  "AnnotatorUsername": user_tag
783
  })
784
+
785
 
786
  if new_rows:
787
  new_df = pd.DataFrame(new_rows)
 
838
  first_pass_outputs = []
839
 
840
  # Unpack 9 items per row
841
+ # for container, b, r, c_md, inter, j, m_co, m_dr, m_cf in eval_rows:
842
+ # row_outputs.extend([container, b, r, c_md, inter, j, m_co, m_dr, m_cf])
843
+ # row_inputs.extend([b, r, c_md, inter, j, m_co, m_dr, m_cf])
844
+ # first_pass_outputs.extend([r, c_md, inter, m_co, m_dr, m_cf])
845
+ for container, b, r, c_md, inter, j, m_co, m_dr, m_cf, m_ai_j, m_ig_j in eval_rows:
846
+ row_outputs.extend([container, b, r, c_md, inter, j, m_co, m_dr, m_cf, m_ai_j, m_ig_j])
847
+ row_inputs.extend([b, r, c_md, inter, j, m_co, m_dr, m_cf, m_ai_j, m_ig_j])
848
  first_pass_outputs.extend([r, c_md, inter, m_co, m_dr, m_cf])
849
 
850
  lang_selector.change(