import csv import io import json import os import random import re import time from pathlib import Path import gradio as gr from huggingface_hub import HfApi BASE_DIR = Path(__file__).resolve().parent RESULTS_DIR = BASE_DIR / "results" OUTPUT_DIR = BASE_DIR / "survey_outputs" BASELINES = ["4dnex", "dimensionx", "geovideo"] EXPECTED_PER_BASELINE = 20 HF_DATASET_REPO = os.getenv("HF_DATASET_REPO") HF_TOKEN = os.getenv("HF_TOKEN") STORE_LOCAL = os.getenv("STORE_LOCAL", "0") == "1" INSTRUCTION = """ ### Instruction The image below is a **real-world image**. Please compare the two generated videos that animate the image by considering: - **Motion Naturalness:** Is the movement physically plausible? - **Geometric Consistency:** Is the 3D structure stable (no warping/distortion)? - **Visual Quality:** Is the video realistic, sharp, and artifact-free? **Which video is better overall?** """.strip() # noqa: E501 CSS = """ #input_image img { max-height: 280px; object-fit: contain; } #video_left video, #video_right video { height: 320px; object-fit: contain; background: #000; } """ def _sanitize(text: str) -> str: text = text.strip() or "anonymous" text = re.sub(r"[^A-Za-z0-9_-]+", "_", text) return text[:40] def _has_media(path: Path) -> bool: try: return path.exists() and path.stat().st_size > 0 except OSError: return False def _collect_items(seed: int): rng = random.Random(seed) warnings = [] items = [] for baseline in BASELINES: files = sorted((RESULTS_DIR / baseline).glob("*.mp4")) matched = [] skipped = 0 for fpath in files: name = fpath.stem input_path = RESULTS_DIR / "input" / f"{name}.jpg" ours_path = RESULTS_DIR / "ours" / f"{name}.mp4" if _has_media(input_path) and _has_media(ours_path) and _has_media(fpath): matched.append((name, input_path, ours_path, fpath)) else: skipped += 1 if len(matched) < EXPECTED_PER_BASELINE: warnings.append( f"{baseline}: only {len(matched)} valid pairs found (expected {EXPECTED_PER_BASELINE})." ) if skipped: warnings.append(f"{baseline}: skipped {skipped} items with missing/empty files.") if len(matched) > EXPECTED_PER_BASELINE: rng.shuffle(matched) matched = matched[:EXPECTED_PER_BASELINE] for name, input_path, ours_path, baseline_path in matched: left_is_ours = rng.random() < 0.5 left_video = ours_path if left_is_ours else baseline_path right_video = baseline_path if left_is_ours else ours_path items.append( { "id": name, "baseline": baseline, "input": input_path, "ours": ours_path, "baseline_video": baseline_path, "left_video": left_video, "right_video": right_video, "left_source": "ours" if left_is_ours else baseline, "right_source": baseline if left_is_ours else "ours", } ) rng.shuffle(items) return items, warnings def _current_view(state): if not state or not state["items"]: return "", None, None, None, None idx = state["index"] item = state["items"][idx] total = len(state["items"]) progress = f"Question {idx + 1} / {total}" choice = None saved = state["responses"].get(item["id"]) if saved: choice = saved["choice_label"] return ( progress, str(item["input"]), str(item["left_video"]), str(item["right_video"]), choice, ) def _record_choice(state, choice_label): item = state["items"][state["index"]] if choice_label == "Video A": chosen_source = item["left_source"] else: chosen_source = item["right_source"] state["responses"][item["id"]] = { "id": item["id"], "baseline": item["baseline"], "input": str(item["input"]), "ours": str(item["ours"]), "baseline_video": str(item["baseline_video"]), "left_video": str(item["left_video"]), "right_video": str(item["right_video"]), "left_source": item["left_source"], "right_source": item["right_source"], "choice_label": choice_label, "chosen_source": chosen_source, "choice_side": "left" if choice_label == "Video A" else "right", "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "participant_id": state["participant_id"], "seed": state["seed"], } def _write_exports(state): timestamp = time.strftime("%Y%m%d_%H%M%S") base = f"{timestamp}_{_sanitize(state['participant_id'])}_{state['seed']}" date_dir = time.strftime("%Y%m%d") dataset_prefix = f"sessions/{date_dir}/{base}" rows = [] for item in state["items"]: resp = state["responses"].get(item["id"], {}) rows.append( { "id": item["id"], "baseline": item["baseline"], "input": str(item["input"]), "ours": str(item["ours"]), "baseline_video": str(item["baseline_video"]), "left_video": str(item["left_video"]), "right_video": str(item["right_video"]), "left_source": item["left_source"], "right_source": item["right_source"], "choice_label": resp.get("choice_label"), "chosen_source": resp.get("chosen_source"), "choice_side": resp.get("choice_side"), "timestamp": resp.get("timestamp"), "participant_id": state["participant_id"], "seed": state["seed"], } ) csv_buffer = io.StringIO() writer = csv.DictWriter(csv_buffer, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) csv_bytes = csv_buffer.getvalue().encode("utf-8") json_payload = json.dumps( { "participant_id": state["participant_id"], "seed": state["seed"], "items": rows, }, ensure_ascii=False, indent=2, ).encode("utf-8") if STORE_LOCAL: OUTPUT_DIR.mkdir(exist_ok=True) csv_path = OUTPUT_DIR / f"{base}.csv" json_path = OUTPUT_DIR / f"{base}.json" csv_path.write_bytes(csv_bytes) json_path.write_bytes(json_payload) if HF_DATASET_REPO and HF_TOKEN: api = HfApi(token=HF_TOKEN) api.create_repo( repo_id=HF_DATASET_REPO, repo_type="dataset", private=True, exist_ok=True, ) api.upload_file( path_or_fileobj=io.BytesIO(csv_bytes), path_in_repo=f"{dataset_prefix}.csv", repo_id=HF_DATASET_REPO, repo_type="dataset", ) api.upload_file( path_or_fileobj=io.BytesIO(json_payload), path_in_repo=f"{dataset_prefix}.json", repo_id=HF_DATASET_REPO, repo_type="dataset", ) return None def start_session(participant_id): seed = int(time.time() * 1000) % (2**31 - 1) items, warnings = _collect_items(seed) if not STORE_LOCAL and (not HF_DATASET_REPO or not HF_TOKEN): warnings.append( "Results will NOT be saved. Set HF_DATASET_REPO and HF_TOKEN in Space secrets." ) state = { "seed": seed, "items": items, "index": 0, "responses": {}, "participant_id": (participant_id or "anonymous").strip(), } warning_text = "\n".join(f"- {w}" for w in warnings) if warnings else "" if not items: warning_text = "No valid items found. Please check your results folders." progress, image, left, right, choice = _current_view(state) status = warning_text or "Session ready." return ( state, status, gr.update(visible=True), gr.update(visible=False), progress, image, left, right, choice, ) def go_next(state, choice_label): if not state: return ( state, "Please start the session.", gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), ) if not choice_label: progress, image, left, right, choice = _current_view(state) return ( state, "Please choose Video A or Video B before continuing.", gr.update(visible=True), gr.update(visible=False), progress, image, left, right, choice, ) _record_choice(state, choice_label) if state["index"] < len(state["items"]) - 1: state["index"] += 1 progress, image, left, right, choice = _current_view(state) return ( state, "", gr.update(visible=True), gr.update(visible=False), progress, image, left, right, choice, ) _write_exports(state) return ( state, "", gr.update(visible=False), gr.update(visible=True), gr.update(value=""), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), ) def go_prev(state): if not state: return ( state, "Please start the session.", gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), ) if state["index"] > 0: state["index"] -= 1 progress, image, left, right, choice = _current_view(state) return ( state, "", gr.update(visible=True), gr.update(visible=False), progress, image, left, right, choice, ) with gr.Blocks(title="Real-world Image-to-Video Survey", css=CSS) as demo: gr.Markdown("# Real-world Image-to-Video Survey") gr.Markdown(INSTRUCTION) with gr.Row(): participant_id = gr.Textbox(label="Participant ID (optional)", placeholder="e.g., P012") start_btn = gr.Button("Start New Session", variant="primary") status = gr.Markdown("") with gr.Column(visible=False) as question_box: progress_md = gr.Markdown("") input_image = gr.Image(label="Input Image", interactive=False, elem_id="input_image") with gr.Row(): left_video = gr.Video(label="Video A", interactive=False, elem_id="video_left") right_video = gr.Video(label="Video B", interactive=False, elem_id="video_right") choice = gr.Radio( choices=["Video A", "Video B"], label="Which video is better overall?", ) with gr.Row(): prev_btn = gr.Button("Previous") next_btn = gr.Button("Next", variant="primary") with gr.Column(visible=False) as done_box: gr.Markdown("## All done") gr.Markdown("Thank you! You can close this tab now.") session_state = gr.State() start_btn.click( start_session, inputs=[participant_id], outputs=[ session_state, status, question_box, done_box, progress_md, input_image, left_video, right_video, choice, ], ) next_btn.click( go_next, inputs=[session_state, choice], outputs=[ session_state, status, question_box, done_box, progress_md, input_image, left_video, right_video, choice, ], ) prev_btn.click( go_prev, inputs=[session_state], outputs=[ session_state, status, question_box, done_box, progress_md, input_image, left_video, right_video, choice, ], ) server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860")) demo.launch( share=os.getenv("GRADIO_SHARE", "1") == "1", server_name=server_name, server_port=server_port, )