human_study / app.py
mariam-hassan's picture
Update app.py
7fd2e84 verified
import os
import re
import json
import uuid
import random
from pathlib import Path
from datetime import datetime, timezone
import gradio as gr
from huggingface_hub import HfApi
# ----------------------------
# Config
# ----------------------------
STUDY_ROOT = Path("study")
CATEGORIES = ["search_t2v", "search_i2v", "opt_t2v", "opt_i2v"]
HF_TOKEN = os.getenv("HF_TOKEN")
RESULTS_REPO_ID = os.getenv("RESULTS_REPO_ID")
api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
LOCAL_RESULTS_DIR = Path("local_results")
LOCAL_RESULTS_DIR.mkdir(exist_ok=True)
# ----------------------------
# Pair discovery
# ----------------------------
def normalize_stem(path: Path) -> str:
"""
Match files like:
0001_good.mp4
0001_bad.mp4
and return:
0001
"""
m = re.match(r"^(\d+)_(good|bad)$", path.stem)
if m:
return m.group(1)
return path.stem
def build_pairs():
all_pairs = []
for category in CATEGORIES:
good_dir = STUDY_ROOT / category / "good"
bad_dir = STUDY_ROOT / category / "bad"
good_files = (
list(good_dir.glob("*.mp4"))
+ list(good_dir.glob("*.webm"))
+ list(good_dir.glob("*.ogg"))
)
bad_files = (
list(bad_dir.glob("*.mp4"))
+ list(bad_dir.glob("*.webm"))
+ list(bad_dir.glob("*.ogg"))
)
good_map = {}
for p in good_files:
key = normalize_stem(p)
good_map.setdefault(key, []).append(p)
bad_map = {}
for p in bad_files:
key = normalize_stem(p)
bad_map.setdefault(key, []).append(p)
shared_keys = sorted(set(good_map.keys()) & set(bad_map.keys()))
for key in shared_keys:
goods = sorted(good_map[key])
bads = sorted(bad_map[key])
n = min(len(goods), len(bads))
for i in range(n):
all_pairs.append(
{
"category": category,
"pair_id": f"{category}::{key}::{i}",
"key": key,
"good_path": str(goods[i].resolve()),
"bad_path": str(bads[i].resolve()),
"good_file": goods[i].name,
"bad_file": bads[i].name,
}
)
return all_pairs
PAIRS = build_pairs()
if len(PAIRS) == 0:
raise RuntimeError("No matched good/bad pairs found.")
# ----------------------------
# Video helper
# ----------------------------
def video_value(path_str: str):
path = Path(path_str)
if not path.exists():
raise FileNotFoundError(f"Video file not found: {path}")
return str(path.resolve())
# ----------------------------
# Sampling logic
# ----------------------------
def sample_pair(seen_pair_ids):
seen_pair_ids = set(seen_pair_ids or [])
unseen = [p for p in PAIRS if p["pair_id"] not in seen_pair_ids]
pool = unseen if unseen else PAIRS
pair = random.choice(pool)
good_on_left = random.choice([True, False])
if good_on_left:
left_path = pair["good_path"]
right_path = pair["bad_path"]
left_label = "good"
right_label = "bad"
else:
left_path = pair["bad_path"]
right_path = pair["good_path"]
left_label = "bad"
right_label = "good"
current = {
**pair,
"left_path": left_path,
"right_path": right_path,
"left_hidden_label": left_label,
"right_hidden_label": right_label,
"good_on_left": good_on_left,
}
return current
# ----------------------------
# Hub / persistence helpers
# ----------------------------
def local_save(record):
local_name = LOCAL_RESULTS_DIR / f"{record['response_id']}.json"
with open(local_name, "w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
return str(local_name)
def check_hub_setup():
"""
Returns a warning string if Hub saving is not configured correctly.
Otherwise returns None.
"""
if not RESULTS_REPO_ID:
return "RESULTS_REPO_ID is not set. Responses will be saved locally only."
if not HF_TOKEN:
return "HF_TOKEN is not set. Responses will be saved locally only."
try:
api.repo_info(repo_id=RESULTS_REPO_ID, repo_type="dataset")
return None
except Exception as e:
return f"Hub dataset check failed: {type(e).__name__}: {e}"
HUB_WARNING = check_hub_setup()
def save_response(record):
"""
Always save locally first.
Then try to upload to the dataset repo.
Never raise an exception to the UI callback.
"""
local_path = local_save(record)
if not RESULTS_REPO_ID or not HF_TOKEN or api is None:
return {
"ok": True,
"saved_local": True,
"saved_hub": False,
"message": f"Saved locally to {local_path}. Hub upload is not configured."
}
tmp_name = f"/tmp/{record['response_id']}.json"
with open(tmp_name, "w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
remote_path = f"responses/{record['timestamp'][:10]}/{record['response_id']}.json"
try:
api.upload_file(
path_or_fileobj=tmp_name,
path_in_repo=remote_path,
repo_id=RESULTS_REPO_ID,
repo_type="dataset",
)
return {
"ok": True,
"saved_local": True,
"saved_hub": True,
"message": f"Saved to dataset repo: {remote_path}"
}
except Exception as e:
return {
"ok": False,
"saved_local": True,
"saved_hub": False,
"message": f"Saved locally to {local_path}, but Hub upload failed: {type(e).__name__}: {e}"
}
# ----------------------------
# Gradio callbacks
# ----------------------------
def start_session():
participant_id = str(uuid.uuid4())
seen_pair_ids = []
current = sample_pair(seen_pair_ids)
status_msg = "Study loaded. Watch both videos and answer the questions below."
if HUB_WARNING:
status_msg += f"\n\nWarning: {HUB_WARNING}"
return (
video_value(current["left_path"]),
video_value(current["right_path"]),
f"Participant ID: {participant_id}",
participant_id,
seen_pair_ids,
current,
None,
None,
status_msg,
gr.update(visible=False), # hide Study study
gr.update(visible=True), # show Submit and continue
)
def submit_and_next(
plausible_answer,
quality_answer,
participant_id,
seen_pair_ids,
current
):
if current is None:
return (
None, None, seen_pair_ids, current,
"No current pair loaded."
)
if plausible_answer is None or quality_answer is None:
return (
video_value(current["left_path"]),
video_value(current["right_path"]),
seen_pair_ids,
current,
"Please answer both questions before continuing."
)
timestamp = datetime.now(timezone.utc).isoformat()
record = {
"response_id": str(uuid.uuid4()),
"timestamp": timestamp,
"participant_id": participant_id,
"category": current["category"],
"pair_id": current["pair_id"],
"pair_key": current["key"],
"left_video": current["left_path"],
"right_video": current["right_path"],
"good_video": current["good_path"],
"bad_video": current["bad_path"],
"good_on_left": current["good_on_left"],
"left_hidden_label": current["left_hidden_label"],
"right_hidden_label": current["right_hidden_label"],
"physical_plausibility_answer": plausible_answer,
"visual_quality_answer": quality_answer,
}
save_result = save_response(record)
seen_pair_ids = list(set(seen_pair_ids + [current["pair_id"]]))
next_pair = sample_pair(seen_pair_ids)
status_msg = save_result["message"]
if not save_result["saved_hub"]:
status_msg += "\n\nYour response was still saved locally."
return (
video_value(next_pair["left_path"]),
video_value(next_pair["right_path"]),
seen_pair_ids,
next_pair,
status_msg
)
# ----------------------------
# UI
# ----------------------------
with gr.Blocks(title="Human Study") as demo:
gr.Markdown(
"""
# Human Study
Please answer as many video pairs as you can.
You do not need to finish the full study. Each response is saved automatically when you click **Submit and continue**, so you may stop at any time.
### Physical plausibility
- Watch both videos and choose the one that is relatively better in terms of physical realism.
- Sometimes both videos may be imperfect, but please select the one with better physics than the other.
- Pay attention to object motion, interactions, and overall realism.
- Please avoid choosing **No preference** for physical plausibility as much as possible.
### Visual quality
- Choose the video with better visual quality.
- Focus on overall visual appearance and rendering quality.
"""
)
participant_label = gr.Markdown()
with gr.Row():
video_left = gr.Video(
label="Left",
interactive=False,
format="mp4",
)
video_right = gr.Video(
label="Right",
interactive=False,
format="mp4",
)
plausible = gr.Radio(
choices=["Left", "Right", "No preference"],
label="Which video is more physically plausible?"
)
quality = gr.Radio(
choices=["Left", "Right", "No preference"],
label="Which video has better visual quality?"
)
status = gr.Markdown()
start_btn = gr.Button("Study study", visible=True)
next_btn = gr.Button("Submit and continue", visible=False)
participant_id_state = gr.State()
seen_pair_ids_state = gr.State([])
current_pair_state = gr.State()
start_btn.click(
fn=start_session,
outputs=[
video_left,
video_right,
participant_label,
participant_id_state,
seen_pair_ids_state,
current_pair_state,
plausible,
quality,
status,
start_btn,
next_btn,
]
)
next_btn.click(
fn=submit_and_next,
inputs=[
plausible,
quality,
participant_id_state,
seen_pair_ids_state,
current_pair_state
],
outputs=[
video_left,
video_right,
seen_pair_ids_state,
current_pair_state,
status
]
).then(
lambda: (None, None),
outputs=[plausible, quality]
)
demo.launch()