import os import re import json import uuid import random from pathlib import Path from datetime import datetime, timezone import gradio as gr from huggingface_hub import HfApi # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- STUDY_ROOT = Path("avg_study") PROMPT_JSON = "prompts_mapping.json" HF_TOKEN = os.getenv("HF_TOKEN") RESULTS_REPO_ID = os.getenv("RESULTS_REPO_ID") api = HfApi(token=HF_TOKEN) if HF_TOKEN else None LOCAL_RESULTS_DIR = Path("local_results") LOCAL_RESULTS_DIR.mkdir(exist_ok=True) # ----------------------------------------------------------------------------- # Load prompt mapping and build pairs # ----------------------------------------------------------------------------- def clean_prompt(prompt: str) -> str: if not prompt: return "" prompt = re.sub(r"\s*idx\d+\s+r\d+\s*$", "", prompt).strip() return prompt def format_prompt_html(prompt: str) -> str: safe_prompt = prompt.replace("&", "&").replace("<", "<").replace(">", ">") return f"
{safe_prompt}
" def load_prompt_mapping(mapping_file: str): mapping_path = Path(mapping_file) if not mapping_path.exists(): raise FileNotFoundError( f"Mapping file '{mapping_file}' not found. " "Ensure that 'prompts_mapping.json' exists in the repository." ) with open(mapping_path, "r", encoding="utf-8") as f: return json.load(f) def build_pairs_from_mapping(prompt_map): pairs = [] for idx, entry in prompt_map.items(): prompt = clean_prompt(entry.get("prompt", "")) category = entry.get("category", "") good_rel = entry.get("good_file") bad_rel = entry.get("bad_file") if not good_rel or not bad_rel: continue good_path = STUDY_ROOT / good_rel bad_path = STUDY_ROOT / bad_rel pairs.append( { "pair_id": idx, "category": category, "prompt": prompt, "good_path": str(good_path.resolve()), "bad_path": str(bad_path.resolve()), } ) return pairs prompt_map = load_prompt_mapping(PROMPT_JSON) PAIRS = build_pairs_from_mapping(prompt_map) print(f"Loaded {len(PAIRS)} pairs.") if PAIRS: print("First pair:", PAIRS[0]) if not PAIRS: raise RuntimeError( "No matched good/bad pairs found in mapping. " "Check that 'avg_study' contains matching files and that the mapping file lists them." ) # ----------------------------------------------------------------------------- # Helper functions # ----------------------------------------------------------------------------- def video_value(path_str: str) -> str: path = Path(path_str) print(f"Loading video: {path}") print(f"Exists: {path.exists()}") if not path.exists(): raise FileNotFoundError(f"Video file not found: {path}") return str(path.resolve()) def sample_pair(pair): good_on_A = random.choice([True, False]) if good_on_A: A_path = pair["good_path"] B_path = pair["bad_path"] A_label = "good" B_label = "bad" else: A_path = pair["bad_path"] B_path = pair["good_path"] A_label = "bad" B_label = "good" current = pair.copy() current.update( { "A_path": A_path, "B_path": B_path, "A_label": A_label, "B_label": B_label, "good_on_A": good_on_A, } ) return current def local_save(record: dict) -> str: local_name = LOCAL_RESULTS_DIR / f"{record['response_id']}.json" with open(local_name, "w", encoding="utf-8") as f: json.dump(record, f, ensure_ascii=False, indent=2) return str(local_name) def save_response(record: dict) -> dict: local_path = local_save(record) if not (RESULTS_REPO_ID and HF_TOKEN and api): return { "ok": True, "saved_local": True, "saved_hub": False, "message": f"Saved locally to {local_path}. Hub upload is not configured.", } tmp_name = f"/tmp/{record['response_id']}.json" with open(tmp_name, "w", encoding="utf-8") as f: json.dump(record, f, ensure_ascii=False, indent=2) remote_path = f"responses/{record['timestamp'][:10]}/{record['response_id']}.json" try: api.upload_file( path_or_fileobj=tmp_name, path_in_repo=remote_path, repo_id=RESULTS_REPO_ID, repo_type="dataset", ) return { "ok": True, "saved_local": True, "saved_hub": True, "message": f"Saved to dataset repo: {remote_path}", } except Exception as e: return { "ok": False, "saved_local": True, "saved_hub": False, "message": f"Saved locally to {local_path}, but Hub upload failed: {e}", } # ----------------------------------------------------------------------------- # Gradio callback functions # ----------------------------------------------------------------------------- def start_session(): try: participant_id = str(uuid.uuid4()) remaining_pairs = PAIRS.copy() pair = random.choice(remaining_pairs) remaining_pairs.remove(pair) current = sample_pair(pair) status_msg = "Study loaded. Read the prompt, watch both videos, and answer the questions below." return ( format_prompt_html(current["prompt"]), video_value(current["A_path"]), video_value(current["B_path"]), f"Participant ID: {participant_id}", participant_id, remaining_pairs, current, None, None, None, status_msg, gr.update(visible=False), gr.update(visible=True), ) except Exception as e: print(f"start_session error: {type(e).__name__}: {e}") return ( format_prompt_html(f"Error loading study: {type(e).__name__}: {e}"), None, None, "", None, [], None, None, None, None, f"Start failed: {type(e).__name__}: {e}", gr.update(visible=True), gr.update(visible=False), ) def submit_and_next(ans_prompt, ans_logic, ans_visual, participant_id, remaining_pairs, current): try: if current is None: return ( format_prompt_html("No current pair loaded."), None, None, remaining_pairs, current, "No current pair loaded." ) if ans_prompt is None or ans_logic is None or ans_visual is None: return ( format_prompt_html(current["prompt"]), video_value(current["A_path"]), video_value(current["B_path"]), remaining_pairs, current, "Please answer all questions before continuing." ) timestamp = datetime.now(timezone.utc).isoformat() record = { "response_id": str(uuid.uuid4()), "timestamp": timestamp, "participant_id": participant_id, "category": current["category"], "prompt": current["prompt"], "pair_id": current["pair_id"], "A_video": current["A_path"], "B_video": current["B_path"], "good_video": current["good_path"], "bad_video": current["bad_path"], "good_on_A": current["good_on_A"], "A_label": current["A_label"], "B_label": current["B_label"], "adherence_answer": ans_prompt, "logic_answer": ans_logic, "visual_quality_answer": ans_visual, } save_result = save_response(record) if not remaining_pairs: finish_msg = save_result["message"] + "\n\nYou have completed all pairs. Thank you!" return ( format_prompt_html("Study complete."), None, None, remaining_pairs, None, finish_msg ) pair = random.choice(remaining_pairs) remaining_pairs.remove(pair) next_pair = sample_pair(pair) status_msg = save_result["message"] if not save_result["saved_hub"]: status_msg += "\n\nYour response was still saved locally." return ( format_prompt_html(next_pair["prompt"]), video_value(next_pair["A_path"]), video_value(next_pair["B_path"]), remaining_pairs, next_pair, status_msg ) except Exception as e: print(f"submit_and_next error: {type(e).__name__}: {e}") return ( format_prompt_html(current["prompt"] if current else "Error"), None, None, remaining_pairs, current, f"Submit failed: {type(e).__name__}: {e}" ) # ----------------------------------------------------------------------------- # Build Gradio UI # ----------------------------------------------------------------------------- with gr.Blocks( title="Human Study", css=""" #prompt_box { width: 100%; text-align: center !important; font-size: 30px !important; font-weight: 900 !important; line-height: 1.35 !important; margin-top: 16px !important; margin-bottom: 22px !important; } """ ) as demo: gr.Markdown( """ # Human Study Read the prompt, watch both videos, and answer the questions below. Each response is saved automatically when you click **Submit and continue**. The study ends once you have gone through all available pairs. Thank you for your participation! """ ) participant_label = gr.Markdown() status = gr.Markdown() start_btn = gr.Button("Start Study", visible=True) prompt_md = gr.HTML("
Prompt will appear here.
") with gr.Row(): video_A = gr.Video(label="A", interactive=False, format="mp4") video_B = gr.Video(label="B", interactive=False, format="mp4") q_adherence = gr.Radio( choices=["A", "B", "No preference"], label="Which video has better adherence to the prompt?" ) q_logic = gr.Radio( choices=["A", "B", "No preference"], label="Which video has a more logical change of events?" ) q_visual = gr.Radio( choices=["A", "B", "No preference"], label="Which video has better visual quality?" ) next_btn = gr.Button("Submit and continue", visible=False) participant_id_state = gr.State() remaining_pairs_state = gr.State([]) current_pair_state = gr.State() start_btn.click( fn=start_session, outputs=[ prompt_md, video_A, video_B, participant_label, participant_id_state, remaining_pairs_state, current_pair_state, q_adherence, q_logic, q_visual, status, start_btn, next_btn, ] ) next_btn.click( fn=submit_and_next, inputs=[ q_adherence, q_logic, q_visual, participant_id_state, remaining_pairs_state, current_pair_state, ], outputs=[ prompt_md, video_A, video_B, remaining_pairs_state, current_pair_state, status, ] ).then( lambda: (None, None, None), outputs=[q_adherence, q_logic, q_visual] ) demo.launch()