Spaces:
Paused
Paused
| import csv | |
| import io | |
| import json | |
| import os | |
| import random | |
| import re | |
| import time | |
| from pathlib import Path | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| BASE_DIR = Path(__file__).resolve().parent | |
| RESULTS_DIR = BASE_DIR / "results" | |
| OUTPUT_DIR = BASE_DIR / "survey_outputs" | |
| BASELINES = ["4dnex", "dimensionx", "geovideo"] | |
| EXPECTED_PER_BASELINE = 20 | |
| HF_DATASET_REPO = os.getenv("HF_DATASET_REPO") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| STORE_LOCAL = os.getenv("STORE_LOCAL", "0") == "1" | |
| INSTRUCTION = """ | |
| ### Instruction | |
| The image below is a **real-world image**. Please compare the two generated videos that animate the image by considering: | |
| - **Motion Naturalness:** Is the movement physically plausible? | |
| - **Geometric Consistency:** Is the 3D structure stable (no warping/distortion)? | |
| - **Visual Quality:** Is the video realistic, sharp, and artifact-free? | |
| **Which video is better overall?** | |
| """.strip() # noqa: E501 | |
| CSS = """ | |
| #input_image img { | |
| max-height: 280px; | |
| object-fit: contain; | |
| } | |
| #video_left video, | |
| #video_right video { | |
| height: 320px; | |
| object-fit: contain; | |
| background: #000; | |
| } | |
| """ | |
| def _sanitize(text: str) -> str: | |
| text = text.strip() or "anonymous" | |
| text = re.sub(r"[^A-Za-z0-9_-]+", "_", text) | |
| return text[:40] | |
| def _has_media(path: Path) -> bool: | |
| try: | |
| return path.exists() and path.stat().st_size > 0 | |
| except OSError: | |
| return False | |
| def _collect_items(seed: int): | |
| rng = random.Random(seed) | |
| warnings = [] | |
| items = [] | |
| for baseline in BASELINES: | |
| files = sorted((RESULTS_DIR / baseline).glob("*.mp4")) | |
| matched = [] | |
| skipped = 0 | |
| for fpath in files: | |
| name = fpath.stem | |
| input_path = RESULTS_DIR / "input" / f"{name}.jpg" | |
| ours_path = RESULTS_DIR / "ours" / f"{name}.mp4" | |
| if _has_media(input_path) and _has_media(ours_path) and _has_media(fpath): | |
| matched.append((name, input_path, ours_path, fpath)) | |
| else: | |
| skipped += 1 | |
| if len(matched) < EXPECTED_PER_BASELINE: | |
| warnings.append( | |
| f"{baseline}: only {len(matched)} valid pairs found (expected {EXPECTED_PER_BASELINE})." | |
| ) | |
| if skipped: | |
| warnings.append(f"{baseline}: skipped {skipped} items with missing/empty files.") | |
| if len(matched) > EXPECTED_PER_BASELINE: | |
| rng.shuffle(matched) | |
| matched = matched[:EXPECTED_PER_BASELINE] | |
| for name, input_path, ours_path, baseline_path in matched: | |
| left_is_ours = rng.random() < 0.5 | |
| left_video = ours_path if left_is_ours else baseline_path | |
| right_video = baseline_path if left_is_ours else ours_path | |
| items.append( | |
| { | |
| "id": name, | |
| "baseline": baseline, | |
| "input": input_path, | |
| "ours": ours_path, | |
| "baseline_video": baseline_path, | |
| "left_video": left_video, | |
| "right_video": right_video, | |
| "left_source": "ours" if left_is_ours else baseline, | |
| "right_source": baseline if left_is_ours else "ours", | |
| } | |
| ) | |
| rng.shuffle(items) | |
| return items, warnings | |
| def _current_view(state): | |
| if not state or not state["items"]: | |
| return "", None, None, None, None | |
| idx = state["index"] | |
| item = state["items"][idx] | |
| total = len(state["items"]) | |
| progress = f"Question {idx + 1} / {total}" | |
| choice = None | |
| saved = state["responses"].get(item["id"]) | |
| if saved: | |
| choice = saved["choice_label"] | |
| return ( | |
| progress, | |
| str(item["input"]), | |
| str(item["left_video"]), | |
| str(item["right_video"]), | |
| choice, | |
| ) | |
| def _record_choice(state, choice_label): | |
| item = state["items"][state["index"]] | |
| if choice_label == "Video A": | |
| chosen_source = item["left_source"] | |
| else: | |
| chosen_source = item["right_source"] | |
| state["responses"][item["id"]] = { | |
| "id": item["id"], | |
| "baseline": item["baseline"], | |
| "input": str(item["input"]), | |
| "ours": str(item["ours"]), | |
| "baseline_video": str(item["baseline_video"]), | |
| "left_video": str(item["left_video"]), | |
| "right_video": str(item["right_video"]), | |
| "left_source": item["left_source"], | |
| "right_source": item["right_source"], | |
| "choice_label": choice_label, | |
| "chosen_source": chosen_source, | |
| "choice_side": "left" if choice_label == "Video A" else "right", | |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
| "participant_id": state["participant_id"], | |
| "seed": state["seed"], | |
| } | |
| def _write_exports(state): | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| base = f"{timestamp}_{_sanitize(state['participant_id'])}_{state['seed']}" | |
| date_dir = time.strftime("%Y%m%d") | |
| dataset_prefix = f"sessions/{date_dir}/{base}" | |
| rows = [] | |
| for item in state["items"]: | |
| resp = state["responses"].get(item["id"], {}) | |
| rows.append( | |
| { | |
| "id": item["id"], | |
| "baseline": item["baseline"], | |
| "input": str(item["input"]), | |
| "ours": str(item["ours"]), | |
| "baseline_video": str(item["baseline_video"]), | |
| "left_video": str(item["left_video"]), | |
| "right_video": str(item["right_video"]), | |
| "left_source": item["left_source"], | |
| "right_source": item["right_source"], | |
| "choice_label": resp.get("choice_label"), | |
| "chosen_source": resp.get("chosen_source"), | |
| "choice_side": resp.get("choice_side"), | |
| "timestamp": resp.get("timestamp"), | |
| "participant_id": state["participant_id"], | |
| "seed": state["seed"], | |
| } | |
| ) | |
| csv_buffer = io.StringIO() | |
| writer = csv.DictWriter(csv_buffer, fieldnames=list(rows[0].keys())) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| csv_bytes = csv_buffer.getvalue().encode("utf-8") | |
| json_payload = json.dumps( | |
| { | |
| "participant_id": state["participant_id"], | |
| "seed": state["seed"], | |
| "items": rows, | |
| }, | |
| ensure_ascii=False, | |
| indent=2, | |
| ).encode("utf-8") | |
| if STORE_LOCAL: | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| csv_path = OUTPUT_DIR / f"{base}.csv" | |
| json_path = OUTPUT_DIR / f"{base}.json" | |
| csv_path.write_bytes(csv_bytes) | |
| json_path.write_bytes(json_payload) | |
| if HF_DATASET_REPO and HF_TOKEN: | |
| api = HfApi(token=HF_TOKEN) | |
| api.create_repo( | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| private=True, | |
| exist_ok=True, | |
| ) | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(csv_bytes), | |
| path_in_repo=f"{dataset_prefix}.csv", | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| ) | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(json_payload), | |
| path_in_repo=f"{dataset_prefix}.json", | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| ) | |
| return None | |
| def start_session(participant_id): | |
| seed = int(time.time() * 1000) % (2**31 - 1) | |
| items, warnings = _collect_items(seed) | |
| if not STORE_LOCAL and (not HF_DATASET_REPO or not HF_TOKEN): | |
| warnings.append( | |
| "Results will NOT be saved. Set HF_DATASET_REPO and HF_TOKEN in Space secrets." | |
| ) | |
| state = { | |
| "seed": seed, | |
| "items": items, | |
| "index": 0, | |
| "responses": {}, | |
| "participant_id": (participant_id or "anonymous").strip(), | |
| } | |
| warning_text = "\n".join(f"- {w}" for w in warnings) if warnings else "" | |
| if not items: | |
| warning_text = "No valid items found. Please check your results folders." | |
| progress, image, left, right, choice = _current_view(state) | |
| status = warning_text or "Session ready." | |
| return ( | |
| state, | |
| status, | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| progress, | |
| image, | |
| left, | |
| right, | |
| choice, | |
| ) | |
| def go_next(state, choice_label): | |
| if not state: | |
| return ( | |
| state, | |
| "Please start the session.", | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| ) | |
| if not choice_label: | |
| progress, image, left, right, choice = _current_view(state) | |
| return ( | |
| state, | |
| "Please choose Video A or Video B before continuing.", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| progress, | |
| image, | |
| left, | |
| right, | |
| choice, | |
| ) | |
| _record_choice(state, choice_label) | |
| if state["index"] < len(state["items"]) - 1: | |
| state["index"] += 1 | |
| progress, image, left, right, choice = _current_view(state) | |
| return ( | |
| state, | |
| "", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| progress, | |
| image, | |
| left, | |
| right, | |
| choice, | |
| ) | |
| _write_exports(state) | |
| return ( | |
| state, | |
| "", | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| ) | |
| def go_prev(state): | |
| if not state: | |
| return ( | |
| state, | |
| "Please start the session.", | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| ) | |
| if state["index"] > 0: | |
| state["index"] -= 1 | |
| progress, image, left, right, choice = _current_view(state) | |
| return ( | |
| state, | |
| "", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| progress, | |
| image, | |
| left, | |
| right, | |
| choice, | |
| ) | |
| with gr.Blocks(title="Real-world Image-to-Video Survey", css=CSS) as demo: | |
| gr.Markdown("# Real-world Image-to-Video Survey") | |
| gr.Markdown(INSTRUCTION) | |
| with gr.Row(): | |
| participant_id = gr.Textbox(label="Participant ID (optional)", placeholder="e.g., P012") | |
| start_btn = gr.Button("Start New Session", variant="primary") | |
| status = gr.Markdown("") | |
| with gr.Column(visible=False) as question_box: | |
| progress_md = gr.Markdown("") | |
| input_image = gr.Image(label="Input Image", interactive=False, elem_id="input_image") | |
| with gr.Row(): | |
| left_video = gr.Video(label="Video A", interactive=False, elem_id="video_left") | |
| right_video = gr.Video(label="Video B", interactive=False, elem_id="video_right") | |
| choice = gr.Radio( | |
| choices=["Video A", "Video B"], | |
| label="Which video is better overall?", | |
| ) | |
| with gr.Row(): | |
| prev_btn = gr.Button("Previous") | |
| next_btn = gr.Button("Next", variant="primary") | |
| with gr.Column(visible=False) as done_box: | |
| gr.Markdown("## All done") | |
| gr.Markdown("Thank you! You can close this tab now.") | |
| session_state = gr.State() | |
| start_btn.click( | |
| start_session, | |
| inputs=[participant_id], | |
| outputs=[ | |
| session_state, | |
| status, | |
| question_box, | |
| done_box, | |
| progress_md, | |
| input_image, | |
| left_video, | |
| right_video, | |
| choice, | |
| ], | |
| ) | |
| next_btn.click( | |
| go_next, | |
| inputs=[session_state, choice], | |
| outputs=[ | |
| session_state, | |
| status, | |
| question_box, | |
| done_box, | |
| progress_md, | |
| input_image, | |
| left_video, | |
| right_video, | |
| choice, | |
| ], | |
| ) | |
| prev_btn.click( | |
| go_prev, | |
| inputs=[session_state], | |
| outputs=[ | |
| session_state, | |
| status, | |
| question_box, | |
| done_box, | |
| progress_md, | |
| input_image, | |
| left_video, | |
| right_video, | |
| choice, | |
| ], | |
| ) | |
| server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") | |
| server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860")) | |
| demo.launch( | |
| share=os.getenv("GRADIO_SHARE", "1") == "1", | |
| server_name=server_name, | |
| server_port=server_port, | |
| ) | |