MovieBench2 / app.py
evanzyfan
update app.py
4787686
"""
MovieBench Preference Ranking User Study Application (Gradio Version)
A simplified Gradio web application for collecting human preference rankings
of AI-generated movies. For each story, presents results from different methods
side-by-side with shuffled anonymous labels, and collects preference ordering.
"""
import json
import logging
import os
import random
import threading
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
from huggingface_hub import CommitScheduler, snapshot_download
logging.getLogger("huggingface_hub._commit_scheduler").setLevel(logging.DEBUG)
logging.basicConfig(
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
level=logging.INFO,
)
# ============================================================================
# Configuration
# ============================================================================
DATA_DIR = os.environ.get("DATA_DIR", "./data")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./results")
NUM_GROUPS = int(os.environ.get("NUM_GROUPS", "10"))
RESULTS_REPO_ID = os.environ.get("RESULTS_REPO_ID", "MovieBench/moviebench-results2")
DATA_REPO_ID = os.environ.get("DATA_REPO_ID", "MovieBench/moviebench-data2")
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MAX_METHODS = 8
REFERENCE_AGENT = "EvoStoryGraph"
if DATA_REPO_ID and not Path(DATA_DIR).exists():
print(f"Downloading data from {DATA_REPO_ID} ...")
downloaded = snapshot_download(
repo_id=DATA_REPO_ID,
repo_type="dataset",
local_dir=DATA_DIR,
token=HF_TOKEN,
)
print(f"Data downloaded to {downloaded}")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
scheduler: Optional[CommitScheduler] = None
if RESULTS_REPO_ID:
print(f"Using scheduler for {RESULTS_REPO_ID} ...")
scheduler = CommitScheduler(
repo_id=RESULTS_REPO_ID,
repo_type="dataset",
folder_path=OUTPUT_DIR,
every=1,
path_in_repo="final_results/rk_pref",
token=HF_TOKEN,
)
# ============================================================================
# Data Loading Functions
# ============================================================================
def _load_story_scripts() -> Dict[str, str]:
"""Load original story scripts from vistory_test_lite.json (keyed by story id)."""
script_path = Path(DATA_DIR) / "vistory_test_lite.json"
if script_path.exists():
with open(script_path, "r", encoding="utf-8-sig") as f:
entries = json.load(f)
return {entry["id"]: entry["script"]["cn"] for entry in entries}
return {}
STORY_SCRIPTS = _load_story_scripts()
def load_summary() -> List[Dict[str, str]]:
"""Load summary.json that maps sample IDs to agents and story IDs."""
summary_path = Path(DATA_DIR) / "summary.json"
if summary_path.exists():
with open(summary_path, "r", encoding="utf-8-sig") as f:
return json.load(f)
return []
def get_available_samples() -> List[str]:
"""Get list of available sample directory IDs."""
data_path = Path(DATA_DIR)
if not data_path.exists():
return []
return sorted([d.name for d in data_path.iterdir() if d.is_dir()])
def get_stories_with_agents() -> Dict[str, List[Dict[str, str]]]:
"""Build mapping: story_id -> [{agent, shuffled_id}, ...]."""
summary = load_summary()
available = set(get_available_samples())
mapping: Dict[str, List[Dict[str, str]]] = {}
for entry in summary:
sid = entry["shuffled_id"]
if sid not in available:
continue
story_id = entry["story_id"]
mapping.setdefault(story_id, []).append({
"agent": entry["agent"],
"shuffled_id": sid,
})
for v in mapping.values():
v.sort(key=lambda x: x["agent"])
return mapping
def get_movie_video_path(shuffled_id: str) -> str:
"""Return the path to a sample's final movie video."""
p = Path(DATA_DIR) / shuffled_id / "final_video.mp4"
return str(p) if p.exists() else ""
def load_characters(sample_id: str) -> List[Dict]:
"""Load characters.json for a given sample."""
characters_path = Path(DATA_DIR) / sample_id / "characters.json"
if characters_path.exists():
with open(characters_path, "r", encoding="utf-8-sig") as f:
return json.load(f)
return []
def get_character_portraits(sample_id: str, characters: List[Dict]) -> List[Tuple[str, str]]:
"""Get character portrait paths. Returns list of (path, name) tuples.
Uses the convention: character_portraits/{idx}_{name}/front.png
"""
result = []
portraits_dir = Path(DATA_DIR).resolve() / sample_id / "character_portraits"
for char in characters:
idx = char.get("idx", "")
name = char.get("identifier_in_scene", "")
portrait_path = portraits_dir / f"{idx}_{name}" / "front.png"
if portrait_path.exists():
result.append((str(portrait_path), name))
return result
def get_reference_portraits(story_id: str) -> List[Tuple[str, str]]:
"""Get character portraits from the REFERENCE_AGENT sample for a story."""
stories_map = get_stories_with_agents()
entries = stories_map.get(story_id, [])
ref_sid = ""
for entry in entries:
if entry["agent"] == REFERENCE_AGENT:
ref_sid = entry["shuffled_id"]
break
if not ref_sid:
return []
characters = load_characters(ref_sid)
return get_character_portraits(ref_sid, characters)
_save_lock = threading.Lock()
# ============================================================================
# Group Management
# ============================================================================
def _partition_list(items: List, num_chunks: int) -> List[List]:
"""Split items into num_chunks chunks as evenly as possible."""
chunk_size, remainder = divmod(len(items), num_chunks)
chunks: List[List] = []
start = 0
for i in range(num_chunks):
end = start + chunk_size + (1 if i < remainder else 0)
chunks.append(items[start:end])
start = end
return chunks
def get_or_create_group_config(group_id: str) -> Dict[str, Any]:
"""Load existing group config or create a new one."""
group_dir = Path(OUTPUT_DIR) / f"group_{group_id}"
mapping_path = group_dir / "mapping.json"
if mapping_path.exists():
with open(mapping_path, "r", encoding="utf-8-sig") as f:
return json.load(f)
return create_group_config(group_id)
def create_group_config(group_id: str) -> Dict[str, Any]:
"""Create a group config with deterministic story partitioning and method shuffle.
Stories are shuffled with a fixed global seed and split into NUM_GROUPS
non-overlapping chunks. The agent display order is shuffled per-group
so that anonymous labels (Method A, B, ...) are consistent within a group
but differ across groups.
"""
group_dir = Path(OUTPUT_DIR) / f"group_{group_id}"
group_dir.mkdir(parents=True, exist_ok=True)
stories_map = get_stories_with_agents()
try:
group_index = (int(group_id) - 1) % NUM_GROUPS
except ValueError:
group_index = hash(group_id) % NUM_GROUPS
unique_stories = sorted(stories_map.keys())
story_rng = random.Random("moviebench_pref_story_partition")
story_rng.shuffle(unique_stories)
story_chunks = _partition_list(unique_stories, NUM_GROUPS)
selected_stories = story_chunks[group_index]
all_agents = set()
for story_id in selected_stories:
for entry in stories_map.get(story_id, []):
all_agents.add(entry["agent"])
all_agents_sorted = sorted(all_agents)
method_rng = random.Random(f"moviebench_pref_group_{group_id}")
shuffled_agents = list(all_agents_sorted)
method_rng.shuffle(shuffled_agents)
labels = [chr(ord("A") + i) for i in range(len(shuffled_agents))]
method_display_map = {}
for i, agent in enumerate(shuffled_agents):
method_display_map[f"Method {labels[i]}"] = agent
presentation_rng = random.Random(f"moviebench_pref_order_{group_id}")
story_order = list(selected_stories)
presentation_rng.shuffle(story_order)
config = {
"group_id": group_id,
"group_index": group_index,
"num_groups": NUM_GROUPS,
"created_at": datetime.now().isoformat(),
"stories": story_order,
"total_stories": len(unique_stories),
"stories_in_group": len(story_order),
"agents": all_agents_sorted,
"method_order": shuffled_agents,
"method_display_map": method_display_map,
}
with _save_lock:
with open(group_dir / "mapping.json", "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
return config
def _on_push_done(future):
"""Callback to surface push results/errors from the background thread."""
try:
result = future.result()
if result:
print(f"[CommitScheduler] Push succeeded: {result.commit_url}")
else:
print("[CommitScheduler] Push skipped: no changed files detected")
except Exception as e:
print(f"[CommitScheduler] Push FAILED: {e}")
def save_ranking_result(
group_id: str,
story_id: str,
evaluator_id: str,
method_display_map: Dict[str, str],
ranking: Dict[str, int],
comment: str,
) -> str:
"""Save a preference ranking result to JSON. Returns a status message."""
group_dir = Path(OUTPUT_DIR) / f"group_{group_id}"
story_dir = group_dir / story_id
filename = f"{story_id}_{evaluator_id}.json"
result_data = {
"evaluator_id": evaluator_id,
"group_id": group_id,
"timestamp": datetime.now().isoformat(),
"story_id": story_id,
"method_order": method_display_map,
"ranking": ranking,
"comment": comment,
}
filepath = story_dir / filename
with _save_lock:
story_dir.mkdir(parents=True, exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(result_data, f, indent=4, ensure_ascii=False)
print(f"[Save] Written {filepath}")
if scheduler is not None:
print("[CommitScheduler] Triggering immediate push after save ...")
future = scheduler.trigger()
future.add_done_callback(_on_push_done)
else:
print("[CommitScheduler] WARNING: scheduler is None — RESULTS_REPO_ID not set?")
return f"Saved to {filepath}"
# ============================================================================
# Gradio Interface
# ============================================================================
CUSTOM_CSS = """
.gradio-container {
max-width: 1600px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.title-text {
text-align: center;
background: linear-gradient(135deg, #7c5cff 0%, #ff6b9d 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 2rem;
font-weight: 700;
margin-bottom: 1rem;
}
.method-label {
text-align: center;
font-size: 1.1rem;
font-weight: 600;
padding: 6px 0;
}
"""
def create_app():
"""Create the Gradio application."""
with gr.Blocks(
title="MovieBench: Preference Ranking",
css=CUSTOM_CSS,
theme=gr.themes.Soft(
primary_hue="purple",
secondary_hue="pink",
neutral_hue="slate",
),
) as app:
current_evaluator = gr.State("anonymous")
current_group = gr.State("")
group_config_state = gr.State({})
current_story_idx = gr.State(0)
gr.Markdown(
"# MovieBench: Preference Ranking",
elem_classes=["title-text"],
)
# ================================================================
# Tab 1: Setup
# ================================================================
with gr.Tab("Setup", id="tab_setup"):
gr.Markdown("### Enter your evaluator ID and group ID to begin")
with gr.Row():
evaluator_input = gr.Textbox(
label="Evaluator ID",
placeholder="Enter your name or ID",
value="anonymous",
scale=2,
)
group_input = gr.Textbox(
label="Group ID (auto-assigned, you may override)",
placeholder=f"Auto-assigned (1-{NUM_GROUPS})",
value="",
scale=2,
)
load_group_btn = gr.Button("Load / Create Group", variant="primary")
group_info = gr.Markdown("*Enter a Group ID and click 'Load / Create Group'*")
def load_group(group_id: str, evaluator_id: str):
if not group_id:
group_id = str(random.randint(1, NUM_GROUPS))
config = get_or_create_group_config(group_id)
stories = config.get("stories", [])
agents = config.get("agents", [])
method_map = config.get("method_display_map", {})
display_lines = ", ".join(sorted(method_map.keys()))
info_md = (
f"### Group `{group_id}` loaded "
f"(partition {config.get('group_index', 0) + 1}/{config.get('num_groups', NUM_GROUPS)})\n\n"
f"**Stories in group:** {len(stories)}/{config.get('total_stories', '?')}\n\n"
f"**Agents:** {len(agents)} ({', '.join(agents)})\n\n"
f"**Display labels:** {display_lines}\n\n"
f"**Story order:** {', '.join(stories)}\n\n"
f"**Created:** {config.get('created_at', 'N/A')}\n\n"
f"Go to the **Preference Evaluation** tab to start ranking."
)
return info_md, evaluator_id, group_id, config, gr.update(value=group_id)
load_group_btn.click(
load_group,
inputs=[group_input, evaluator_input],
outputs=[group_info, current_evaluator, current_group, group_config_state, group_input],
)
# ================================================================
# Tab 2: Preference Evaluation
# ================================================================
with gr.Tab("Preference Evaluation", id="tab_eval"):
gr.Markdown("### Rank the methods by preference for each story")
gr.Markdown(
"> **Note:** 不需要考虑音频质量、音画同步,重点关注**视觉一致性**、"
"**空间连贯性**、**叙事连贯性**、**剧本忠实度**、**视觉吸引力**。"
)
with gr.Row():
story_progress = gr.Markdown("**Progress:** Load a group first")
story_nav_prev = gr.Button("Previous Story", size="sm")
story_nav_next = gr.Button("Next Story", size="sm")
with gr.Accordion("Story Script", open=True):
story_script_display = gr.Markdown(
"*Load a group and go to this tab to see stories*"
)
with gr.Accordion("Character References", open=True):
char_gallery = gr.Gallery(
label="Characters (from EvoStoryGraph)",
columns=6,
height=180,
object_fit="contain",
)
gr.Markdown("---")
gr.Markdown("### Method Videos")
method_cols: List[gr.Column] = []
method_videos: List[gr.Video] = []
method_labels: List[gr.Markdown] = []
method_ranks: List[gr.Dropdown] = []
with gr.Row():
for i in range(MAX_METHODS):
with gr.Column(visible=False) as col:
lbl = gr.Markdown(
f"**Method {chr(ord('A') + i)}**",
elem_classes=["method-label"],
)
vid = gr.Video(
label=f"Method {chr(ord('A') + i)}",
height=300,
)
rank = gr.Dropdown(
label="Rank",
choices=[],
value=None,
interactive=True,
)
method_cols.append(col)
method_videos.append(vid)
method_labels.append(lbl)
method_ranks.append(rank)
gr.Markdown("---")
rank_comment = gr.Textbox(
label="Comment (optional)",
placeholder="Any additional notes about your ranking decision...",
lines=2,
)
with gr.Row():
submit_btn = gr.Button("Submit & Next Story", variant="primary")
eval_status = gr.Markdown("")
# ============================================================
# Helper functions
# ============================================================
def _build_story_display(story_idx: int, config: Dict[str, Any]):
"""Build all output values for displaying a given story.
Returns a flat list matching the outputs wired to the UI:
[progress_md, script_md, gallery_items,
col_0_visible, vid_0, lbl_0, rank_0_choices,
col_1_visible, vid_1, lbl_1, rank_1_choices,
... (MAX_METHODS times)]
"""
stories = config.get("stories", [])
method_order: List[str] = config.get("method_order", [])
method_display_map: Dict[str, str] = config.get("method_display_map", {})
stories_map = get_stories_with_agents()
num_methods = len(method_order)
rank_choices = [str(r) for r in range(1, num_methods + 1)]
if not stories or story_idx >= len(stories):
outputs: list = [
"**Progress:** No stories loaded",
"*Load a group first*",
[],
]
for _ in range(MAX_METHODS):
outputs.extend([
gr.update(visible=False),
None,
"",
gr.update(choices=[], value=None),
])
return outputs
story_id = stories[story_idx]
script_text = STORY_SCRIPTS.get(story_id, "(Script not available)")
progress_md = f"**Progress:** Story {story_idx + 1}/{len(stories)} (`{story_id}`)"
script_md = f"**Story ID:** `{story_id}`\n\n{script_text}"
gallery_items = get_reference_portraits(story_id)
agent_to_sid: Dict[str, str] = {}
for entry in stories_map.get(story_id, []):
agent_to_sid[entry["agent"]] = entry["shuffled_id"]
label_to_agent = {}
for label in sorted(method_display_map.keys()):
label_to_agent[label] = method_display_map[label]
sorted_labels = sorted(label_to_agent.keys())
outputs = [progress_md, script_md, gallery_items]
for i in range(MAX_METHODS):
if i < len(sorted_labels):
label = sorted_labels[i]
agent = label_to_agent[label]
sid = agent_to_sid.get(agent, "")
video_path = get_movie_video_path(sid) if sid else ""
outputs.extend([
gr.update(visible=True),
video_path if video_path else None,
f"**{label}**",
gr.update(choices=rank_choices, value=None),
])
else:
outputs.extend([
gr.update(visible=False),
None,
"",
gr.update(choices=[], value=None),
])
return outputs
def update_story_display(story_idx: int, config: Dict[str, Any]):
return _build_story_display(story_idx, config)
def go_prev_story(story_idx: int):
return max(0, story_idx - 1)
def go_next_story(story_idx: int, config: Dict[str, Any]):
stories = config.get("stories", [])
return min(len(stories) - 1, story_idx + 1) if stories else 0
def submit_ranking(
story_idx: int,
evaluator_id: str,
group_id: str,
config: Dict[str, Any],
comment: str,
*rank_values,
):
"""Validate and save the ranking, then advance to the next story."""
if not group_id or not config:
return "Please load a group first", story_idx, gr.update()
stories = config.get("stories", [])
if not stories or story_idx >= len(stories):
return "No stories available", story_idx, gr.update()
method_display_map = config.get("method_display_map", {})
sorted_labels = sorted(method_display_map.keys())
num_methods = len(sorted_labels)
ranking: Dict[str, int] = {}
used_ranks = set()
for i in range(num_methods):
val = rank_values[i] if i < len(rank_values) else None
if val is None or val == "":
return (
f"Please assign a rank to **{sorted_labels[i]}**",
story_idx,
gr.update(),
)
r = int(val)
if r in used_ranks:
return (
f"Duplicate rank {r} — each method must have a unique rank",
story_idx,
gr.update(),
)
used_ranks.add(r)
ranking[sorted_labels[i]] = r
story_id = stories[story_idx]
status = save_ranking_result(
group_id=group_id,
story_id=story_id,
evaluator_id=evaluator_id,
method_display_map=method_display_map,
ranking=ranking,
comment=comment or "",
)
next_idx = min(len(stories) - 1, story_idx + 1)
if next_idx == story_idx:
return (
f"{status}\n\nAll stories evaluated! Thank you!",
next_idx,
"",
)
return (
f"{status} | Moving to next story...",
next_idx,
"",
)
# ============================================================
# Wire up events
# ============================================================
display_outputs = [story_progress, story_script_display, char_gallery]
for i in range(MAX_METHODS):
display_outputs.extend([
method_cols[i],
method_videos[i],
method_labels[i],
method_ranks[i],
])
# When group config changes, reset to story 0
group_config_state.change(
lambda cfg: [0] + _build_story_display(0, cfg),
inputs=[group_config_state],
outputs=[current_story_idx] + display_outputs,
)
# When story idx changes, update display
current_story_idx.change(
update_story_display,
inputs=[current_story_idx, group_config_state],
outputs=display_outputs,
)
story_nav_prev.click(
go_prev_story,
inputs=[current_story_idx],
outputs=[current_story_idx],
)
story_nav_next.click(
go_next_story,
inputs=[current_story_idx, group_config_state],
outputs=[current_story_idx],
)
submit_inputs = [
current_story_idx,
current_evaluator,
current_group,
group_config_state,
rank_comment,
] + method_ranks
submit_btn.click(
submit_ranking,
inputs=submit_inputs,
outputs=[eval_status, current_story_idx, rank_comment],
)
def _assign_random_group():
return str(random.randint(1, NUM_GROUPS))
app.load(_assign_random_group, outputs=[group_input])
return app
# ============================================================================
# Main Entry Point
# ============================================================================
demo = create_app()
if __name__ == "__main__":
data_dir_abs = str(Path(DATA_DIR).resolve())
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
allowed_paths=[data_dir_abs],
)