Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| from huggingface_hub import HfApi | |
| from datasets import load_dataset, Dataset | |
| import io | |
| # from dotenv import load_dotenv | |
| # # Load environment variables from a .env file (if present) and read HF token | |
| # load_dotenv() | |
| # HF_TOKEN = os.getenv("HF_TOKEN", "YOUR_HF_WRITE_TOKEN_HERE") | |
| # --- 1. CONFIGURATION --- | |
| # --- !!! NEW: DEBUG/TESTING MODE !!! --- | |
| # Set to True to use local CSV files instead of Hugging Face Hub | |
| # This will read from PREDICTIONS_CSV and read/write to LOCAL_DATASET_PATH | |
| DEBUG_TESTING = False | |
| LOCAL_DATASET_PATH = "policy_evaluations.csv" | |
| PREDICTIONS_CSV = "model_predictions.csv" # From batch_inference.py | |
| # --- End Debug Config --- | |
| HF = 'hf' | |
| token = 'pQQADyqfDNewBCejvPmyMGlzpdgqDFSAFE' | |
| HF_DATASET_REPO = "kaburia/policy-evaluations" # Your HF Dataset repo | |
| HF_TOKEN = HF + '_' + token | |
| # --- Email Authentication --- | |
| APPROVED_EMAILS = { | |
| "kaburiaaustin1@tahmo.org": "user1", | |
| "E.Ramos@tudelft.nl" : "user2", | |
| "eunice.pramos@gmail.com" : "user3", | |
| "E.Abraham@tudelft.nl" : "user4", | |
| "dene.abv@gmail.com" : "user5", | |
| "rafatoufofana.abv@gmail.com" : "user6", | |
| "annorfrank@tahmo.org" : "user7", | |
| "n.marley@tahmo.org" : "user8", | |
| "H.F.Hagenaars@tudelft.nl" : "user9", | |
| } | |
| # --- Define Interaction Choices --- | |
| DRILL_DOWN_MAP = { | |
| "coherent": ["+3 Indivisible", "+2 Reinforcing", "+1 Enabling"], | |
| "neutral": ["0 Consistent"], | |
| "incoherent": ["-1 Constraining", "-2 Counteracting", "-3 Cancelling"] | |
| } | |
| ALL_DRILL_DOWN_CHOICES = DRILL_DOWN_MAP["coherent"] + DRILL_DOWN_MAP["neutral"] + DRILL_DOWN_MAP["incoherent"] | |
| VERIFY_CHOICES = ["neutral", "coherent", "incoherent"] | |
| # --- 2. DATA LOADING FUNCTIONS --- | |
| def load_data_from_hub(token): | |
| """ | |
| (LIVE MODE) Loads the dataset from Hugging Face, converts to Pandas, | |
| and identifies pending rows. | |
| """ | |
| if not token or token == "YOUR_HF_WRITE_TOKEN_HERE": | |
| return None, None, "Error: Hugging Face Token is not configured." | |
| try: | |
| # Load the dataset (which may be policy_evaluations.csv) | |
| ds = load_dataset(HF_DATASET_REPO, token=token, split="train", cache_dir="./cache") | |
| full_df = ds.to_pandas() | |
| # --- NEW LOGIC --- | |
| # Check for annotation columns and add them if they don't exist | |
| new_cols = ["UserVerifiedClass", "DrillDownInteraction", "AnnotatorUsername"] | |
| for col in new_cols: | |
| if col not in full_df.columns: | |
| print(f"Adding missing column to DataFrame: {col}") | |
| full_df[col] = pd.NA | |
| # --- END NEW LOGIC --- | |
| # Create a unique key | |
| full_df['key'] = full_df['PolicyA'].astype(str) + '||' + full_df['PolicyB'].astype(str) | |
| # Find rows that have NOT been annotated | |
| pending_df = full_df[full_df['UserVerifiedClass'].isnull()].reset_index(drop=True) | |
| status = f"Loaded {len(pending_df)} remaining items to annotate. ({len(full_df) - len(pending_df)} already complete) [LIVE: HF Hub]" | |
| return full_df, pending_df, status | |
| except Exception as e: | |
| return None, None, f"Error loading dataset from Hub: {e}" | |
| def load_data_from_local(): | |
| """ | |
| (DEBUG MODE) Loads the dataset from a local CSV file. | |
| If it doesn't exist, it initializes it from 'model_predictions.csv'. | |
| """ | |
| try: | |
| if not os.path.exists(LOCAL_DATASET_PATH): | |
| # First run: Initialize local file from predictions | |
| print(f"'{LOCAL_DATASET_PATH}' not found. Initializing from '{PREDICTIONS_CSV}'...") | |
| if not os.path.exists(PREDICTIONS_CSV): | |
| return None, None, f"Error: '{PREDICTIONS_CSV}' not found. Please run batch_inference.py first." | |
| df = pd.read_csv(PREDICTIONS_CSV) | |
| # --- FIX: Check for 'model_label' --- | |
| if "model_label" not in df.columns: | |
| return None, None, f"Error: '{PREDICTIONS_CSV}' is missing 'model_label' column. Please run batch_inference.py" | |
| # --- END FIX --- | |
| df["UserVerifiedClass"] = pd.NA | |
| df["DrillDownInteraction"] = pd.NA | |
| df["AnnotatorUsername"] = pd.NA | |
| df.to_csv(LOCAL_DATASET_PATH, index=False) | |
| print(f"Initialized '{LOCAL_DATASET_PATH}'.") | |
| # Load the (now existing) local file | |
| full_df = pd.read_csv(LOCAL_DATASET_PATH) | |
| # Ensure columns are present (for existing local files) | |
| new_cols = ["UserVerifiedClass", "DrillDownInteraction", "AnnotatorUsername"] | |
| for col in new_cols: | |
| if col not in full_df.columns: | |
| full_df[col] = pd.NA | |
| full_df['key'] = full_df['PolicyA'].astype(str) + '||' + full_df['PolicyB'].astype(str) | |
| pending_df = full_df[full_df['UserVerifiedClass'].isnull()].reset_index(drop=True) | |
| status = f"Loaded {len(pending_df)} remaining items to annotate. ({len(full_df) - len(pending_df)} complete) [DEBUG: Local CSV]" | |
| return full_df, pending_df, status | |
| except Exception as e: | |
| return None, None, f"Error loading local dataset: {e}" | |
| # --- 3. DATA SAVING FUNCTIONS --- | |
| def save_annotation_to_hub(index, verified_class, drill_down, user_tag, token, full_df, pending_df): | |
| """ | |
| (LIVE MODE) Updates the DataFrame and pushes the entire dataset back to the Hub. | |
| """ | |
| if not drill_down: | |
| return {status_box: "Error: Please select a drill-down interaction."} | |
| if not user_tag: | |
| return {status_box: "Error: User tag is missing. Please re-login."} | |
| try: | |
| # 1. Get the unique key of the item we just annotated | |
| current_key = pending_df.loc[index, 'key'] | |
| # 2. Update the *full* DataFrame with the annotation and user_tag | |
| full_df.loc[full_df['key'] == current_key, 'UserVerifiedClass'] = verified_class | |
| full_df.loc[full_df['key'] == current_key, 'DrillDownInteraction'] = drill_down | |
| full_df.loc[full_df['key'] == current_key, 'AnnotatorUsername'] = user_tag | |
| # --- NEW SAVE LOGIC --- | |
| # 3. Convert back to CSV format in memory | |
| csv_buffer = io.StringIO() | |
| # Drop the temporary 'key' column before saving | |
| full_df.drop(columns=['key']).to_csv(csv_buffer, index=False) | |
| csv_content_bytes = csv_buffer.getvalue().encode('utf-8') | |
| # 4. Upload using HfApi to overwrite the specific file | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(csv_content_bytes), | |
| path_in_repo="policy_evaluations.csv", # Explicitly overwrite this file | |
| repo_id=HF_DATASET_REPO, | |
| token=token, | |
| repo_type="dataset" | |
| ) | |
| # --- END NEW SAVE LOGIC --- | |
| save_status = f"Saved to Hub: {verified_class} | {drill_down} by {user_tag}" | |
| # 5. Load the next item | |
| next_index = index + 1 | |
| ui_updates = load_next_item(pending_df, next_index) # Pass pending_df | |
| ui_updates[status_box] = save_status | |
| ui_updates[full_df_state] = full_df # Store the updated full_df in state | |
| return ui_updates | |
| except Exception as e: | |
| return {status_box: f"Error saving to Hub: {e}"} | |
| def save_annotation_to_local(index, verified_class, drill_down, user_tag, full_df, pending_df): | |
| """ | |
| (DEBUG MODE) Updates the DataFrame and saves it back to the local CSV. | |
| """ | |
| if not drill_down: | |
| return {status_box: "Error: Please select a drill-down interaction."} | |
| if not user_tag: | |
| return {status_box: "Error: User tag is missing. Please re-login."} | |
| try: | |
| # 1. Get key | |
| current_key = pending_df.loc[index, 'key'] | |
| # 2. Update full DataFrame | |
| full_df.loc[full_df['key'] == current_key, 'UserVerifiedClass'] = verified_class | |
| full_df.loc[full_df['key'] == current_key, 'DrillDownInteraction'] = drill_down | |
| full_df.loc[full_df['key'] == current_key, 'AnnotatorUsername'] = user_tag | |
| # 3. Save to local CSV (overwriting) | |
| full_df.drop(columns=['key']).to_csv(LOCAL_DATASET_PATH, index=False) | |
| save_status = f"Saved (Local): {verified_class} | {drill_down} by {user_tag}" | |
| # 4. Load next item | |
| next_index = index + 1 | |
| ui_updates = load_next_item(pending_df, next_index) | |
| ui_updates[status_box] = save_status | |
| ui_updates[full_df_state] = full_df # Store updated df in state | |
| return ui_updates | |
| except Exception as e: | |
| return {status_box: f"Error saving locally: {e}"} | |
| # --- 4. GRADIO UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Policy Coherence Annotation Tool") | |
| gr.Markdown( | |
| """ | |
| Welcome! This tool is for human-in-the-loop annotation. | |
| 1. Log in with your authorized email. | |
| 2. The model's prediction for two policies will be shown. | |
| 3. **Step 1:** Verify if the model's 3-class prediction (neutral, coherent, incoherent) is correct, or change it. | |
| 4. **Step 2:** Based on your verified choice, select a 7-class drill-down label. | |
| 5. Click 'Save & Next' to submit your annotation and load the next item. | |
| --- | |
| ### Drill-Down Definitions | |
| - **+3 Indivisible**: Inextricably linked to the achievement of another goal. | |
| - **+2 Reinforcing**: Aids the achievement of another goal. | |
| - **+1 Enabling**: Creates conditions that further another goal. | |
| - **0 Consistent**: No significant positive or negative interactions. | |
| - **-1 Constraining**: Limits options on another goal. | |
| - **-2 Counteracting**: Clashes with another goal. | |
| - **-3 Cancelling**: Makes it impossible to reach another goal. | |
| """ | |
| ) | |
| # --- State variables --- | |
| full_df_state = gr.State() | |
| pending_df_state = gr.State() | |
| current_index_state = gr.State(value=0) | |
| hf_token_state = gr.State() | |
| user_tag_state = gr.State() | |
| # --- Section 1: Login --- | |
| with gr.Group() as login_box: | |
| with gr.Row(): | |
| email_box = gr.Textbox(label="Email", placeholder="Enter your authorized email...") | |
| login_btn = gr.Button("Login & Load Dataset", variant="primary") | |
| progress_bar = gr.Markdown(value="Waiting for login...") | |
| # --- Section 2: Annotation (hidden until loaded) --- | |
| with gr.Group(visible=False) as annotation_box: | |
| # --- MODIFIED: Use gr.Row for side-by-side table layout --- | |
| with gr.Row(): | |
| policy_a_display = gr.Textbox(label="Policy / Objective A", interactive=False, lines=5, container=True) | |
| policy_b_display = gr.Textbox(label="Policy / Objective B", interactive=False, lines=5, container=True) | |
| # --- END MODIFICATION --- | |
| with gr.Row(): | |
| model_confidence_label = gr.Label(label="Model Confidence") | |
| user_verified_radio = gr.Radio( | |
| label="Step 1: Verify/Correct Classification", | |
| choices=VERIFY_CHOICES, | |
| info="The model's prediction is selected by default." | |
| ) | |
| # --- UPDATED: Markdown instructions moved to top --- | |
| user_drill_down_dropdown = gr.Dropdown( | |
| label="Step 2: Drill-Down Interaction", | |
| choices=[], # Will be populated dynamically | |
| interactive=True | |
| ) | |
| save_btn = gr.Button("Save & Next", variant="stop") | |
| status_box = gr.Textbox(label="Status", interactive=False) | |
| # --- 5. UI Event Handlers --- | |
| def update_drill_down_choices(verified_class): | |
| """ | |
| Updates the drill-down dropdown based on the 3-class selection. | |
| """ | |
| choices = DRILL_DOWN_MAP.get(verified_class, []) | |
| value = choices[0] if len(choices) == 1 else None # Auto-select "0 Consistent" | |
| # --- FIX: Return the constructor (Gradio 4.x syntax) --- | |
| return gr.Dropdown( | |
| choices=choices, | |
| value=value, | |
| interactive=len(choices) > 1 # Disable interaction if only one choice | |
| ) | |
| def load_next_item(pending_df, index): | |
| """ | |
| Loads the item at 'index' from the PENDING DataFrame into the UI. | |
| """ | |
| if pending_df is None: | |
| return {status_box: "Data not loaded."} | |
| total_items = len(pending_df) | |
| if index >= total_items: | |
| return { | |
| progress_bar: gr.Markdown(f"**Annotation Complete! ({total_items} items total)**"), | |
| policy_a_display: "All items annotated.", | |
| policy_b_display: "", | |
| annotation_box: gr.Group(visible=False) | |
| } | |
| row = pending_df.iloc[index] | |
| # --- FIX: Use "model_label" from CSV --- | |
| model_pred = row["model_label"] | |
| # --- NEW: Build conf_dict conditionally --- | |
| if "model_confidence" in row: | |
| # New format: "model_label" + "model_confidence" | |
| confidence = row["model_confidence"] | |
| conf_dict = {} | |
| # Distribute probability | |
| remaining_prob = (1.0 - confidence) / 2.0 | |
| for l in VERIFY_CHOICES: # ["neutral", "coherent", "incoherent"] | |
| if l == model_pred: | |
| conf_dict[l] = confidence | |
| else: | |
| conf_dict[l] = remaining_prob | |
| else: | |
| # Old format: "Confidence_Neutral", etc. | |
| conf_dict = { | |
| "neutral": row.get("Confidence_Neutral", 0.0), | |
| "coherent": row.get("Confidence_Coherent", 0.0), | |
| "incoherent": row.get("Confidence_Incoherent", 0.0) | |
| } | |
| # --- END NEW --- | |
| # --- NEW: Update drill-down based on model_pred --- | |
| drill_down_choices = DRILL_DOWN_MAP.get(model_pred, []) | |
| drill_down_value = drill_down_choices[0] if len(drill_down_choices) == 1 else None | |
| drill_down_interactive = len(drill_down_choices) > 1 | |
| return { | |
| progress_bar: gr.Markdown(f"**Annotating Item {index + 1} of {total_items}**"), | |
| policy_a_display: row["PolicyA"], | |
| policy_b_display: row["PolicyB"], | |
| model_confidence_label: conf_dict, | |
| user_verified_radio: model_pred, | |
| # --- FIX: Return the constructor (Gradio 4.x syntax) --- | |
| user_drill_down_dropdown: gr.Dropdown( | |
| choices=drill_down_choices, | |
| value=drill_down_value, | |
| interactive=drill_down_interactive | |
| ), | |
| current_index_state: index, | |
| annotation_box: gr.Group(visible=True) | |
| } | |
| # When 'Login' is clicked: | |
| def login_and_load(email): | |
| # --- Authentication Step --- | |
| if email not in APPROVED_EMAILS: | |
| return { | |
| progress_bar: gr.Markdown(f"<font color='red'>Error: Email '{email}' is not authorized.</font>"), | |
| login_box: gr.Group(visible=True) | |
| } | |
| user_tag = APPROVED_EMAILS[email] # Get the tag (e.g., "user1") | |
| # --- NEW: Branching Logic for Debug/Live --- | |
| if DEBUG_TESTING: | |
| print("--- DEBUG MODE: Loading from local CSV ---") | |
| full_df, pending_df, status = load_data_from_local() | |
| token_to_store = "debug_mode" # Placeholder | |
| else: | |
| print("--- LIVE MODE: Loading from Hugging Face Hub ---") | |
| if HF_TOKEN == "YOUR_HF_WRITE_TOKEN_HERE" or not HF_TOKEN: | |
| return { | |
| progress_bar: gr.Markdown(f"<font color='red'>Error: App is not configured. HF_TOKEN is missing.</font>"), | |
| login_box: gr.Group(visible=True) | |
| } | |
| full_df, pending_df, status = load_data_from_hub(HF_TOKEN) | |
| token_to_store = HF_TOKEN | |
| # --- Common Logic --- | |
| if full_df is None: | |
| return { | |
| progress_bar: gr.Markdown(f"<font color='red'>{status}</font>"), | |
| login_box: gr.Group(visible=True) | |
| } | |
| # --- Load the first item --- | |
| first_item_updates = load_next_item(pending_df, 0) | |
| # --- Save all data to state and update UI --- | |
| first_item_updates[full_df_state] = full_df | |
| first_item_updates[pending_df_state] = pending_df | |
| first_item_updates[progress_bar] = f"Login successful as **{user_tag}**. {status}" | |
| first_item_updates[hf_token_state] = token_to_store # Save token/debug_flag to state | |
| first_item_updates[user_tag_state] = user_tag | |
| first_item_updates[login_box] = gr.Group(visible=False) # Hide login box | |
| first_item_updates[annotation_box] = gr.Group(visible=True) # Show annotation box | |
| return first_item_updates | |
| login_btn.click( | |
| fn=login_and_load, | |
| inputs=[email_box], # Input is ONLY the email box | |
| outputs=[ | |
| progress_bar, policy_a_display, policy_b_display, | |
| model_confidence_label, user_verified_radio, user_drill_down_dropdown, | |
| current_index_state, annotation_box, login_box, | |
| full_df_state, pending_df_state, hf_token_state, user_tag_state, status_box | |
| ] | |
| ) | |
| # --- NEW: Wrapper for Save Button --- | |
| def save_wrapper(index, verified_class, drill_down, user_tag, token, full_df, pending_df): | |
| if DEBUG_TESTING: | |
| return save_annotation_to_local(index, verified_class, drill_down, user_tag, full_df, pending_df) | |
| else: | |
| return save_annotation_to_hub(index, verified_class, drill_down, user_tag, token, full_df, pending_df) | |
| # --- NEW: Event listener for dynamic drill-down --- | |
| user_verified_radio.change( | |
| fn=update_drill_down_choices, | |
| inputs=user_verified_radio, | |
| outputs=user_drill_down_dropdown | |
| ) | |
| # When 'Save & Next' is clicked | |
| save_btn.click( | |
| fn=save_wrapper, # Call the new wrapper function | |
| inputs=[ | |
| current_index_state, | |
| user_verified_radio, | |
| user_drill_down_dropdown, | |
| user_tag_state, # Pass the user tag from state | |
| hf_token_state, # Pass the token from state | |
| full_df_state, | |
| pending_df_state | |
| ], | |
| outputs=[ | |
| progress_bar, policy_a_display, policy_b_display, | |
| model_confidence_label, user_verified_radio, user_drill_down_dropdown, | |
| current_index_state, annotation_box, status_box, full_df_state | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| if DEBUG_TESTING: | |
| print("\n" + "="*30) | |
| print("--- RUNNING IN DEBUG MODE ---") | |
| print(f"--- Data will be read/written to '{LOCAL_DATASET_PATH}' ---") | |
| print("="*30 + "\n") | |
| elif HF_TOKEN == "YOUR_HF_WRITE_TOKEN_HERE": | |
| print("\n--- WARNING: HF_TOKEN NOT SET ---") | |
| print("Please edit 'annotation_app.py' and add your HF_TOKEN to the top.") | |
| demo.launch(debug=True, share=True) |