bshfang's picture
Upload folder using huggingface_hub
dedbd79 verified
import csv
import io
import json
import os
import random
import re
import time
from pathlib import Path
import gradio as gr
from huggingface_hub import HfApi
BASE_DIR = Path(__file__).resolve().parent
RESULTS_DIR = BASE_DIR / "results"
OUTPUT_DIR = BASE_DIR / "survey_outputs"
BASELINES = ["4dnex", "dimensionx", "geovideo"]
EXPECTED_PER_BASELINE = 20
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO")
HF_TOKEN = os.getenv("HF_TOKEN")
STORE_LOCAL = os.getenv("STORE_LOCAL", "0") == "1"
INSTRUCTION = """
### Instruction
The image below is a **real-world image**. Please compare the two generated videos that animate the image by considering:
- **Motion Naturalness:** Is the movement physically plausible?
- **Geometric Consistency:** Is the 3D structure stable (no warping/distortion)?
- **Visual Quality:** Is the video realistic, sharp, and artifact-free?
**Which video is better overall?**
""".strip() # noqa: E501
CSS = """
#input_image img {
max-height: 280px;
object-fit: contain;
}
#video_left video,
#video_right video {
height: 320px;
object-fit: contain;
background: #000;
}
"""
def _sanitize(text: str) -> str:
text = text.strip() or "anonymous"
text = re.sub(r"[^A-Za-z0-9_-]+", "_", text)
return text[:40]
def _has_media(path: Path) -> bool:
try:
return path.exists() and path.stat().st_size > 0
except OSError:
return False
def _collect_items(seed: int):
rng = random.Random(seed)
warnings = []
items = []
for baseline in BASELINES:
files = sorted((RESULTS_DIR / baseline).glob("*.mp4"))
matched = []
skipped = 0
for fpath in files:
name = fpath.stem
input_path = RESULTS_DIR / "input" / f"{name}.jpg"
ours_path = RESULTS_DIR / "ours" / f"{name}.mp4"
if _has_media(input_path) and _has_media(ours_path) and _has_media(fpath):
matched.append((name, input_path, ours_path, fpath))
else:
skipped += 1
if len(matched) < EXPECTED_PER_BASELINE:
warnings.append(
f"{baseline}: only {len(matched)} valid pairs found (expected {EXPECTED_PER_BASELINE})."
)
if skipped:
warnings.append(f"{baseline}: skipped {skipped} items with missing/empty files.")
if len(matched) > EXPECTED_PER_BASELINE:
rng.shuffle(matched)
matched = matched[:EXPECTED_PER_BASELINE]
for name, input_path, ours_path, baseline_path in matched:
left_is_ours = rng.random() < 0.5
left_video = ours_path if left_is_ours else baseline_path
right_video = baseline_path if left_is_ours else ours_path
items.append(
{
"id": name,
"baseline": baseline,
"input": input_path,
"ours": ours_path,
"baseline_video": baseline_path,
"left_video": left_video,
"right_video": right_video,
"left_source": "ours" if left_is_ours else baseline,
"right_source": baseline if left_is_ours else "ours",
}
)
rng.shuffle(items)
return items, warnings
def _current_view(state):
if not state or not state["items"]:
return "", None, None, None, None
idx = state["index"]
item = state["items"][idx]
total = len(state["items"])
progress = f"Question {idx + 1} / {total}"
choice = None
saved = state["responses"].get(item["id"])
if saved:
choice = saved["choice_label"]
return (
progress,
str(item["input"]),
str(item["left_video"]),
str(item["right_video"]),
choice,
)
def _record_choice(state, choice_label):
item = state["items"][state["index"]]
if choice_label == "Video A":
chosen_source = item["left_source"]
else:
chosen_source = item["right_source"]
state["responses"][item["id"]] = {
"id": item["id"],
"baseline": item["baseline"],
"input": str(item["input"]),
"ours": str(item["ours"]),
"baseline_video": str(item["baseline_video"]),
"left_video": str(item["left_video"]),
"right_video": str(item["right_video"]),
"left_source": item["left_source"],
"right_source": item["right_source"],
"choice_label": choice_label,
"chosen_source": chosen_source,
"choice_side": "left" if choice_label == "Video A" else "right",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"participant_id": state["participant_id"],
"seed": state["seed"],
}
def _write_exports(state):
timestamp = time.strftime("%Y%m%d_%H%M%S")
base = f"{timestamp}_{_sanitize(state['participant_id'])}_{state['seed']}"
date_dir = time.strftime("%Y%m%d")
dataset_prefix = f"sessions/{date_dir}/{base}"
rows = []
for item in state["items"]:
resp = state["responses"].get(item["id"], {})
rows.append(
{
"id": item["id"],
"baseline": item["baseline"],
"input": str(item["input"]),
"ours": str(item["ours"]),
"baseline_video": str(item["baseline_video"]),
"left_video": str(item["left_video"]),
"right_video": str(item["right_video"]),
"left_source": item["left_source"],
"right_source": item["right_source"],
"choice_label": resp.get("choice_label"),
"chosen_source": resp.get("chosen_source"),
"choice_side": resp.get("choice_side"),
"timestamp": resp.get("timestamp"),
"participant_id": state["participant_id"],
"seed": state["seed"],
}
)
csv_buffer = io.StringIO()
writer = csv.DictWriter(csv_buffer, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
csv_bytes = csv_buffer.getvalue().encode("utf-8")
json_payload = json.dumps(
{
"participant_id": state["participant_id"],
"seed": state["seed"],
"items": rows,
},
ensure_ascii=False,
indent=2,
).encode("utf-8")
if STORE_LOCAL:
OUTPUT_DIR.mkdir(exist_ok=True)
csv_path = OUTPUT_DIR / f"{base}.csv"
json_path = OUTPUT_DIR / f"{base}.json"
csv_path.write_bytes(csv_bytes)
json_path.write_bytes(json_payload)
if HF_DATASET_REPO and HF_TOKEN:
api = HfApi(token=HF_TOKEN)
api.create_repo(
repo_id=HF_DATASET_REPO,
repo_type="dataset",
private=True,
exist_ok=True,
)
api.upload_file(
path_or_fileobj=io.BytesIO(csv_bytes),
path_in_repo=f"{dataset_prefix}.csv",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=io.BytesIO(json_payload),
path_in_repo=f"{dataset_prefix}.json",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
)
return None
def start_session(participant_id):
seed = int(time.time() * 1000) % (2**31 - 1)
items, warnings = _collect_items(seed)
if not STORE_LOCAL and (not HF_DATASET_REPO or not HF_TOKEN):
warnings.append(
"Results will NOT be saved. Set HF_DATASET_REPO and HF_TOKEN in Space secrets."
)
state = {
"seed": seed,
"items": items,
"index": 0,
"responses": {},
"participant_id": (participant_id or "anonymous").strip(),
}
warning_text = "\n".join(f"- {w}" for w in warnings) if warnings else ""
if not items:
warning_text = "No valid items found. Please check your results folders."
progress, image, left, right, choice = _current_view(state)
status = warning_text or "Session ready."
return (
state,
status,
gr.update(visible=True),
gr.update(visible=False),
progress,
image,
left,
right,
choice,
)
def go_next(state, choice_label):
if not state:
return (
state,
"Please start the session.",
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
if not choice_label:
progress, image, left, right, choice = _current_view(state)
return (
state,
"Please choose Video A or Video B before continuing.",
gr.update(visible=True),
gr.update(visible=False),
progress,
image,
left,
right,
choice,
)
_record_choice(state, choice_label)
if state["index"] < len(state["items"]) - 1:
state["index"] += 1
progress, image, left, right, choice = _current_view(state)
return (
state,
"",
gr.update(visible=True),
gr.update(visible=False),
progress,
image,
left,
right,
choice,
)
_write_exports(state)
return (
state,
"",
gr.update(visible=False),
gr.update(visible=True),
gr.update(value=""),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
)
def go_prev(state):
if not state:
return (
state,
"Please start the session.",
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
if state["index"] > 0:
state["index"] -= 1
progress, image, left, right, choice = _current_view(state)
return (
state,
"",
gr.update(visible=True),
gr.update(visible=False),
progress,
image,
left,
right,
choice,
)
with gr.Blocks(title="Real-world Image-to-Video Survey", css=CSS) as demo:
gr.Markdown("# Real-world Image-to-Video Survey")
gr.Markdown(INSTRUCTION)
with gr.Row():
participant_id = gr.Textbox(label="Participant ID (optional)", placeholder="e.g., P012")
start_btn = gr.Button("Start New Session", variant="primary")
status = gr.Markdown("")
with gr.Column(visible=False) as question_box:
progress_md = gr.Markdown("")
input_image = gr.Image(label="Input Image", interactive=False, elem_id="input_image")
with gr.Row():
left_video = gr.Video(label="Video A", interactive=False, elem_id="video_left")
right_video = gr.Video(label="Video B", interactive=False, elem_id="video_right")
choice = gr.Radio(
choices=["Video A", "Video B"],
label="Which video is better overall?",
)
with gr.Row():
prev_btn = gr.Button("Previous")
next_btn = gr.Button("Next", variant="primary")
with gr.Column(visible=False) as done_box:
gr.Markdown("## All done")
gr.Markdown("Thank you! You can close this tab now.")
session_state = gr.State()
start_btn.click(
start_session,
inputs=[participant_id],
outputs=[
session_state,
status,
question_box,
done_box,
progress_md,
input_image,
left_video,
right_video,
choice,
],
)
next_btn.click(
go_next,
inputs=[session_state, choice],
outputs=[
session_state,
status,
question_box,
done_box,
progress_md,
input_image,
left_video,
right_video,
choice,
],
)
prev_btn.click(
go_prev,
inputs=[session_state],
outputs=[
session_state,
status,
question_box,
done_box,
progress_md,
input_image,
left_video,
right_video,
choice,
],
)
server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
demo.launch(
share=os.getenv("GRADIO_SHARE", "1") == "1",
server_name=server_name,
server_port=server_port,
)