Spaces:
Runtime error
Runtime error
| import argparse | |
| import html | |
| import json | |
| import random | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| import gradio as gr | |
| Task = Dict[str, Any] | |
| Result = Dict[str, Any] | |
| CUSTOM_CSS = """ | |
| .app-title { | |
| text-align: center; | |
| font-size: 2.2rem; | |
| font-weight: 600; | |
| margin-bottom: 0.8rem; | |
| } | |
| .prompt-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| margin-bottom: 1.2rem; | |
| } | |
| .prompt-text { | |
| font-size: 1.6rem; | |
| font-weight: 500; | |
| text-align: center; | |
| max-width: 960px; | |
| } | |
| .image-row { | |
| display: flex; | |
| justify-content: center; | |
| gap: 1.5rem; | |
| } | |
| .button-row { | |
| display: flex; | |
| justify-content: center; | |
| gap: 1.0rem; | |
| } | |
| .status-text { | |
| text-align: center; | |
| } | |
| .progress-text { | |
| text-align: center; | |
| } | |
| """ | |
| def resolve_path(path_str: str, base_dir: Path) -> Path: | |
| path = Path(path_str) | |
| if not path.is_absolute(): | |
| path = (base_dir / path).resolve() | |
| return path | |
| def load_tasks(tasks_path: Path, image_root: Optional[Path]) -> List[Task]: | |
| with tasks_path.open("r", encoding="utf-8") as f: | |
| raw_tasks: Sequence[Task] = json.load(f) | |
| base_dir = image_root if image_root is not None else tasks_path.parent | |
| tasks: List[Task] = [] | |
| for index, entry in enumerate(raw_tasks): | |
| task = dict(entry) | |
| task["original_index"] = index | |
| original_left_image = entry.get("left_image") | |
| original_right_image = entry.get("right_image") | |
| task["original_left_image"] = original_left_image | |
| task["original_right_image"] = original_right_image | |
| task["original_left_model"] = entry.get("model_left") | |
| task["original_right_model"] = entry.get("model_right") | |
| task["original_left_meta"] = entry.get("left_meta") | |
| task["original_right_meta"] = entry.get("right_meta") | |
| task["original_left_image_path"] = str( | |
| resolve_path(original_left_image, base_dir) | |
| ) | |
| task["original_right_image_path"] = str( | |
| resolve_path(original_right_image, base_dir) | |
| ) | |
| tasks.append(task) | |
| return tasks | |
| def prepare_display_plan(tasks: List[Task], seed: Optional[int] = None) -> None: | |
| rng = random.Random(seed) | |
| rng.shuffle(tasks) | |
| for order_index, task in enumerate(tasks): | |
| swap = rng.choice([False, True]) | |
| task["display_order"] = order_index | |
| task["display_swap"] = swap | |
| if swap: | |
| task["display_left_image"] = task.get("original_right_image") | |
| task["display_right_image"] = task.get("original_left_image") | |
| task["display_left_image_path"] = task.get("original_right_image_path") | |
| task["display_right_image_path"] = task.get("original_left_image_path") | |
| task["display_left_model"] = task.get("original_right_model") | |
| task["display_right_model"] = task.get("original_left_model") | |
| else: | |
| task["display_left_image"] = task.get("original_left_image") | |
| task["display_right_image"] = task.get("original_right_image") | |
| task["display_left_image_path"] = task.get("original_left_image_path") | |
| task["display_right_image_path"] = task.get("original_right_image_path") | |
| task["display_left_model"] = task.get("original_left_model") | |
| task["display_right_model"] = task.get("original_right_model") | |
| def load_results(results_path: Path) -> List[Result]: | |
| if not results_path.exists(): | |
| return [] | |
| with results_path.open("r", encoding="utf-8") as f: | |
| content = f.read().strip() | |
| if not content: | |
| return [] | |
| try: | |
| loaded: List[Result] = json.loads(content) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Unable to parse result file: {results_path}") from exc | |
| return loaded | |
| def ensure_parent_dir(path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| def save_results(results_path: Path, results: Sequence[Result]) -> None: | |
| ensure_parent_dir(results_path) | |
| with results_path.open("w", encoding="utf-8") as f: | |
| json.dump(list(results), f, ensure_ascii=False, indent=2) | |
| def build_prompt_html(task: Task) -> str: | |
| prompt = (task.get("prompt") or "").strip() | |
| if not prompt: | |
| prompt = "No prompt provided." | |
| escaped = html.escape(prompt) | |
| return f"<div class='prompt-container'><div class='prompt-text'>{escaped}</div></div>" | |
| def format_progress(completed: int, total: int) -> str: | |
| return f"Progress: {completed}/{total}" | |
| def find_next_index(tasks: Sequence[Task], completed_ids: Sequence[int]) -> Optional[int]: | |
| completed_set = set(filter(None, completed_ids)) | |
| for idx, task in enumerate(tasks): | |
| task_id = task.get("id") | |
| if task_id is None or task_id not in completed_set: | |
| return idx | |
| return None | |
| def vote_handler_factory( | |
| tasks: Sequence[Task], | |
| results_path: Path, | |
| ) -> Any: | |
| total = len(tasks) | |
| def handler( | |
| choice: str, | |
| current_index: Optional[int], | |
| completed_ids: Sequence[int], | |
| results: Sequence[Result], | |
| ) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]: | |
| completed_set = set(filter(None, completed_ids)) | |
| results_list = list(results) | |
| if current_index is None or current_index >= total: | |
| status_text = "All tasks completed." | |
| return ( | |
| "<div class='prompt-container'><div class='prompt-text'>All tasks completed.</div></div>", | |
| "", | |
| "", | |
| format_progress(len(completed_set), total), | |
| status_text, | |
| None, | |
| list(completed_set), | |
| results_list, | |
| ) | |
| task = tasks[current_index] | |
| task_id = task.get("id", current_index) | |
| result_entry: Result = { | |
| "task_id": task_id, | |
| "display_order": task.get("display_order"), | |
| "original_index": task.get("original_index"), | |
| "prompt_id": task.get("prompt_id"), | |
| "pair_index": task.get("pair_index"), | |
| "original_left_model": task.get("original_left_model"), | |
| "original_right_model": task.get("original_right_model"), | |
| "original_left_image": task.get("original_left_image"), | |
| "original_right_image": task.get("original_right_image"), | |
| "display_left_model": task.get("display_left_model"), | |
| "display_right_model": task.get("display_right_model"), | |
| "display_left_image": task.get("display_left_image"), | |
| "display_right_image": task.get("display_right_image"), | |
| "display_swap": task.get("display_swap"), | |
| "choice": choice, | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| "prompt": task.get("prompt"), | |
| } | |
| results_list.append(result_entry) | |
| save_results(results_path, results_list) | |
| completed_set.add(task_id) | |
| next_index = find_next_index(tasks, completed_set) | |
| if next_index is None: | |
| status_text = f"Recorded choice `{choice}`. All tasks completed." | |
| return ( | |
| "<div class='prompt-container'><div class='prompt-text'>All tasks completed.</div></div>", | |
| "", | |
| "", | |
| format_progress(len(completed_set), total), | |
| status_text, | |
| None, | |
| list(completed_set), | |
| results_list, | |
| ) | |
| next_task = tasks[next_index] | |
| status_text = f"Recorded choice `{choice}`." | |
| return ( | |
| build_prompt_html(next_task), | |
| next_task["display_left_image_path"], | |
| next_task["display_right_image_path"], | |
| format_progress(len(completed_set), total), | |
| status_text, | |
| next_index, | |
| list(completed_set), | |
| results_list, | |
| ) | |
| return handler | |
| def launch_interface( | |
| tasks: Sequence[Task], | |
| results_path: Path, | |
| initial_results: Sequence[Result], | |
| *, | |
| share: bool, | |
| server_name: Optional[str], | |
| server_port: Optional[int], | |
| ) -> None: | |
| completed_ids = {entry.get("task_id", entry.get("id")) for entry in initial_results} | |
| completed_ids.discard(None) | |
| total = len(tasks) | |
| initial_index = find_next_index(tasks, list(completed_ids)) | |
| with gr.Blocks( | |
| title="Human Preference Voting", | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| gr.HTML("<div class='app-title'>Human Preference Voting</div>") | |
| prompt_html = gr.HTML() | |
| with gr.Row(elem_classes="image-row"): | |
| left_image = gr.Image(label="Left Image", type="filepath", show_label=True) | |
| right_image = gr.Image(label="Right Image", type="filepath", show_label=True) | |
| progress_md = gr.Markdown(elem_classes="progress-text") | |
| status_md = gr.Markdown(elem_classes="status-text") | |
| with gr.Row(elem_classes="button-row"): | |
| left_btn = gr.Button("Prefer Left", variant="primary") | |
| tie_btn = gr.Button("No Preference", variant="secondary") | |
| right_btn = gr.Button("Prefer Right", variant="primary") | |
| current_index_state = gr.State(initial_index) | |
| completed_state = gr.State(list(completed_ids)) | |
| results_state = gr.State(list(initial_results)) | |
| def init_view( | |
| index: Optional[int], | |
| completed: Sequence[int], | |
| results_data: Sequence[Result], | |
| ) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]: | |
| if total == 0: | |
| return ( | |
| "<div class='prompt-container'><div class='prompt-text'>No tasks available.</div></div>", | |
| "", | |
| "", | |
| format_progress(len(completed), total), | |
| "No tasks available.", | |
| None, | |
| list(completed), | |
| list(results_data), | |
| ) | |
| if index is None: | |
| return ( | |
| "<div class='prompt-container'><div class='prompt-text'>All tasks completed.</div></div>", | |
| "", | |
| "", | |
| format_progress(len(completed), total), | |
| "All tasks completed.", | |
| None, | |
| list(completed), | |
| list(results_data), | |
| ) | |
| task = tasks[index] | |
| return ( | |
| build_prompt_html(task), | |
| task["display_left_image_path"], | |
| task["display_right_image_path"], | |
| format_progress(len(completed), total), | |
| "Select a preference to continue.", | |
| index, | |
| list(completed), | |
| list(results_data), | |
| ) | |
| demo.load( | |
| init_view, | |
| inputs=[current_index_state, completed_state, results_state], | |
| outputs=[ | |
| prompt_html, | |
| left_image, | |
| right_image, | |
| progress_md, | |
| status_md, | |
| current_index_state, | |
| completed_state, | |
| results_state, | |
| ], | |
| ) | |
| handler = vote_handler_factory(tasks, results_path) | |
| def make_click_fn(choice_value: str): | |
| def on_click( | |
| current_index: Optional[int], | |
| completed_ids_value: Sequence[int], | |
| results_data: Sequence[Result], | |
| ) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]: | |
| return handler(choice_value, current_index, completed_ids_value, results_data) | |
| return on_click | |
| for button, vote_choice in [ | |
| (left_btn, "left"), | |
| (tie_btn, "tie"), | |
| (right_btn, "right"), | |
| ]: | |
| button.click( | |
| make_click_fn(vote_choice), | |
| inputs=[current_index_state, completed_state, results_state], | |
| outputs=[ | |
| prompt_html, | |
| left_image, | |
| right_image, | |
| progress_md, | |
| status_md, | |
| current_index_state, | |
| completed_state, | |
| results_state, | |
| ], | |
| ) | |
| demo.queue() | |
| demo.launch( | |
| share=share, | |
| server_name=server_name, | |
| server_port=server_port, | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Gradio interface for human preference voting between image pairs.", | |
| ) | |
| parser.add_argument( | |
| "--tasks", | |
| required=True, | |
| type=Path, | |
| help="Path to tasks JSON file, e.g. ./tasks.json", | |
| ) | |
| parser.add_argument( | |
| "--out", | |
| type=Path, | |
| default=Path("votes.json"), | |
| help="Path to output JSON file for votes (default: ./votes.json).", | |
| ) | |
| parser.add_argument( | |
| "--image-root", | |
| type=Path, | |
| default=None, | |
| help="Optional image root directory overriding task file location.", | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| help="Enable Gradio share link for external access.", | |
| ) | |
| parser.add_argument( | |
| "--server-name", | |
| type=str, | |
| default=None, | |
| help="Host address to bind, e.g. 0.0.0.0 for LAN access.", | |
| ) | |
| parser.add_argument( | |
| "--server-port", | |
| type=int, | |
| default=None, | |
| help="Port number to bind; defaults to Gradio's choice.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="Optional random seed to make task order reproducible.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| tasks_path: Path = args.tasks | |
| results_path: Path = args.out | |
| image_root: Optional[Path] = args.image_root | |
| if not tasks_path.exists(): | |
| raise FileNotFoundError(f"Tasks file not found: {tasks_path}") | |
| if image_root is not None and not image_root.exists(): | |
| raise FileNotFoundError(f"Image root not found: {image_root}") | |
| tasks = load_tasks(tasks_path, image_root) | |
| if not tasks: | |
| raise ValueError("Task list is empty.") | |
| prepare_display_plan(tasks, seed=args.seed) | |
| results = load_results(results_path) | |
| launch_interface( | |
| tasks, | |
| results_path, | |
| results, | |
| share=bool(args.share), | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |