Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import uuid | |
| import random | |
| from pathlib import Path | |
| from datetime import datetime, timezone | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| # ----------------------------------------------------------------------------- | |
| # Configuration | |
| # ----------------------------------------------------------------------------- | |
| STUDY_ROOT = Path("avg_study") | |
| PROMPT_JSON = "prompts_mapping.json" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| RESULTS_REPO_ID = os.getenv("RESULTS_REPO_ID") | |
| api = HfApi(token=HF_TOKEN) if HF_TOKEN else None | |
| LOCAL_RESULTS_DIR = Path("local_results") | |
| LOCAL_RESULTS_DIR.mkdir(exist_ok=True) | |
| # ----------------------------------------------------------------------------- | |
| # Load prompt mapping and build pairs | |
| # ----------------------------------------------------------------------------- | |
| def clean_prompt(prompt: str) -> str: | |
| if not prompt: | |
| return "" | |
| prompt = re.sub(r"\s*idx\d+\s+r\d+\s*$", "", prompt).strip() | |
| return prompt | |
| def format_prompt_html(prompt: str) -> str: | |
| safe_prompt = prompt.replace("&", "&").replace("<", "<").replace(">", ">") | |
| return f"<div id='prompt_box'><strong>{safe_prompt}</strong></div>" | |
| def load_prompt_mapping(mapping_file: str): | |
| mapping_path = Path(mapping_file) | |
| if not mapping_path.exists(): | |
| raise FileNotFoundError( | |
| f"Mapping file '{mapping_file}' not found. " | |
| "Ensure that 'prompts_mapping.json' exists in the repository." | |
| ) | |
| with open(mapping_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def build_pairs_from_mapping(prompt_map): | |
| pairs = [] | |
| for idx, entry in prompt_map.items(): | |
| prompt = clean_prompt(entry.get("prompt", "")) | |
| category = entry.get("category", "") | |
| good_rel = entry.get("good_file") | |
| bad_rel = entry.get("bad_file") | |
| if not good_rel or not bad_rel: | |
| continue | |
| good_path = STUDY_ROOT / good_rel | |
| bad_path = STUDY_ROOT / bad_rel | |
| pairs.append( | |
| { | |
| "pair_id": idx, | |
| "category": category, | |
| "prompt": prompt, | |
| "good_path": str(good_path.resolve()), | |
| "bad_path": str(bad_path.resolve()), | |
| } | |
| ) | |
| return pairs | |
| prompt_map = load_prompt_mapping(PROMPT_JSON) | |
| PAIRS = build_pairs_from_mapping(prompt_map) | |
| print(f"Loaded {len(PAIRS)} pairs.") | |
| if PAIRS: | |
| print("First pair:", PAIRS[0]) | |
| if not PAIRS: | |
| raise RuntimeError( | |
| "No matched good/bad pairs found in mapping. " | |
| "Check that 'avg_study' contains matching files and that the mapping file lists them." | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Helper functions | |
| # ----------------------------------------------------------------------------- | |
| def video_value(path_str: str) -> str: | |
| path = Path(path_str) | |
| print(f"Loading video: {path}") | |
| print(f"Exists: {path.exists()}") | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Video file not found: {path}") | |
| return str(path.resolve()) | |
| def sample_pair(pair): | |
| good_on_A = random.choice([True, False]) | |
| if good_on_A: | |
| A_path = pair["good_path"] | |
| B_path = pair["bad_path"] | |
| A_label = "good" | |
| B_label = "bad" | |
| else: | |
| A_path = pair["bad_path"] | |
| B_path = pair["good_path"] | |
| A_label = "bad" | |
| B_label = "good" | |
| current = pair.copy() | |
| current.update( | |
| { | |
| "A_path": A_path, | |
| "B_path": B_path, | |
| "A_label": A_label, | |
| "B_label": B_label, | |
| "good_on_A": good_on_A, | |
| } | |
| ) | |
| return current | |
| def local_save(record: dict) -> str: | |
| local_name = LOCAL_RESULTS_DIR / f"{record['response_id']}.json" | |
| with open(local_name, "w", encoding="utf-8") as f: | |
| json.dump(record, f, ensure_ascii=False, indent=2) | |
| return str(local_name) | |
| def save_response(record: dict) -> dict: | |
| local_path = local_save(record) | |
| if not (RESULTS_REPO_ID and HF_TOKEN and api): | |
| return { | |
| "ok": True, | |
| "saved_local": True, | |
| "saved_hub": False, | |
| "message": f"Saved locally to {local_path}. Hub upload is not configured.", | |
| } | |
| tmp_name = f"/tmp/{record['response_id']}.json" | |
| with open(tmp_name, "w", encoding="utf-8") as f: | |
| json.dump(record, f, ensure_ascii=False, indent=2) | |
| remote_path = f"responses/{record['timestamp'][:10]}/{record['response_id']}.json" | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=tmp_name, | |
| path_in_repo=remote_path, | |
| repo_id=RESULTS_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| return { | |
| "ok": True, | |
| "saved_local": True, | |
| "saved_hub": True, | |
| "message": f"Saved to dataset repo: {remote_path}", | |
| } | |
| except Exception as e: | |
| return { | |
| "ok": False, | |
| "saved_local": True, | |
| "saved_hub": False, | |
| "message": f"Saved locally to {local_path}, but Hub upload failed: {e}", | |
| } | |
| # ----------------------------------------------------------------------------- | |
| # Gradio callback functions | |
| # ----------------------------------------------------------------------------- | |
| def start_session(): | |
| try: | |
| participant_id = str(uuid.uuid4()) | |
| remaining_pairs = PAIRS.copy() | |
| pair = random.choice(remaining_pairs) | |
| remaining_pairs.remove(pair) | |
| current = sample_pair(pair) | |
| status_msg = "Study loaded. Read the prompt, watch both videos, and answer the questions below." | |
| return ( | |
| format_prompt_html(current["prompt"]), | |
| video_value(current["A_path"]), | |
| video_value(current["B_path"]), | |
| f"Participant ID: {participant_id}", | |
| participant_id, | |
| remaining_pairs, | |
| current, | |
| None, | |
| None, | |
| None, | |
| status_msg, | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| ) | |
| except Exception as e: | |
| print(f"start_session error: {type(e).__name__}: {e}") | |
| return ( | |
| format_prompt_html(f"Error loading study: {type(e).__name__}: {e}"), | |
| None, | |
| None, | |
| "", | |
| None, | |
| [], | |
| None, | |
| None, | |
| None, | |
| None, | |
| f"Start failed: {type(e).__name__}: {e}", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| ) | |
| def submit_and_next(ans_prompt, ans_logic, ans_visual, participant_id, remaining_pairs, current): | |
| try: | |
| if current is None: | |
| return ( | |
| format_prompt_html("No current pair loaded."), | |
| None, | |
| None, | |
| remaining_pairs, | |
| current, | |
| "No current pair loaded." | |
| ) | |
| if ans_prompt is None or ans_logic is None or ans_visual is None: | |
| return ( | |
| format_prompt_html(current["prompt"]), | |
| video_value(current["A_path"]), | |
| video_value(current["B_path"]), | |
| remaining_pairs, | |
| current, | |
| "Please answer all questions before continuing." | |
| ) | |
| timestamp = datetime.now(timezone.utc).isoformat() | |
| record = { | |
| "response_id": str(uuid.uuid4()), | |
| "timestamp": timestamp, | |
| "participant_id": participant_id, | |
| "category": current["category"], | |
| "prompt": current["prompt"], | |
| "pair_id": current["pair_id"], | |
| "A_video": current["A_path"], | |
| "B_video": current["B_path"], | |
| "good_video": current["good_path"], | |
| "bad_video": current["bad_path"], | |
| "good_on_A": current["good_on_A"], | |
| "A_label": current["A_label"], | |
| "B_label": current["B_label"], | |
| "adherence_answer": ans_prompt, | |
| "logic_answer": ans_logic, | |
| "visual_quality_answer": ans_visual, | |
| } | |
| save_result = save_response(record) | |
| if not remaining_pairs: | |
| finish_msg = save_result["message"] + "\n\nYou have completed all pairs. Thank you!" | |
| return ( | |
| format_prompt_html("Study complete."), | |
| None, | |
| None, | |
| remaining_pairs, | |
| None, | |
| finish_msg | |
| ) | |
| pair = random.choice(remaining_pairs) | |
| remaining_pairs.remove(pair) | |
| next_pair = sample_pair(pair) | |
| status_msg = save_result["message"] | |
| if not save_result["saved_hub"]: | |
| status_msg += "\n\nYour response was still saved locally." | |
| return ( | |
| format_prompt_html(next_pair["prompt"]), | |
| video_value(next_pair["A_path"]), | |
| video_value(next_pair["B_path"]), | |
| remaining_pairs, | |
| next_pair, | |
| status_msg | |
| ) | |
| except Exception as e: | |
| print(f"submit_and_next error: {type(e).__name__}: {e}") | |
| return ( | |
| format_prompt_html(current["prompt"] if current else "Error"), | |
| None, | |
| None, | |
| remaining_pairs, | |
| current, | |
| f"Submit failed: {type(e).__name__}: {e}" | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Build Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title="Human Study", | |
| css=""" | |
| #prompt_box { | |
| width: 100%; | |
| text-align: center !important; | |
| font-size: 30px !important; | |
| font-weight: 900 !important; | |
| line-height: 1.35 !important; | |
| margin-top: 16px !important; | |
| margin-bottom: 22px !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # Human Study | |
| Read the prompt, watch both videos, and answer the questions below. Each response is saved | |
| automatically when you click **Submit and continue**. The study ends once you have gone through | |
| all available pairs. Thank you for your participation! | |
| """ | |
| ) | |
| participant_label = gr.Markdown() | |
| status = gr.Markdown() | |
| start_btn = gr.Button("Start Study", visible=True) | |
| prompt_md = gr.HTML("<div id='prompt_box'><strong>Prompt will appear here.</strong></div>") | |
| with gr.Row(): | |
| video_A = gr.Video(label="A", interactive=False, format="mp4") | |
| video_B = gr.Video(label="B", interactive=False, format="mp4") | |
| q_adherence = gr.Radio( | |
| choices=["A", "B", "No preference"], | |
| label="Which video has better adherence to the prompt?" | |
| ) | |
| q_logic = gr.Radio( | |
| choices=["A", "B", "No preference"], | |
| label="Which video has a more logical change of events?" | |
| ) | |
| q_visual = gr.Radio( | |
| choices=["A", "B", "No preference"], | |
| label="Which video has better visual quality?" | |
| ) | |
| next_btn = gr.Button("Submit and continue", visible=False) | |
| participant_id_state = gr.State() | |
| remaining_pairs_state = gr.State([]) | |
| current_pair_state = gr.State() | |
| start_btn.click( | |
| fn=start_session, | |
| outputs=[ | |
| prompt_md, | |
| video_A, | |
| video_B, | |
| participant_label, | |
| participant_id_state, | |
| remaining_pairs_state, | |
| current_pair_state, | |
| q_adherence, | |
| q_logic, | |
| q_visual, | |
| status, | |
| start_btn, | |
| next_btn, | |
| ] | |
| ) | |
| next_btn.click( | |
| fn=submit_and_next, | |
| inputs=[ | |
| q_adherence, | |
| q_logic, | |
| q_visual, | |
| participant_id_state, | |
| remaining_pairs_state, | |
| current_pair_state, | |
| ], | |
| outputs=[ | |
| prompt_md, | |
| video_A, | |
| video_B, | |
| remaining_pairs_state, | |
| current_pair_state, | |
| status, | |
| ] | |
| ).then( | |
| lambda: (None, None, None), | |
| outputs=[q_adherence, q_logic, q_visual] | |
| ) | |
| demo.launch() |