Spaces:
Build error
Build error
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| from uuid import uuid4 | |
| from datasets import load_dataset, Audio | |
| from collections import Counter | |
| import numpy as np | |
| from configs import configs | |
| from clients import backend, logger | |
| from backend.helpers import get_random_session_samples | |
| ds_name_map = { | |
| "stressBench": "iyosha-huji/stressBench", | |
| "StressPresso": "slprl/StressPresso", | |
| } | |
| if configs.DS_NAME not in ds_name_map: | |
| raise ValueError( | |
| f"Invalid DS_NAME {configs.DS_NAME}. Must be one of {list(ds_name_map.keys())}" | |
| ) | |
| # force redownloading the dataset to get the latest version | |
| dataset = load_dataset( | |
| ds_name_map[configs.DS_NAME], | |
| token=configs.HF_API_TOKEN, | |
| # download_mode="force_redownload", | |
| )["test"] | |
| dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
| INSTRUCTIONS = """<div align='center'>You are given an audio sample and a question with 2 answer options.\n\nListen to the audio and select the correct answer from the options below.\n\n<b>Note:</b> The question is the same for all samples, but the audio and the corresponding answers change.</div>""" | |
| indices_subdir_map = { | |
| "stressBench": "stressbench", | |
| "StressPresso": "stresspresso", | |
| } | |
| with open( | |
| Path(__file__).parent | |
| / f"data/stage_indices.json" | |
| ) as f: | |
| STAGE_SPLITS = json.load(f) | |
| def human_eval_tab(): | |
| with gr.Tab(label="Evaluation"): | |
| # ==== State ==== | |
| i = gr.State(-1) | |
| selected_answer = gr.State(None) | |
| answers_dict = gr.State({}) | |
| logged_in = gr.State(False) | |
| session_id = gr.State(None) | |
| user_name = gr.State(None) | |
| session_sample_indices = gr.State([]) | |
| # === Login UI === | |
| with gr.Group(visible=True) as login_group: | |
| gr.Markdown("### 🔐 Login to Continue") | |
| with gr.Row(): | |
| username = gr.Text(label="Username", placeholder="Enter username") | |
| password = gr.Text( | |
| label="Password", type="password", placeholder="Enter password" | |
| ) | |
| login_error = gr.Markdown( | |
| "\u274c Incorrect login, try again. Enter username and password.", | |
| visible=False, | |
| ) | |
| login_btn = gr.Button("Login") | |
| def login(usr, p): | |
| if p == configs.USER_PASSWORD and usr.strip() != "": | |
| new_session_id = str(uuid4()) | |
| sample_indices, stage = get_random_session_samples( | |
| backend, dataset, STAGE_SPLITS, usr, num_samples=17 | |
| ) | |
| logger.info(f"Session ID: {new_session_id}, Stage: {stage}") | |
| return ( | |
| True, | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| new_session_id, | |
| sample_indices, | |
| usr, | |
| ) | |
| else: | |
| return ( | |
| False, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| None, | |
| [], | |
| None, | |
| ) | |
| # === Login Button === | |
| login_btn.click( | |
| fn=login, | |
| inputs=[username, password], | |
| outputs=[ | |
| logged_in, | |
| login_group, | |
| login_error, | |
| session_id, | |
| session_sample_indices, | |
| user_name, | |
| ], | |
| ) | |
| # === UI Elements === | |
| next_btn = gr.Button("Start", visible=False) | |
| prev_btn = gr.Button("Previous Sample", visible=False) | |
| warning_msg = gr.Markdown( | |
| "<span style='color:red;'>\u26a0\ufe0f Please select an answer before continuing.</span>", | |
| visible=False, | |
| ) | |
| with gr.Group(visible=False) as app_group: | |
| with gr.Group(): | |
| gr.Markdown("<div align='center'><big><b>Instructions</b></big></div>") | |
| gr.Markdown(INSTRUCTIONS) | |
| with gr.Group(visible=False) as question_group: | |
| with gr.Row(show_progress=True): | |
| with gr.Column(variant="compact"): | |
| sample_info = gr.Markdown() | |
| gr.Markdown("**Question:**") | |
| question_md = gr.Markdown() | |
| radio = gr.Radio(label="Answer:", interactive=True) | |
| with gr.Column(variant="compact"): | |
| audio_output = gr.Audio( | |
| interactive=False, type="numpy", label="Audio:" | |
| ) | |
| with gr.Group( | |
| visible=False, elem_id="final_page" | |
| ) as final_group: # Final page, not visible until the end | |
| gr.Markdown( | |
| """ | |
| # 🎉 Thanks for your help! | |
| You helped moving science forward 🤓 | |
| Your responses have been recorded. | |
| You may now close this tab. | |
| """ | |
| ) | |
| # === Logic === | |
| def update_ui(i, answers, session_sample_indices): | |
| if i == -1: # We haven't started yet | |
| return ( | |
| gr.update(visible=False), | |
| "", | |
| "", | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| None, | |
| ) | |
| # show the question | |
| true_index = session_sample_indices[i] | |
| sample = dataset[true_index] | |
| audio_data = (sample["audio"]["sampling_rate"], sample["audio"]["array"]) | |
| previous_answer = answers.get(i, None) | |
| return ( | |
| gr.update(visible=True), | |
| f"<div align='center'>Sample <b>{i+1}</b> out of <b>{len(session_sample_indices)}</b></div>", | |
| "Out of the following answers, according to the speaker's stressed words, what is most likely the underlying intention of the speaker?", | |
| gr.update(value=audio_data), | |
| gr.update( | |
| choices=sample["possible_answers"], | |
| value=previous_answer, | |
| ), | |
| previous_answer, | |
| ) | |
| def update_next_index( | |
| i, answer, answers, session_id, session_sample_indices, user_name | |
| ): | |
| if answer is None and i != -1: # if no answer is selected | |
| # show warning message | |
| return ( | |
| gr.update(), | |
| gr.update(visible=True), | |
| gr.update(), | |
| answers, | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ) | |
| if answer: # if an answer is selected | |
| # save the answer to the backend | |
| answers[i] = answer | |
| true_index = session_sample_indices[i] | |
| sample = dataset[true_index] | |
| interp_id = sample["interpretation_id"] | |
| trans_id = sample["transcription_id"] | |
| user_id = session_id | |
| user_name_str = user_name or "anonymous" | |
| label = sample["label"] | |
| gt = sample["possible_answers"][label] | |
| is_correct = gt == answer | |
| logger.info( | |
| "saving answer to backend", | |
| context={ | |
| "i": true_index, | |
| "interp_id": interp_id, | |
| "answer": answer, | |
| "user_id": user_id, | |
| "is_correct": is_correct, | |
| }, | |
| ) | |
| if not backend.update_row( | |
| true_index, interp_id, user_id, answer, is_correct=is_correct | |
| ): | |
| backend.add_row( | |
| true_index, | |
| interp_id, | |
| trans_id, | |
| user_id, | |
| answer, | |
| user_name_str, | |
| is_correct=is_correct, | |
| ) | |
| if i + 1 == len(session_sample_indices): # Last question just answered | |
| return ( | |
| -1, # reset i to stop showing question | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| answers, | |
| gr.update(visible=True), # show final page | |
| gr.update(visible=False), # hide previous button | |
| ) | |
| # go to the next question | |
| new_i = i + 1 if i + 1 < len(session_sample_indices) else 0 | |
| return ( | |
| new_i, | |
| gr.update(visible=False), | |
| gr.update(value="Submit answer and go to Next"), | |
| answers, | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ) | |
| def update_prev_index(i): | |
| # prevent goint back in the first question and first page | |
| if i <= 0: | |
| return i, gr.update(visible=False) | |
| # go back to the previous question | |
| else: | |
| return i - 1, gr.update(visible=False) | |
| def answer_change_callback(answer, i, answers): | |
| answers[i] = answer | |
| return answer, answers | |
| def login_callback(logged_in): | |
| return ( | |
| ( | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| if logged_in | |
| else ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| ) | |
| # === Events === | |
| next_btn.click( | |
| update_next_index, | |
| [ | |
| i, | |
| selected_answer, | |
| answers_dict, | |
| session_id, | |
| session_sample_indices, | |
| user_name, | |
| ], | |
| [i, warning_msg, next_btn, answers_dict, final_group, prev_btn], | |
| ) | |
| prev_btn.click(update_prev_index, i, [i, warning_msg]) | |
| i.change( | |
| update_ui, | |
| [i, answers_dict, session_sample_indices], | |
| [ | |
| question_group, | |
| sample_info, | |
| question_md, | |
| audio_output, | |
| radio, | |
| selected_answer, | |
| ], | |
| ) | |
| radio.change( | |
| answer_change_callback, | |
| [radio, i, answers_dict], | |
| [selected_answer, answers_dict], | |
| ) | |
| logged_in.change( | |
| login_callback, logged_in, [app_group, next_btn, prev_btn, warning_msg] | |
| ) | |
| def compute_random_sampled_accuracy(df, dataset, n_rounds=100, seed=42): | |
| rng = np.random.default_rng(seed) | |
| # Filter to interpretation_ids with at least 3 user answers | |
| counts = df.groupby("interpretation_id")["user_id"].nunique() | |
| eligible_ids = set(counts[counts >= 3].index) | |
| # Group answers by interpretation_id | |
| grouped = df[df["interpretation_id"].isin(eligible_ids)].groupby( | |
| "interpretation_id" | |
| ) | |
| all_scores = [] | |
| total_answered_per_round = [] | |
| for _ in range(n_rounds): | |
| correct = 0 | |
| total = 0 | |
| for interp_id, group in grouped: | |
| if group.empty: | |
| continue | |
| # Randomly pick one row | |
| row = group.sample(1, random_state=rng.integers(1e6)).iloc[0] | |
| answer = row["answer"] | |
| idx = int(row["index_in_dataset"]) | |
| sample = dataset[idx] | |
| gt = sample["possible_answers"][sample["label"]] | |
| total += 1 | |
| if answer == gt: | |
| correct += 1 | |
| if total > 0: | |
| all_scores.append(correct / total) | |
| total_answered_per_round.append(total) | |
| if all_scores: | |
| mean_acc = np.mean(all_scores) | |
| mean_total = int(np.mean(total_answered_per_round)) | |
| std_acc = np.std(all_scores, ddof=1) # sample std | |
| ci_95 = 1.96 * std_acc / np.sqrt(n_rounds) | |
| return mean_acc, std_acc, mean_total, ci_95 | |
| return None, None, 0, None | |
| def get_admin_tab(): | |
| with gr.Tab("Admin Console"): | |
| admin_password = gr.Text(label="Enter Admin Password", type="password") | |
| check_btn = gr.Button("Enter") | |
| error_box = gr.Markdown("", visible=False) | |
| output_box = gr.Markdown("", visible=False) | |
| def calculate_majority_vote_accuracy(pw): | |
| if pw != configs.ADMIN_PASSWORD: | |
| return gr.update( | |
| visible=True, value="❌ Incorrect password." | |
| ), gr.update(visible=False) | |
| df = backend.get_all_rows() | |
| if df.empty: | |
| return gr.update(visible=True, value="No data available."), gr.update( | |
| visible=False | |
| ) | |
| # Majority vote per interpretation_id | |
| majority_answers = {} | |
| for interp_id, group in df.groupby("interpretation_id"): | |
| answer_counts = Counter(group["answer"]) | |
| if answer_counts: | |
| majority_answers[interp_id] = answer_counts.most_common(1)[0][0] | |
| counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() | |
| total_answers = len(df) | |
| users_count = df["user_id"].nunique() | |
| stage_acc = {} | |
| stage_completes = {} | |
| stage_counts = {} | |
| stage_remaining = {} | |
| # global_correct = 0 | |
| # global_total = 0 | |
| for stage in ["stage1", "stage2", "stage3"]: | |
| correct, total = 0, 0 | |
| complete = 0 | |
| for i in STAGE_SPLITS[stage]: | |
| sample = dataset[i] | |
| interp_id = sample["interpretation_id"] | |
| label = sample["label"] | |
| gt = sample["possible_answers"][label] | |
| n = counts.get(interp_id, 0) | |
| if n >= 3: | |
| complete += 1 | |
| if interp_id in majority_answers: | |
| pred = majority_answers[interp_id] | |
| total += 1 | |
| if pred == gt: | |
| correct += 1 | |
| stage_counts[stage] = len(STAGE_SPLITS[stage]) | |
| stage_completes[stage] = complete | |
| stage_remaining[stage] = 3 * len(STAGE_SPLITS[stage]) - sum( | |
| counts.get(dataset[i]["interpretation_id"], 0) | |
| for i in STAGE_SPLITS[stage] | |
| ) | |
| if complete == len(STAGE_SPLITS[stage]): | |
| acc = correct / total if total > 0 else 0 | |
| stage_acc[stage] = (acc, correct, total) | |
| else: | |
| stage_acc[stage] = None # not shown yet | |
| # Determine active stage | |
| if stage_completes["stage1"] < stage_counts["stage1"]: | |
| current_stage = "Stage 1" | |
| elif stage_completes["stage2"] < stage_counts["stage2"]: | |
| current_stage = "Stage 2" | |
| else: | |
| current_stage = "Stage 3" | |
| # Majority Vote Accuracy Section | |
| agg_lines = [] | |
| if stage_acc["stage1"]: | |
| acc1, c1, t1 = stage_acc["stage1"] | |
| agg_lines.append(f"- **Stage 1:** {acc1:.2%} ({c1}/{t1})") | |
| if stage_acc["stage2"]: | |
| acc2, c2, t2 = stage_acc["stage2"] | |
| agg_lines.append( | |
| f"- **Stage 1+2:** {(c1 + c2) / (t1 + t2):.2%} ({c1 + c2}/{t1 + t2})" | |
| ) | |
| if stage_acc["stage3"]: | |
| acc3, c3, t3 = stage_acc["stage3"] | |
| agg_lines.append( | |
| f"- **All Stages:** {(c1 + c2 + c3) / (t1 + t2 + t3):.2%} ({c1 + c2 + c3}/{t1 + t2 + t3})" | |
| ) | |
| agg_msg = "\n".join(agg_lines) if agg_lines else "No completed stages yet." | |
| # Compute random-sampled accuracy | |
| n_rounds = 100 | |
| rand_acc, rand_std, rand_total, rand_ci = compute_random_sampled_accuracy( | |
| df, dataset, n_rounds=n_rounds | |
| ) | |
| # Random-sampled Accuracy | |
| if rand_acc is not None: | |
| rand_acc_msg = ( | |
| f"**Accuracy:** {rand_acc:.2%} ± {rand_ci:.2%} (95% CI)\n\n" | |
| f"Standard deviation: {rand_std:.2%}\n\n" | |
| f"Samples used: {rand_total} × {n_rounds} rounds" | |
| ) | |
| else: | |
| rand_acc_msg = "Random sampling failed (no data)." | |
| correct = 0 | |
| total = 0 | |
| for _, row in df.iterrows(): | |
| idx = int(row["index_in_dataset"]) | |
| if idx >= len(dataset): | |
| continue # skip out-of-range | |
| sample = dataset[idx] | |
| gt_answer = sample["possible_answers"][sample["label"]] | |
| if row["answer"] == gt_answer: | |
| correct += 1 | |
| total += 1 | |
| overall_acc = correct / total if total > 0 else None | |
| if overall_acc is not None: | |
| overall_acc_msg = ( | |
| f"Overall Accuracy: {overall_acc:.2%} ({correct}/{total})" | |
| ) | |
| else: | |
| overall_acc_msg = "No data available." | |
| # Final message (no indentation!) | |
| msg = f""" | |
| ## ✅ Accuracy Summary | |
| ### Overall Accuracy | |
| {overall_acc_msg} | |
| --- | |
| ### Majority Vote | |
| {agg_msg} | |
| --- | |
| ### Random-Sampled Accuracy | |
| {rand_acc_msg} | |
| --- | |
| ## 📊 Answer Progress | |
| - **Total answers submitted:** {total_answers} | |
| - **Answers to go (global):** {300 - total_answers} | |
| - **Unique users:** {users_count} | |
| --- | |
| ## 🧱 Stage Breakdown | |
| | Stage | Completed | Total | Remaining Answers | | |
| |-------|-----------|--------|-------------------| | |
| | 1 | {stage_completes['stage1']} / {stage_counts['stage1']} | {stage_counts['stage1']} | {stage_remaining['stage1']} | | |
| | 2 | {stage_completes['stage2']} / {stage_counts['stage2']} | {stage_counts['stage2']} | {stage_remaining['stage2']} | | |
| | 3 | {stage_completes['stage3']} / {stage_counts['stage3']} | {stage_counts['stage3']} | {stage_remaining['stage3']} | | |
| **➡️ Current Active Stage:** {current_stage} | |
| """ | |
| return gr.update(visible=False), gr.update(visible=True, value=msg) | |
| check_btn.click( | |
| fn=calculate_majority_vote_accuracy, | |
| inputs=admin_password, | |
| outputs=[error_box, output_box], | |
| ) | |
| # App UI | |
| with gr.Blocks() as demo: | |
| human_eval_tab() | |
| get_admin_tab() | |
| demo.launch() | |