Human_study_AVG / app.py
mariam-hassan's picture
Update app.py
2e759f1 verified
import os
import re
import json
import uuid
import random
from pathlib import Path
from datetime import datetime, timezone
import gradio as gr
from huggingface_hub import HfApi
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
STUDY_ROOT = Path("avg_study")
PROMPT_JSON = "prompts_mapping.json"
HF_TOKEN = os.getenv("HF_TOKEN")
RESULTS_REPO_ID = os.getenv("RESULTS_REPO_ID")
api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
LOCAL_RESULTS_DIR = Path("local_results")
LOCAL_RESULTS_DIR.mkdir(exist_ok=True)
# -----------------------------------------------------------------------------
# Load prompt mapping and build pairs
# -----------------------------------------------------------------------------
def clean_prompt(prompt: str) -> str:
if not prompt:
return ""
prompt = re.sub(r"\s*idx\d+\s+r\d+\s*$", "", prompt).strip()
return prompt
def format_prompt_html(prompt: str) -> str:
safe_prompt = prompt.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
return f"<div id='prompt_box'><strong>{safe_prompt}</strong></div>"
def load_prompt_mapping(mapping_file: str):
mapping_path = Path(mapping_file)
if not mapping_path.exists():
raise FileNotFoundError(
f"Mapping file '{mapping_file}' not found. "
"Ensure that 'prompts_mapping.json' exists in the repository."
)
with open(mapping_path, "r", encoding="utf-8") as f:
return json.load(f)
def build_pairs_from_mapping(prompt_map):
pairs = []
for idx, entry in prompt_map.items():
prompt = clean_prompt(entry.get("prompt", ""))
category = entry.get("category", "")
good_rel = entry.get("good_file")
bad_rel = entry.get("bad_file")
if not good_rel or not bad_rel:
continue
good_path = STUDY_ROOT / good_rel
bad_path = STUDY_ROOT / bad_rel
pairs.append(
{
"pair_id": idx,
"category": category,
"prompt": prompt,
"good_path": str(good_path.resolve()),
"bad_path": str(bad_path.resolve()),
}
)
return pairs
prompt_map = load_prompt_mapping(PROMPT_JSON)
PAIRS = build_pairs_from_mapping(prompt_map)
print(f"Loaded {len(PAIRS)} pairs.")
if PAIRS:
print("First pair:", PAIRS[0])
if not PAIRS:
raise RuntimeError(
"No matched good/bad pairs found in mapping. "
"Check that 'avg_study' contains matching files and that the mapping file lists them."
)
# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------
def video_value(path_str: str) -> str:
path = Path(path_str)
print(f"Loading video: {path}")
print(f"Exists: {path.exists()}")
if not path.exists():
raise FileNotFoundError(f"Video file not found: {path}")
return str(path.resolve())
def sample_pair(pair):
good_on_A = random.choice([True, False])
if good_on_A:
A_path = pair["good_path"]
B_path = pair["bad_path"]
A_label = "good"
B_label = "bad"
else:
A_path = pair["bad_path"]
B_path = pair["good_path"]
A_label = "bad"
B_label = "good"
current = pair.copy()
current.update(
{
"A_path": A_path,
"B_path": B_path,
"A_label": A_label,
"B_label": B_label,
"good_on_A": good_on_A,
}
)
return current
def local_save(record: dict) -> str:
local_name = LOCAL_RESULTS_DIR / f"{record['response_id']}.json"
with open(local_name, "w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
return str(local_name)
def save_response(record: dict) -> dict:
local_path = local_save(record)
if not (RESULTS_REPO_ID and HF_TOKEN and api):
return {
"ok": True,
"saved_local": True,
"saved_hub": False,
"message": f"Saved locally to {local_path}. Hub upload is not configured.",
}
tmp_name = f"/tmp/{record['response_id']}.json"
with open(tmp_name, "w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
remote_path = f"responses/{record['timestamp'][:10]}/{record['response_id']}.json"
try:
api.upload_file(
path_or_fileobj=tmp_name,
path_in_repo=remote_path,
repo_id=RESULTS_REPO_ID,
repo_type="dataset",
)
return {
"ok": True,
"saved_local": True,
"saved_hub": True,
"message": f"Saved to dataset repo: {remote_path}",
}
except Exception as e:
return {
"ok": False,
"saved_local": True,
"saved_hub": False,
"message": f"Saved locally to {local_path}, but Hub upload failed: {e}",
}
# -----------------------------------------------------------------------------
# Gradio callback functions
# -----------------------------------------------------------------------------
def start_session():
try:
participant_id = str(uuid.uuid4())
remaining_pairs = PAIRS.copy()
pair = random.choice(remaining_pairs)
remaining_pairs.remove(pair)
current = sample_pair(pair)
status_msg = "Study loaded. Read the prompt, watch both videos, and answer the questions below."
return (
format_prompt_html(current["prompt"]),
video_value(current["A_path"]),
video_value(current["B_path"]),
f"Participant ID: {participant_id}",
participant_id,
remaining_pairs,
current,
None,
None,
None,
status_msg,
gr.update(visible=False),
gr.update(visible=True),
)
except Exception as e:
print(f"start_session error: {type(e).__name__}: {e}")
return (
format_prompt_html(f"Error loading study: {type(e).__name__}: {e}"),
None,
None,
"",
None,
[],
None,
None,
None,
None,
f"Start failed: {type(e).__name__}: {e}",
gr.update(visible=True),
gr.update(visible=False),
)
def submit_and_next(ans_prompt, ans_logic, ans_visual, participant_id, remaining_pairs, current):
try:
if current is None:
return (
format_prompt_html("No current pair loaded."),
None,
None,
remaining_pairs,
current,
"No current pair loaded."
)
if ans_prompt is None or ans_logic is None or ans_visual is None:
return (
format_prompt_html(current["prompt"]),
video_value(current["A_path"]),
video_value(current["B_path"]),
remaining_pairs,
current,
"Please answer all questions before continuing."
)
timestamp = datetime.now(timezone.utc).isoformat()
record = {
"response_id": str(uuid.uuid4()),
"timestamp": timestamp,
"participant_id": participant_id,
"category": current["category"],
"prompt": current["prompt"],
"pair_id": current["pair_id"],
"A_video": current["A_path"],
"B_video": current["B_path"],
"good_video": current["good_path"],
"bad_video": current["bad_path"],
"good_on_A": current["good_on_A"],
"A_label": current["A_label"],
"B_label": current["B_label"],
"adherence_answer": ans_prompt,
"logic_answer": ans_logic,
"visual_quality_answer": ans_visual,
}
save_result = save_response(record)
if not remaining_pairs:
finish_msg = save_result["message"] + "\n\nYou have completed all pairs. Thank you!"
return (
format_prompt_html("Study complete."),
None,
None,
remaining_pairs,
None,
finish_msg
)
pair = random.choice(remaining_pairs)
remaining_pairs.remove(pair)
next_pair = sample_pair(pair)
status_msg = save_result["message"]
if not save_result["saved_hub"]:
status_msg += "\n\nYour response was still saved locally."
return (
format_prompt_html(next_pair["prompt"]),
video_value(next_pair["A_path"]),
video_value(next_pair["B_path"]),
remaining_pairs,
next_pair,
status_msg
)
except Exception as e:
print(f"submit_and_next error: {type(e).__name__}: {e}")
return (
format_prompt_html(current["prompt"] if current else "Error"),
None,
None,
remaining_pairs,
current,
f"Submit failed: {type(e).__name__}: {e}"
)
# -----------------------------------------------------------------------------
# Build Gradio UI
# -----------------------------------------------------------------------------
with gr.Blocks(
title="Human Study",
css="""
#prompt_box {
width: 100%;
text-align: center !important;
font-size: 30px !important;
font-weight: 900 !important;
line-height: 1.35 !important;
margin-top: 16px !important;
margin-bottom: 22px !important;
}
"""
) as demo:
gr.Markdown(
"""
# Human Study
Read the prompt, watch both videos, and answer the questions below. Each response is saved
automatically when you click **Submit and continue**. The study ends once you have gone through
all available pairs. Thank you for your participation!
"""
)
participant_label = gr.Markdown()
status = gr.Markdown()
start_btn = gr.Button("Start Study", visible=True)
prompt_md = gr.HTML("<div id='prompt_box'><strong>Prompt will appear here.</strong></div>")
with gr.Row():
video_A = gr.Video(label="A", interactive=False, format="mp4")
video_B = gr.Video(label="B", interactive=False, format="mp4")
q_adherence = gr.Radio(
choices=["A", "B", "No preference"],
label="Which video has better adherence to the prompt?"
)
q_logic = gr.Radio(
choices=["A", "B", "No preference"],
label="Which video has a more logical change of events?"
)
q_visual = gr.Radio(
choices=["A", "B", "No preference"],
label="Which video has better visual quality?"
)
next_btn = gr.Button("Submit and continue", visible=False)
participant_id_state = gr.State()
remaining_pairs_state = gr.State([])
current_pair_state = gr.State()
start_btn.click(
fn=start_session,
outputs=[
prompt_md,
video_A,
video_B,
participant_label,
participant_id_state,
remaining_pairs_state,
current_pair_state,
q_adherence,
q_logic,
q_visual,
status,
start_btn,
next_btn,
]
)
next_btn.click(
fn=submit_and_next,
inputs=[
q_adherence,
q_logic,
q_visual,
participant_id_state,
remaining_pairs_state,
current_pair_state,
],
outputs=[
prompt_md,
video_A,
video_B,
remaining_pairs_state,
current_pair_state,
status,
]
).then(
lambda: (None, None, None),
outputs=[q_adherence, q_logic, q_visual]
)
demo.launch()