| import gradio as gr |
| import pandas as pd |
| import os |
| from huggingface_hub import HfApi, InferenceClient |
| from datasets import load_dataset |
| import io |
| import ast |
| import re |
| import json |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| |
| DEBUG_TESTING = False |
| LOCAL_DATASET_PATH = "policy_evaluations.csv" |
| PREDICTIONS_CSV = "model_predictions.csv" |
|
|
| HF = 'hf' |
| token = 'GbeqFrdNnENcHiJtUnTKcAbVkneXrlOkHb' |
| HF_DATASET_REPO = "akaburia/policy-evaluations" |
| HF_TOKEN = HF + '_' + token |
|
|
|
|
| emails_env_string = os.environ.get("APPROVED_EMAILS", "{}") |
|
|
| |
| try: |
| APPROVED_EMAILS = json.loads(emails_env_string) |
| |
| |
| APPROVED_EMAILS = {k.lower(): v for k, v in APPROVED_EMAILS.items()} |
| |
| except (json.JSONDecodeError, TypeError) as e: |
| print(f"⚠️ Error parsing APPROVED_EMAILS from environment variables: {e}") |
| |
| APPROVED_EMAILS = { |
| "kaburiaaustin1@tahmo.org": "user1", |
| "e.ramos@tudelft.nl" : "user2", |
| "eunice.pramos@gmail.com" : "user3", |
| "e.abraham@tudelft.nl" : "user4", |
| "dene.abv@gmail.com" : "user5", |
| "rafatoufofana.abv@gmail.com" : "user6", |
| "annorfrank@tahmo.org" : "user7", |
| "n.marley@tahmo.org" : "user8", |
| "h.f.hagenaars@tudelft.nl" : "user9", |
| "kaburiaaustin1@gmail.com" : "user10", |
| "faridakone@gmail.com": "user11" |
| } |
| |
|
|
| DRILL_DOWN_MAP = { |
| "coherent": ["+3 Indivisible", "+2 Reinforcing", "+1 Enabling"], |
| "neutral": ["0 Consistent"], |
| "incoherent": ["-1 Constraining", "-2 Counteracting", "-3 Cancelling"] |
| } |
| VERIFY_CHOICES = ["neutral", "coherent", "incoherent"] |
|
|
| |
| llm_client = InferenceClient("Qwen/Qwen3-8B", token=HF_TOKEN, timeout=120) |
|
|
|
|
| |
| |
|
|
| def generate_llm_explanation(policy_a, policy_b, prediction, exp_str): |
| if not exp_str or pd.isna(exp_str): |
| return "No attribution data available to generate a summary." |
|
|
| try: |
| raw_data = ast.literal_eval(exp_str) |
| |
| |
| score_list = [] |
| for t, s in raw_data: |
| if isinstance(s, (int, float)): |
| clean_t = str(t).replace('Ġ', '').strip() |
| if clean_t: |
| score_list.append(f"'{clean_t}': {s:.3f}") |
| |
| formatted_scores = ", ".join(score_list) |
|
|
| |
| prompt = f"""You are an expert AI auditor interpreting an Explainable AI (XAI) output. |
| A RoBERTa sequence classification model evaluated two policies and predicted their relationship as: {prediction.upper()} |
| |
| Policy A: "{policy_a}" |
| Policy B: "{policy_b}" |
| |
| Below is the complete Integrated Gradients feature attribution data for the sequence. |
| - Positive scores (> 0) mean the word acted as supporting evidence, pushing the model TOWARD the {prediction.upper()} prediction. |
| - Negative scores (< 0) mean the word acted as contradicting evidence, pushing the model AWAY from the prediction. |
| - Scores near 0.000 are neutral filler. |
| |
| Token Scores: |
| [{formatted_scores}] |
| |
| Write a highly analytical, 2 to 3 sentence explanation of the model's reasoning. |
| Explicitly ground your explanation in the provided text and quote the specific words that have the highest positive and highest negative scores. |
| Explain WHY the model saw those specific words as alignment or contradiction based on the context of the two policies. Do not hallucinate. |
| |
| Auditor's Explanation:""" |
|
|
| |
| response = llm_client.chat_completion( |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=1500, |
| temperature=0.1, |
| top_p=0.9 |
| ) |
| |
| raw_output = response.choices[0].message.content.strip() |
|
|
| |
| match = re.search(r'<think>(.*?)</think>', raw_output, flags=re.DOTALL) |
| |
| if match: |
| think_content = match.group(1).strip() |
| |
| final_answer = raw_output.replace(match.group(0), '').strip() |
| |
| |
| formatted_response = f""" |
| <details style="margin-bottom: 12px; padding: 10px; background-color: #f3f4f6; border-radius: 6px; border: 1px solid #e5e7eb;"> |
| <summary style="cursor: pointer; font-weight: bold; color: #4b5563; outline: none;">🧠 Click to peek into the AI's thought process</summary> |
| <div style="margin-top: 10px; font-size: 0.9em; color: #6b7280; white-space: pre-wrap;">{think_content}</div> |
| </details> |
| |
| **Final Explanation:** |
| {final_answer} |
| """ |
| return formatted_response |
| else: |
| |
| return raw_output |
| |
|
|
| except Exception as e: |
| print(f"LLM API Error: {e}") |
| return f"⚠️ Could not generate AI summary. Error: {str(e)[:60]}" |
| |
|
|
| |
| def load_data_from_hub(token): |
| if not token: |
| return None, None, "Error: Hugging Face Token is not configured." |
| |
| try: |
| ds = load_dataset(HF_DATASET_REPO, token=token, split="train", cache_dir="./cache") |
| full_df = ds.to_pandas() |
| |
| new_cols = ["UserVerifiedClass", "DrillDownInteraction", "AnnotatorUsername"] |
| for col in new_cols: |
| if col not in full_df.columns: |
| full_df[col] = pd.NA |
|
|
| full_df['key'] = full_df['PolicyA'].astype(str) + '||' + full_df['PolicyB'].astype(str) |
| pending_df = full_df[full_df['UserVerifiedClass'].isnull()].reset_index(drop=True) |
| |
| status = f"Loaded {len(pending_df)} remaining items to annotate. ({len(full_df) - len(pending_df)} already complete) [LIVE: HF Hub]" |
| return full_df, pending_df, status |
| |
| except Exception as e: |
| return None, None, f"Error loading dataset from Hub: {e}" |
|
|
| def load_data_from_local(): |
| try: |
| if not os.path.exists(LOCAL_DATASET_PATH): |
| if not os.path.exists(PREDICTIONS_CSV): |
| return None, None, f"Error: '{PREDICTIONS_CSV}' not found. Please run batch_inference.py first." |
| |
| df = pd.read_csv(PREDICTIONS_CSV) |
| if "model_label" not in df.columns: |
| return None, None, f"Error: '{PREDICTIONS_CSV}' is missing 'model_label' column." |
| |
| df["UserVerifiedClass"] = pd.NA |
| df["DrillDownInteraction"] = pd.NA |
| df["AnnotatorUsername"] = pd.NA |
| df.to_csv(LOCAL_DATASET_PATH, index=False) |
|
|
| full_df = pd.read_csv(LOCAL_DATASET_PATH) |
| new_cols = ["UserVerifiedClass", "DrillDownInteraction", "AnnotatorUsername"] |
| for col in new_cols: |
| if col not in full_df.columns: |
| full_df[col] = pd.NA |
| |
| full_df['key'] = full_df['PolicyA'].astype(str) + '||' + full_df['PolicyB'].astype(str) |
| pending_df = full_df[full_df['UserVerifiedClass'].isnull()].reset_index(drop=True) |
| |
| status = f"Loaded {len(pending_df)} remaining items to annotate. [DEBUG: Local CSV]" |
| return full_df, pending_df, status |
| except Exception as e: |
| return None, None, f"Error loading local dataset: {e}" |
|
|
| |
| def save_annotation_to_hub(index, verified_class, drill_down, user_tag, token, full_df, pending_df): |
| if not drill_down: return {status_box: "Error: Please select a drill-down interaction."} |
| if not user_tag: return {status_box: "Error: User tag is missing. Please re-login."} |
| try: |
| current_key = pending_df.loc[index, 'key'] |
| full_df.loc[full_df['key'] == current_key, 'UserVerifiedClass'] = verified_class |
| full_df.loc[full_df['key'] == current_key, 'DrillDownInteraction'] = drill_down |
| full_df.loc[full_df['key'] == current_key, 'AnnotatorUsername'] = user_tag |
|
|
| csv_buffer = io.StringIO() |
| full_df.drop(columns=['key']).to_csv(csv_buffer, index=False) |
| csv_content_bytes = csv_buffer.getvalue().encode('utf-8') |
| |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=io.BytesIO(csv_content_bytes), |
| path_in_repo="policy_evaluations.csv", |
| repo_id=HF_DATASET_REPO, |
| token=token, |
| repo_type="dataset" |
| ) |
| |
| save_status = f"Saved to Hub: {verified_class} | {drill_down} by {user_tag}" |
| ui_updates = load_next_item(pending_df, index + 1) |
| ui_updates[status_box] = save_status |
| ui_updates[full_df_state] = full_df |
| return ui_updates |
| except Exception as e: |
| return {status_box: f"Error saving to Hub: {e}"} |
|
|
| def save_annotation_to_local(index, verified_class, drill_down, user_tag, full_df, pending_df): |
| if not drill_down: return {status_box: "Error: Please select a drill-down interaction."} |
| if not user_tag: return {status_box: "Error: User tag is missing. Please re-login."} |
| try: |
| current_key = pending_df.loc[index, 'key'] |
| full_df.loc[full_df['key'] == current_key, 'UserVerifiedClass'] = verified_class |
| full_df.loc[full_df['key'] == current_key, 'DrillDownInteraction'] = drill_down |
| full_df.loc[full_df['key'] == current_key, 'AnnotatorUsername'] = user_tag |
|
|
| full_df.drop(columns=['key']).to_csv(LOCAL_DATASET_PATH, index=False) |
| save_status = f"Saved (Local): {verified_class} | {drill_down} by {user_tag}" |
| ui_updates = load_next_item(pending_df, index + 1) |
| ui_updates[status_box] = save_status |
| ui_updates[full_df_state] = full_df |
| return ui_updates |
| except Exception as e: |
| return {status_box: f"Error saving locally: {e}"} |
|
|
| |
| def generate_heatmap_html(raw_data, policy_a_text): |
| if not raw_data or not isinstance(raw_data, list): |
| return "<i>No explainability data found for this row.</i>" |
|
|
| try: |
| scores = [abs(s) for t, s in raw_data if isinstance(s, (int, float))] |
| max_abs = max(scores) if scores else 0.001 |
|
|
| target_length = len(str(policy_a_text).replace(" ", "").strip()) |
| current_length = 0 |
| split_index = len(raw_data) - 1 |
|
|
| for i, (token, score) in enumerate(raw_data): |
| clean_t = str(token).replace('Ġ', '').replace(' ', '') |
| current_length += len(clean_t) |
| if current_length >= target_length - 2: |
| split_index = i |
| break |
| |
| policy_a_data = raw_data[:split_index + 1] |
| policy_b_data = raw_data[split_index + 1:] |
|
|
| def render_block(data_subset, title): |
| block_html = f""" |
| <div style='flex: 1; padding: 16px; background: #ffffff; border-radius: 8px; border: 1px solid #e5e7eb; box-shadow: 0 1px 2px rgba(0,0,0,0.05);'> |
| <h4 style='margin-top: 0; margin-bottom: 12px; color: #4b5563; font-family: sans-serif; border-bottom: 1px solid #e5e7eb; padding-bottom: 8px;'>{title}</h4> |
| <div style='line-height: 2.2; font-size: 15px; font-family: sans-serif; white-space: pre-wrap;'> |
| """ |
| |
| for token, score in data_subset: |
| if not isinstance(score, (int, float)): continue |
| alpha = min(abs(score) / max_abs, 1.0) |
| |
| if score > 0: bg_color = f"rgba(16, 185, 129, {alpha})" |
| else: bg_color = f"rgba(239, 68, 68, {alpha})" |
|
|
| if abs(score) < 0.02: bg_color = "transparent" |
|
|
| clean_token = str(token) |
| margin_left = "0px" |
| |
| if clean_token.startswith(' ') or clean_token.startswith('Ġ'): |
| clean_token = clean_token[1:] |
| margin_left = "4px" |
| |
| span = f"<span title='Score: {score:.4f}' style='background-color: {bg_color}; border-radius: 3px; padding: 2px 0px; margin-left: {margin_left}; display: inline-block;'>{clean_token}</span>" |
| block_html += span |
| |
| block_html += "</div></div>" |
| return block_html |
|
|
| html_str = f""" |
| <div style='margin-bottom: 15px; font-family: sans-serif;'> |
| <div style='display: flex; justify-content: space-between; width: 100%; max-width: 450px; font-size: 12px; font-weight: bold; color: #4b5563; margin-bottom: 4px;'> |
| <span>Strong Negative (-)</span> |
| <span>Neutral</span> |
| <span>Strong Positive (+)</span> |
| </div> |
| <div style='display: flex; align-items: center; width: 100%; max-width: 450px; height: 14px; background: linear-gradient(to right, rgba(239, 68, 68, 1), rgba(255,255,255,0), rgba(16, 185, 129, 1)); border-radius: 4px; border: 1px solid #d1d5db;'></div> |
| </div> |
| |
| <div style='display: flex; gap: 20px;'> |
| {render_block(policy_a_data, 'Policy A Impact')} |
| {render_block(policy_b_data, 'Policy B Impact')} |
| </div> |
| """ |
| return html_str |
| except Exception as e: |
| print(f"Error generating HTML: {e}") |
| return "<i>Error rendering heatmap.</i>" |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# Policy Coherence Annotation Tool") |
| gr.Markdown( |
| """ |
| Welcome! This tool is for human-in-the-loop annotation. |
| 1. Log in with your authorized email. |
| 2. The model's prediction for two policies will be shown. |
| 3. **Step 1:** Verify if the model's 3-class prediction is correct, or change it. |
| 4. **Step 2:** Select a 7-class drill-down label. |
| 5. Click 'Save & Next'. If you are unsure, you can click 'Skip & Next'. |
| """ |
| ) |
|
|
| full_df_state = gr.State() |
| pending_df_state = gr.State() |
| current_index_state = gr.State(value=0) |
| hf_token_state = gr.State() |
| user_tag_state = gr.State() |
|
|
| with gr.Group() as login_box: |
| with gr.Row(): |
| email_box = gr.Textbox(label="Email", placeholder="Enter your authorized email...") |
| login_btn = gr.Button("Login & Load Dataset", variant="primary") |
| progress_bar = gr.Markdown(value="Waiting for login...") |
|
|
| with gr.Group(visible=False) as annotation_box: |
| with gr.Row(): |
| with gr.Group(): |
| gr.Markdown("### 📄 Policy / Objective A") |
| policy_a_display = gr.Textbox(show_label=False, interactive=False, lines=4, container=False) |
| with gr.Group(): |
| gr.Markdown("### 📄 Policy / Objective B") |
| policy_b_display = gr.Textbox(show_label=False, interactive=False, lines=4, container=False) |
|
|
| gr.Markdown("### 🔍 Model Reasoning (Explainability Heatmap)") |
| gr.Markdown("This shows the sequential flow of the text. Subwords are stitched back together for readability. **Hover your mouse over any word to see its exact mathematical impact score.**") |
| |
| explanation_display = gr.HTML(label="Token Attributions") |
| |
| |
| with gr.Accordion("🤖 Click to view AI-Generated Summary", open=False) as ai_accordion: |
| generate_ai_btn = gr.Button("✨ Ask AI to Summarize Reasoning", size="sm") |
| ai_summary_display = gr.Markdown("Click the button above to generate a summary of the model's reasoning.") |
| gr.Markdown("---") |
| |
|
|
| with gr.Row(): |
| model_confidence_label = gr.Label(label="Model Confidence") |
| user_verified_radio = gr.Radio( |
| label="Step 1: Verify/Correct Classification", |
| choices=VERIFY_CHOICES, |
| info="The model's prediction is selected by default." |
| ) |
| |
| user_drill_down_dropdown = gr.Dropdown( |
| label="Step 2: Drill-Down Interaction", |
| choices=[], |
| interactive=True |
| ) |
| |
| with gr.Row(): |
| skip_btn = gr.Button("Skip & Next (Unsure)") |
| save_btn = gr.Button("Save & Next", variant="primary") |
| |
| status_box = gr.Textbox(label="Status", interactive=False) |
|
|
| |
| |
| |
| def fetch_ai_summary(index, pending_df): |
| if pending_df is None or index >= len(pending_df): |
| return "No data available." |
| row = pending_df.iloc[index] |
| return generate_llm_explanation( |
| policy_a=row["PolicyA"], |
| policy_b=row["PolicyB"], |
| prediction=row["model_label"], |
| exp_str=row.get("explanation_data") |
| ) |
| |
| generate_ai_btn.click( |
| fn=fetch_ai_summary, |
| inputs=[current_index_state, pending_df_state], |
| outputs=[ai_summary_display] |
| ) |
| |
|
|
| def update_drill_down_choices(verified_class): |
| choices = DRILL_DOWN_MAP.get(verified_class, []) |
| value = choices[0] if len(choices) == 1 else None |
| return gr.Dropdown(choices=choices, value=value, interactive=len(choices) > 1) |
|
|
| def load_next_item(pending_df, index): |
| if pending_df is None: return {status_box: "Data not loaded."} |
| |
| total_items = len(pending_df) |
| if index >= total_items: |
| return { |
| progress_bar: gr.Markdown(f"**Annotation Complete! ({total_items} items total)**"), |
| policy_a_display: "All items annotated.", |
| policy_b_display: "", |
| explanation_display: "<div>Done.</div>", |
| annotation_box: gr.Group(visible=False) |
| } |
| |
| row = pending_df.iloc[index] |
| model_pred = row["model_label"] |
| |
| if "model_confidence" in row: |
| confidence = row["model_confidence"] |
| conf_dict = {} |
| remaining_prob = (1.0 - confidence) / 2.0 |
| for l in VERIFY_CHOICES: |
| if l == model_pred: conf_dict[l] = confidence |
| else: conf_dict[l] = remaining_prob |
| else: |
| conf_dict = { |
| "neutral": row.get("Confidence_Neutral", 0.0), |
| "coherent": row.get("Confidence_Coherent", 0.0), |
| "incoherent": row.get("Confidence_Incoherent", 0.0) |
| } |
| |
| drill_down_choices = DRILL_DOWN_MAP.get(model_pred, []) |
| drill_down_value = drill_down_choices[0] if len(drill_down_choices) == 1 else None |
| drill_down_interactive = len(drill_down_choices) > 1 |
| |
| html_output = "<i>No explainability data found for this row.</i>" |
| exp_str = row.get("explanation_data") |
| |
| if pd.notna(exp_str) and isinstance(exp_str, str) and exp_str.strip() != "": |
| try: |
| raw_data = ast.literal_eval(exp_str) |
| html_output = generate_heatmap_html(raw_data, row["PolicyA"]) |
| except Exception as e: |
| print(f"Failed to parse explanation string using ast: {e}") |
| html_output = "<i>Error parsing explanation data for this row.</i>" |
|
|
| return { |
| progress_bar: gr.Markdown(f"**Annotating Item {index + 1} of {total_items}**"), |
| policy_a_display: row["PolicyA"], |
| policy_b_display: row["PolicyB"], |
| explanation_display: html_output, |
| model_confidence_label: conf_dict, |
| user_verified_radio: model_pred, |
| user_drill_down_dropdown: gr.Dropdown(choices=drill_down_choices, value=drill_down_value, interactive=drill_down_interactive), |
| current_index_state: index, |
| |
| ai_summary_display: "Click the button above to generate a summary of the model's reasoning.", |
| ai_accordion: gr.Accordion(open=False), |
| |
| annotation_box: gr.Group(visible=True) |
| } |
| |
| def login_and_load(email): |
| |
| clean_email = email.strip().lower() |
| |
| if clean_email not in APPROVED_EMAILS: |
| return {progress_bar: gr.Markdown(f"<font color='red'>Error: Email '{email}' is not authorized.</font>"), login_box: gr.Group(visible=True)} |
| |
| user_tag = APPROVED_EMAILS[clean_email] |
| |
| if DEBUG_TESTING: |
| full_df, pending_df, status = load_data_from_local() |
| token_to_store = "debug_mode" |
| else: |
| if not HF_TOKEN: |
| return {progress_bar: gr.Markdown(f"<font color='red'>Error: HF_TOKEN is missing.</font>"), login_box: gr.Group(visible=True)} |
| full_df, pending_df, status = load_data_from_hub(HF_TOKEN) |
| token_to_store = HF_TOKEN |
| |
| if full_df is None: return {progress_bar: gr.Markdown(f"<font color='red'>{status}</font>"), login_box: gr.Group(visible=True)} |
| |
| first_item_updates = load_next_item(pending_df, 0) |
| first_item_updates[full_df_state] = full_df |
| first_item_updates[pending_df_state] = pending_df |
| first_item_updates[progress_bar] = f"Login successful as **{user_tag}**. {status}" |
| first_item_updates[hf_token_state] = token_to_store |
| first_item_updates[user_tag_state] = user_tag |
| first_item_updates[login_box] = gr.Group(visible=False) |
| first_item_updates[annotation_box] = gr.Group(visible=True) |
| |
| return first_item_updates |
|
|
| login_btn.click( |
| fn=login_and_load, |
| inputs=[email_box], |
| |
| outputs=[ |
| progress_bar, policy_a_display, policy_b_display, explanation_display, |
| model_confidence_label, user_verified_radio, user_drill_down_dropdown, |
| current_index_state, annotation_box, login_box, |
| full_df_state, pending_df_state, hf_token_state, user_tag_state, status_box, |
| ai_summary_display, ai_accordion |
| ] |
| ) |
| |
| def save_wrapper(index, verified_class, drill_down, user_tag, token, full_df, pending_df): |
| if DEBUG_TESTING: return save_annotation_to_local(index, verified_class, drill_down, user_tag, full_df, pending_df) |
| else: return save_annotation_to_hub(index, verified_class, drill_down, user_tag, token, full_df, pending_df) |
|
|
| def skip_item(index, pending_df): |
| ui_updates = load_next_item(pending_df, index + 1) |
| ui_updates[status_box] = f"Skipped item {index + 1}." |
| return ui_updates |
| |
| skip_btn.click( |
| fn=skip_item, |
| inputs=[current_index_state, pending_df_state], |
| outputs=[ |
| progress_bar, policy_a_display, policy_b_display, explanation_display, |
| model_confidence_label, user_verified_radio, user_drill_down_dropdown, |
| current_index_state, annotation_box, status_box, |
| ai_summary_display, ai_accordion |
| ] |
| ) |
|
|
| user_verified_radio.change( |
| fn=update_drill_down_choices, |
| inputs=user_verified_radio, |
| outputs=user_drill_down_dropdown |
| ) |
|
|
| save_btn.click( |
| fn=save_wrapper, |
| inputs=[ |
| current_index_state, user_verified_radio, user_drill_down_dropdown, |
| user_tag_state, hf_token_state, full_df_state, pending_df_state |
| ], |
| outputs=[ |
| progress_bar, policy_a_display, policy_b_display, explanation_display, |
| model_confidence_label, user_verified_radio, user_drill_down_dropdown, |
| current_index_state, annotation_box, status_box, full_df_state, |
| ai_summary_display, ai_accordion |
| ] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True) |
|
|