Spaces:
Running
Running
| """ | |
| MovieBench Preference Ranking User Study Application (Gradio Version) | |
| A simplified Gradio web application for collecting human preference rankings | |
| of AI-generated movies. For each story, presents results from different methods | |
| side-by-side with shuffled anonymous labels, and collects preference ordering. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import threading | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| from huggingface_hub import CommitScheduler, snapshot_download | |
| logging.getLogger("huggingface_hub._commit_scheduler").setLevel(logging.DEBUG) | |
| logging.basicConfig( | |
| format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", | |
| level=logging.INFO, | |
| ) | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| DATA_DIR = os.environ.get("DATA_DIR", "./data") | |
| OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./results") | |
| NUM_GROUPS = int(os.environ.get("NUM_GROUPS", "10")) | |
| RESULTS_REPO_ID = os.environ.get("RESULTS_REPO_ID", "MovieBench/moviebench-results2") | |
| DATA_REPO_ID = os.environ.get("DATA_REPO_ID", "MovieBench/moviebench-data2") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| MAX_METHODS = 8 | |
| REFERENCE_AGENT = "EvoStoryGraph" | |
| if DATA_REPO_ID and not Path(DATA_DIR).exists(): | |
| print(f"Downloading data from {DATA_REPO_ID} ...") | |
| downloaded = snapshot_download( | |
| repo_id=DATA_REPO_ID, | |
| repo_type="dataset", | |
| local_dir=DATA_DIR, | |
| token=HF_TOKEN, | |
| ) | |
| print(f"Data downloaded to {downloaded}") | |
| Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) | |
| scheduler: Optional[CommitScheduler] = None | |
| if RESULTS_REPO_ID: | |
| print(f"Using scheduler for {RESULTS_REPO_ID} ...") | |
| scheduler = CommitScheduler( | |
| repo_id=RESULTS_REPO_ID, | |
| repo_type="dataset", | |
| folder_path=OUTPUT_DIR, | |
| every=1, | |
| path_in_repo="final_results/rk_pref", | |
| token=HF_TOKEN, | |
| ) | |
| # ============================================================================ | |
| # Data Loading Functions | |
| # ============================================================================ | |
| def _load_story_scripts() -> Dict[str, str]: | |
| """Load original story scripts from vistory_test_lite.json (keyed by story id).""" | |
| script_path = Path(DATA_DIR) / "vistory_test_lite.json" | |
| if script_path.exists(): | |
| with open(script_path, "r", encoding="utf-8-sig") as f: | |
| entries = json.load(f) | |
| return {entry["id"]: entry["script"]["cn"] for entry in entries} | |
| return {} | |
| STORY_SCRIPTS = _load_story_scripts() | |
| def load_summary() -> List[Dict[str, str]]: | |
| """Load summary.json that maps sample IDs to agents and story IDs.""" | |
| summary_path = Path(DATA_DIR) / "summary.json" | |
| if summary_path.exists(): | |
| with open(summary_path, "r", encoding="utf-8-sig") as f: | |
| return json.load(f) | |
| return [] | |
| def get_available_samples() -> List[str]: | |
| """Get list of available sample directory IDs.""" | |
| data_path = Path(DATA_DIR) | |
| if not data_path.exists(): | |
| return [] | |
| return sorted([d.name for d in data_path.iterdir() if d.is_dir()]) | |
| def get_stories_with_agents() -> Dict[str, List[Dict[str, str]]]: | |
| """Build mapping: story_id -> [{agent, shuffled_id}, ...].""" | |
| summary = load_summary() | |
| available = set(get_available_samples()) | |
| mapping: Dict[str, List[Dict[str, str]]] = {} | |
| for entry in summary: | |
| sid = entry["shuffled_id"] | |
| if sid not in available: | |
| continue | |
| story_id = entry["story_id"] | |
| mapping.setdefault(story_id, []).append({ | |
| "agent": entry["agent"], | |
| "shuffled_id": sid, | |
| }) | |
| for v in mapping.values(): | |
| v.sort(key=lambda x: x["agent"]) | |
| return mapping | |
| def get_movie_video_path(shuffled_id: str) -> str: | |
| """Return the path to a sample's final movie video.""" | |
| p = Path(DATA_DIR) / shuffled_id / "final_video.mp4" | |
| return str(p) if p.exists() else "" | |
| def load_characters(sample_id: str) -> List[Dict]: | |
| """Load characters.json for a given sample.""" | |
| characters_path = Path(DATA_DIR) / sample_id / "characters.json" | |
| if characters_path.exists(): | |
| with open(characters_path, "r", encoding="utf-8-sig") as f: | |
| return json.load(f) | |
| return [] | |
| def get_character_portraits(sample_id: str, characters: List[Dict]) -> List[Tuple[str, str]]: | |
| """Get character portrait paths. Returns list of (path, name) tuples. | |
| Uses the convention: character_portraits/{idx}_{name}/front.png | |
| """ | |
| result = [] | |
| portraits_dir = Path(DATA_DIR).resolve() / sample_id / "character_portraits" | |
| for char in characters: | |
| idx = char.get("idx", "") | |
| name = char.get("identifier_in_scene", "") | |
| portrait_path = portraits_dir / f"{idx}_{name}" / "front.png" | |
| if portrait_path.exists(): | |
| result.append((str(portrait_path), name)) | |
| return result | |
| def get_reference_portraits(story_id: str) -> List[Tuple[str, str]]: | |
| """Get character portraits from the REFERENCE_AGENT sample for a story.""" | |
| stories_map = get_stories_with_agents() | |
| entries = stories_map.get(story_id, []) | |
| ref_sid = "" | |
| for entry in entries: | |
| if entry["agent"] == REFERENCE_AGENT: | |
| ref_sid = entry["shuffled_id"] | |
| break | |
| if not ref_sid: | |
| return [] | |
| characters = load_characters(ref_sid) | |
| return get_character_portraits(ref_sid, characters) | |
| _save_lock = threading.Lock() | |
| # ============================================================================ | |
| # Group Management | |
| # ============================================================================ | |
| def _partition_list(items: List, num_chunks: int) -> List[List]: | |
| """Split items into num_chunks chunks as evenly as possible.""" | |
| chunk_size, remainder = divmod(len(items), num_chunks) | |
| chunks: List[List] = [] | |
| start = 0 | |
| for i in range(num_chunks): | |
| end = start + chunk_size + (1 if i < remainder else 0) | |
| chunks.append(items[start:end]) | |
| start = end | |
| return chunks | |
| def get_or_create_group_config(group_id: str) -> Dict[str, Any]: | |
| """Load existing group config or create a new one.""" | |
| group_dir = Path(OUTPUT_DIR) / f"group_{group_id}" | |
| mapping_path = group_dir / "mapping.json" | |
| if mapping_path.exists(): | |
| with open(mapping_path, "r", encoding="utf-8-sig") as f: | |
| return json.load(f) | |
| return create_group_config(group_id) | |
| def create_group_config(group_id: str) -> Dict[str, Any]: | |
| """Create a group config with deterministic story partitioning and method shuffle. | |
| Stories are shuffled with a fixed global seed and split into NUM_GROUPS | |
| non-overlapping chunks. The agent display order is shuffled per-group | |
| so that anonymous labels (Method A, B, ...) are consistent within a group | |
| but differ across groups. | |
| """ | |
| group_dir = Path(OUTPUT_DIR) / f"group_{group_id}" | |
| group_dir.mkdir(parents=True, exist_ok=True) | |
| stories_map = get_stories_with_agents() | |
| try: | |
| group_index = (int(group_id) - 1) % NUM_GROUPS | |
| except ValueError: | |
| group_index = hash(group_id) % NUM_GROUPS | |
| unique_stories = sorted(stories_map.keys()) | |
| story_rng = random.Random("moviebench_pref_story_partition") | |
| story_rng.shuffle(unique_stories) | |
| story_chunks = _partition_list(unique_stories, NUM_GROUPS) | |
| selected_stories = story_chunks[group_index] | |
| all_agents = set() | |
| for story_id in selected_stories: | |
| for entry in stories_map.get(story_id, []): | |
| all_agents.add(entry["agent"]) | |
| all_agents_sorted = sorted(all_agents) | |
| method_rng = random.Random(f"moviebench_pref_group_{group_id}") | |
| shuffled_agents = list(all_agents_sorted) | |
| method_rng.shuffle(shuffled_agents) | |
| labels = [chr(ord("A") + i) for i in range(len(shuffled_agents))] | |
| method_display_map = {} | |
| for i, agent in enumerate(shuffled_agents): | |
| method_display_map[f"Method {labels[i]}"] = agent | |
| presentation_rng = random.Random(f"moviebench_pref_order_{group_id}") | |
| story_order = list(selected_stories) | |
| presentation_rng.shuffle(story_order) | |
| config = { | |
| "group_id": group_id, | |
| "group_index": group_index, | |
| "num_groups": NUM_GROUPS, | |
| "created_at": datetime.now().isoformat(), | |
| "stories": story_order, | |
| "total_stories": len(unique_stories), | |
| "stories_in_group": len(story_order), | |
| "agents": all_agents_sorted, | |
| "method_order": shuffled_agents, | |
| "method_display_map": method_display_map, | |
| } | |
| with _save_lock: | |
| with open(group_dir / "mapping.json", "w", encoding="utf-8") as f: | |
| json.dump(config, f, indent=2, ensure_ascii=False) | |
| return config | |
| def _on_push_done(future): | |
| """Callback to surface push results/errors from the background thread.""" | |
| try: | |
| result = future.result() | |
| if result: | |
| print(f"[CommitScheduler] Push succeeded: {result.commit_url}") | |
| else: | |
| print("[CommitScheduler] Push skipped: no changed files detected") | |
| except Exception as e: | |
| print(f"[CommitScheduler] Push FAILED: {e}") | |
| def save_ranking_result( | |
| group_id: str, | |
| story_id: str, | |
| evaluator_id: str, | |
| method_display_map: Dict[str, str], | |
| ranking: Dict[str, int], | |
| comment: str, | |
| ) -> str: | |
| """Save a preference ranking result to JSON. Returns a status message.""" | |
| group_dir = Path(OUTPUT_DIR) / f"group_{group_id}" | |
| story_dir = group_dir / story_id | |
| filename = f"{story_id}_{evaluator_id}.json" | |
| result_data = { | |
| "evaluator_id": evaluator_id, | |
| "group_id": group_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "story_id": story_id, | |
| "method_order": method_display_map, | |
| "ranking": ranking, | |
| "comment": comment, | |
| } | |
| filepath = story_dir / filename | |
| with _save_lock: | |
| story_dir.mkdir(parents=True, exist_ok=True) | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| json.dump(result_data, f, indent=4, ensure_ascii=False) | |
| print(f"[Save] Written {filepath}") | |
| if scheduler is not None: | |
| print("[CommitScheduler] Triggering immediate push after save ...") | |
| future = scheduler.trigger() | |
| future.add_done_callback(_on_push_done) | |
| else: | |
| print("[CommitScheduler] WARNING: scheduler is None — RESULTS_REPO_ID not set?") | |
| return f"Saved to {filepath}" | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| max-width: 1600px !important; | |
| margin-left: auto !important; | |
| margin-right: auto !important; | |
| } | |
| .title-text { | |
| text-align: center; | |
| background: linear-gradient(135deg, #7c5cff 0%, #ff6b9d 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 2rem; | |
| font-weight: 700; | |
| margin-bottom: 1rem; | |
| } | |
| .method-label { | |
| text-align: center; | |
| font-size: 1.1rem; | |
| font-weight: 600; | |
| padding: 6px 0; | |
| } | |
| """ | |
| def create_app(): | |
| """Create the Gradio application.""" | |
| with gr.Blocks( | |
| title="MovieBench: Preference Ranking", | |
| css=CUSTOM_CSS, | |
| theme=gr.themes.Soft( | |
| primary_hue="purple", | |
| secondary_hue="pink", | |
| neutral_hue="slate", | |
| ), | |
| ) as app: | |
| current_evaluator = gr.State("anonymous") | |
| current_group = gr.State("") | |
| group_config_state = gr.State({}) | |
| current_story_idx = gr.State(0) | |
| gr.Markdown( | |
| "# MovieBench: Preference Ranking", | |
| elem_classes=["title-text"], | |
| ) | |
| # ================================================================ | |
| # Tab 1: Setup | |
| # ================================================================ | |
| with gr.Tab("Setup", id="tab_setup"): | |
| gr.Markdown("### Enter your evaluator ID and group ID to begin") | |
| with gr.Row(): | |
| evaluator_input = gr.Textbox( | |
| label="Evaluator ID", | |
| placeholder="Enter your name or ID", | |
| value="anonymous", | |
| scale=2, | |
| ) | |
| group_input = gr.Textbox( | |
| label="Group ID (auto-assigned, you may override)", | |
| placeholder=f"Auto-assigned (1-{NUM_GROUPS})", | |
| value="", | |
| scale=2, | |
| ) | |
| load_group_btn = gr.Button("Load / Create Group", variant="primary") | |
| group_info = gr.Markdown("*Enter a Group ID and click 'Load / Create Group'*") | |
| def load_group(group_id: str, evaluator_id: str): | |
| if not group_id: | |
| group_id = str(random.randint(1, NUM_GROUPS)) | |
| config = get_or_create_group_config(group_id) | |
| stories = config.get("stories", []) | |
| agents = config.get("agents", []) | |
| method_map = config.get("method_display_map", {}) | |
| display_lines = ", ".join(sorted(method_map.keys())) | |
| info_md = ( | |
| f"### Group `{group_id}` loaded " | |
| f"(partition {config.get('group_index', 0) + 1}/{config.get('num_groups', NUM_GROUPS)})\n\n" | |
| f"**Stories in group:** {len(stories)}/{config.get('total_stories', '?')}\n\n" | |
| f"**Agents:** {len(agents)} ({', '.join(agents)})\n\n" | |
| f"**Display labels:** {display_lines}\n\n" | |
| f"**Story order:** {', '.join(stories)}\n\n" | |
| f"**Created:** {config.get('created_at', 'N/A')}\n\n" | |
| f"Go to the **Preference Evaluation** tab to start ranking." | |
| ) | |
| return info_md, evaluator_id, group_id, config, gr.update(value=group_id) | |
| load_group_btn.click( | |
| load_group, | |
| inputs=[group_input, evaluator_input], | |
| outputs=[group_info, current_evaluator, current_group, group_config_state, group_input], | |
| ) | |
| # ================================================================ | |
| # Tab 2: Preference Evaluation | |
| # ================================================================ | |
| with gr.Tab("Preference Evaluation", id="tab_eval"): | |
| gr.Markdown("### Rank the methods by preference for each story") | |
| gr.Markdown( | |
| "> **Note:** 不需要考虑音频质量、音画同步,重点关注**视觉一致性**、" | |
| "**空间连贯性**、**叙事连贯性**、**剧本忠实度**、**视觉吸引力**。" | |
| ) | |
| with gr.Row(): | |
| story_progress = gr.Markdown("**Progress:** Load a group first") | |
| story_nav_prev = gr.Button("Previous Story", size="sm") | |
| story_nav_next = gr.Button("Next Story", size="sm") | |
| with gr.Accordion("Story Script", open=True): | |
| story_script_display = gr.Markdown( | |
| "*Load a group and go to this tab to see stories*" | |
| ) | |
| with gr.Accordion("Character References", open=True): | |
| char_gallery = gr.Gallery( | |
| label="Characters (from EvoStoryGraph)", | |
| columns=6, | |
| height=180, | |
| object_fit="contain", | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Method Videos") | |
| method_cols: List[gr.Column] = [] | |
| method_videos: List[gr.Video] = [] | |
| method_labels: List[gr.Markdown] = [] | |
| method_ranks: List[gr.Dropdown] = [] | |
| with gr.Row(): | |
| for i in range(MAX_METHODS): | |
| with gr.Column(visible=False) as col: | |
| lbl = gr.Markdown( | |
| f"**Method {chr(ord('A') + i)}**", | |
| elem_classes=["method-label"], | |
| ) | |
| vid = gr.Video( | |
| label=f"Method {chr(ord('A') + i)}", | |
| height=300, | |
| ) | |
| rank = gr.Dropdown( | |
| label="Rank", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| ) | |
| method_cols.append(col) | |
| method_videos.append(vid) | |
| method_labels.append(lbl) | |
| method_ranks.append(rank) | |
| gr.Markdown("---") | |
| rank_comment = gr.Textbox( | |
| label="Comment (optional)", | |
| placeholder="Any additional notes about your ranking decision...", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit & Next Story", variant="primary") | |
| eval_status = gr.Markdown("") | |
| # ============================================================ | |
| # Helper functions | |
| # ============================================================ | |
| def _build_story_display(story_idx: int, config: Dict[str, Any]): | |
| """Build all output values for displaying a given story. | |
| Returns a flat list matching the outputs wired to the UI: | |
| [progress_md, script_md, gallery_items, | |
| col_0_visible, vid_0, lbl_0, rank_0_choices, | |
| col_1_visible, vid_1, lbl_1, rank_1_choices, | |
| ... (MAX_METHODS times)] | |
| """ | |
| stories = config.get("stories", []) | |
| method_order: List[str] = config.get("method_order", []) | |
| method_display_map: Dict[str, str] = config.get("method_display_map", {}) | |
| stories_map = get_stories_with_agents() | |
| num_methods = len(method_order) | |
| rank_choices = [str(r) for r in range(1, num_methods + 1)] | |
| if not stories or story_idx >= len(stories): | |
| outputs: list = [ | |
| "**Progress:** No stories loaded", | |
| "*Load a group first*", | |
| [], | |
| ] | |
| for _ in range(MAX_METHODS): | |
| outputs.extend([ | |
| gr.update(visible=False), | |
| None, | |
| "", | |
| gr.update(choices=[], value=None), | |
| ]) | |
| return outputs | |
| story_id = stories[story_idx] | |
| script_text = STORY_SCRIPTS.get(story_id, "(Script not available)") | |
| progress_md = f"**Progress:** Story {story_idx + 1}/{len(stories)} (`{story_id}`)" | |
| script_md = f"**Story ID:** `{story_id}`\n\n{script_text}" | |
| gallery_items = get_reference_portraits(story_id) | |
| agent_to_sid: Dict[str, str] = {} | |
| for entry in stories_map.get(story_id, []): | |
| agent_to_sid[entry["agent"]] = entry["shuffled_id"] | |
| label_to_agent = {} | |
| for label in sorted(method_display_map.keys()): | |
| label_to_agent[label] = method_display_map[label] | |
| sorted_labels = sorted(label_to_agent.keys()) | |
| outputs = [progress_md, script_md, gallery_items] | |
| for i in range(MAX_METHODS): | |
| if i < len(sorted_labels): | |
| label = sorted_labels[i] | |
| agent = label_to_agent[label] | |
| sid = agent_to_sid.get(agent, "") | |
| video_path = get_movie_video_path(sid) if sid else "" | |
| outputs.extend([ | |
| gr.update(visible=True), | |
| video_path if video_path else None, | |
| f"**{label}**", | |
| gr.update(choices=rank_choices, value=None), | |
| ]) | |
| else: | |
| outputs.extend([ | |
| gr.update(visible=False), | |
| None, | |
| "", | |
| gr.update(choices=[], value=None), | |
| ]) | |
| return outputs | |
| def update_story_display(story_idx: int, config: Dict[str, Any]): | |
| return _build_story_display(story_idx, config) | |
| def go_prev_story(story_idx: int): | |
| return max(0, story_idx - 1) | |
| def go_next_story(story_idx: int, config: Dict[str, Any]): | |
| stories = config.get("stories", []) | |
| return min(len(stories) - 1, story_idx + 1) if stories else 0 | |
| def submit_ranking( | |
| story_idx: int, | |
| evaluator_id: str, | |
| group_id: str, | |
| config: Dict[str, Any], | |
| comment: str, | |
| *rank_values, | |
| ): | |
| """Validate and save the ranking, then advance to the next story.""" | |
| if not group_id or not config: | |
| return "Please load a group first", story_idx, gr.update() | |
| stories = config.get("stories", []) | |
| if not stories or story_idx >= len(stories): | |
| return "No stories available", story_idx, gr.update() | |
| method_display_map = config.get("method_display_map", {}) | |
| sorted_labels = sorted(method_display_map.keys()) | |
| num_methods = len(sorted_labels) | |
| ranking: Dict[str, int] = {} | |
| used_ranks = set() | |
| for i in range(num_methods): | |
| val = rank_values[i] if i < len(rank_values) else None | |
| if val is None or val == "": | |
| return ( | |
| f"Please assign a rank to **{sorted_labels[i]}**", | |
| story_idx, | |
| gr.update(), | |
| ) | |
| r = int(val) | |
| if r in used_ranks: | |
| return ( | |
| f"Duplicate rank {r} — each method must have a unique rank", | |
| story_idx, | |
| gr.update(), | |
| ) | |
| used_ranks.add(r) | |
| ranking[sorted_labels[i]] = r | |
| story_id = stories[story_idx] | |
| status = save_ranking_result( | |
| group_id=group_id, | |
| story_id=story_id, | |
| evaluator_id=evaluator_id, | |
| method_display_map=method_display_map, | |
| ranking=ranking, | |
| comment=comment or "", | |
| ) | |
| next_idx = min(len(stories) - 1, story_idx + 1) | |
| if next_idx == story_idx: | |
| return ( | |
| f"{status}\n\nAll stories evaluated! Thank you!", | |
| next_idx, | |
| "", | |
| ) | |
| return ( | |
| f"{status} | Moving to next story...", | |
| next_idx, | |
| "", | |
| ) | |
| # ============================================================ | |
| # Wire up events | |
| # ============================================================ | |
| display_outputs = [story_progress, story_script_display, char_gallery] | |
| for i in range(MAX_METHODS): | |
| display_outputs.extend([ | |
| method_cols[i], | |
| method_videos[i], | |
| method_labels[i], | |
| method_ranks[i], | |
| ]) | |
| # When group config changes, reset to story 0 | |
| group_config_state.change( | |
| lambda cfg: [0] + _build_story_display(0, cfg), | |
| inputs=[group_config_state], | |
| outputs=[current_story_idx] + display_outputs, | |
| ) | |
| # When story idx changes, update display | |
| current_story_idx.change( | |
| update_story_display, | |
| inputs=[current_story_idx, group_config_state], | |
| outputs=display_outputs, | |
| ) | |
| story_nav_prev.click( | |
| go_prev_story, | |
| inputs=[current_story_idx], | |
| outputs=[current_story_idx], | |
| ) | |
| story_nav_next.click( | |
| go_next_story, | |
| inputs=[current_story_idx, group_config_state], | |
| outputs=[current_story_idx], | |
| ) | |
| submit_inputs = [ | |
| current_story_idx, | |
| current_evaluator, | |
| current_group, | |
| group_config_state, | |
| rank_comment, | |
| ] + method_ranks | |
| submit_btn.click( | |
| submit_ranking, | |
| inputs=submit_inputs, | |
| outputs=[eval_status, current_story_idx, rank_comment], | |
| ) | |
| def _assign_random_group(): | |
| return str(random.randint(1, NUM_GROUPS)) | |
| app.load(_assign_random_group, outputs=[group_input]) | |
| return app | |
| # ============================================================================ | |
| # Main Entry Point | |
| # ============================================================================ | |
| demo = create_app() | |
| if __name__ == "__main__": | |
| data_dir_abs = str(Path(DATA_DIR).resolve()) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| allowed_paths=[data_dir_abs], | |
| ) | |