humanpreference / vote_interface.py
hongweiyi's picture
Upload folder using huggingface_hub
89b5bcd verified
import argparse
import html
import json
import random
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import gradio as gr
Task = Dict[str, Any]
Result = Dict[str, Any]
CUSTOM_CSS = """
.app-title {
text-align: center;
font-size: 2.2rem;
font-weight: 600;
margin-bottom: 0.8rem;
}
.prompt-container {
display: flex;
justify-content: center;
align-items: center;
margin-bottom: 1.2rem;
}
.prompt-text {
font-size: 1.6rem;
font-weight: 500;
text-align: center;
max-width: 960px;
}
.image-row {
display: flex;
justify-content: center;
gap: 1.5rem;
}
.button-row {
display: flex;
justify-content: center;
gap: 1.0rem;
}
.status-text {
text-align: center;
}
.progress-text {
text-align: center;
}
"""
def resolve_path(path_str: str, base_dir: Path) -> Path:
path = Path(path_str)
if not path.is_absolute():
path = (base_dir / path).resolve()
return path
def load_tasks(tasks_path: Path, image_root: Optional[Path]) -> List[Task]:
with tasks_path.open("r", encoding="utf-8") as f:
raw_tasks: Sequence[Task] = json.load(f)
base_dir = image_root if image_root is not None else tasks_path.parent
tasks: List[Task] = []
for index, entry in enumerate(raw_tasks):
task = dict(entry)
task["original_index"] = index
original_left_image = entry.get("left_image")
original_right_image = entry.get("right_image")
task["original_left_image"] = original_left_image
task["original_right_image"] = original_right_image
task["original_left_model"] = entry.get("model_left")
task["original_right_model"] = entry.get("model_right")
task["original_left_meta"] = entry.get("left_meta")
task["original_right_meta"] = entry.get("right_meta")
task["original_left_image_path"] = str(
resolve_path(original_left_image, base_dir)
)
task["original_right_image_path"] = str(
resolve_path(original_right_image, base_dir)
)
tasks.append(task)
return tasks
def prepare_display_plan(tasks: List[Task], seed: Optional[int] = None) -> None:
rng = random.Random(seed)
rng.shuffle(tasks)
for order_index, task in enumerate(tasks):
swap = rng.choice([False, True])
task["display_order"] = order_index
task["display_swap"] = swap
if swap:
task["display_left_image"] = task.get("original_right_image")
task["display_right_image"] = task.get("original_left_image")
task["display_left_image_path"] = task.get("original_right_image_path")
task["display_right_image_path"] = task.get("original_left_image_path")
task["display_left_model"] = task.get("original_right_model")
task["display_right_model"] = task.get("original_left_model")
else:
task["display_left_image"] = task.get("original_left_image")
task["display_right_image"] = task.get("original_right_image")
task["display_left_image_path"] = task.get("original_left_image_path")
task["display_right_image_path"] = task.get("original_right_image_path")
task["display_left_model"] = task.get("original_left_model")
task["display_right_model"] = task.get("original_right_model")
def load_results(results_path: Path) -> List[Result]:
if not results_path.exists():
return []
with results_path.open("r", encoding="utf-8") as f:
content = f.read().strip()
if not content:
return []
try:
loaded: List[Result] = json.loads(content)
except json.JSONDecodeError as exc:
raise ValueError(f"Unable to parse result file: {results_path}") from exc
return loaded
def ensure_parent_dir(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def save_results(results_path: Path, results: Sequence[Result]) -> None:
ensure_parent_dir(results_path)
with results_path.open("w", encoding="utf-8") as f:
json.dump(list(results), f, ensure_ascii=False, indent=2)
def build_prompt_html(task: Task) -> str:
prompt = (task.get("prompt") or "").strip()
if not prompt:
prompt = "No prompt provided."
escaped = html.escape(prompt)
return f"<div class='prompt-container'><div class='prompt-text'>{escaped}</div></div>"
def format_progress(completed: int, total: int) -> str:
return f"Progress: {completed}/{total}"
def find_next_index(tasks: Sequence[Task], completed_ids: Sequence[int]) -> Optional[int]:
completed_set = set(filter(None, completed_ids))
for idx, task in enumerate(tasks):
task_id = task.get("id")
if task_id is None or task_id not in completed_set:
return idx
return None
def vote_handler_factory(
tasks: Sequence[Task],
results_path: Path,
) -> Any:
total = len(tasks)
def handler(
choice: str,
current_index: Optional[int],
completed_ids: Sequence[int],
results: Sequence[Result],
) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]:
completed_set = set(filter(None, completed_ids))
results_list = list(results)
if current_index is None or current_index >= total:
status_text = "All tasks completed."
return (
"<div class='prompt-container'><div class='prompt-text'>All tasks completed.</div></div>",
"",
"",
format_progress(len(completed_set), total),
status_text,
None,
list(completed_set),
results_list,
)
task = tasks[current_index]
task_id = task.get("id", current_index)
result_entry: Result = {
"task_id": task_id,
"display_order": task.get("display_order"),
"original_index": task.get("original_index"),
"prompt_id": task.get("prompt_id"),
"pair_index": task.get("pair_index"),
"original_left_model": task.get("original_left_model"),
"original_right_model": task.get("original_right_model"),
"original_left_image": task.get("original_left_image"),
"original_right_image": task.get("original_right_image"),
"display_left_model": task.get("display_left_model"),
"display_right_model": task.get("display_right_model"),
"display_left_image": task.get("display_left_image"),
"display_right_image": task.get("display_right_image"),
"display_swap": task.get("display_swap"),
"choice": choice,
"timestamp": datetime.now().isoformat(timespec="seconds"),
"prompt": task.get("prompt"),
}
results_list.append(result_entry)
save_results(results_path, results_list)
completed_set.add(task_id)
next_index = find_next_index(tasks, completed_set)
if next_index is None:
status_text = f"Recorded choice `{choice}`. All tasks completed."
return (
"<div class='prompt-container'><div class='prompt-text'>All tasks completed.</div></div>",
"",
"",
format_progress(len(completed_set), total),
status_text,
None,
list(completed_set),
results_list,
)
next_task = tasks[next_index]
status_text = f"Recorded choice `{choice}`."
return (
build_prompt_html(next_task),
next_task["display_left_image_path"],
next_task["display_right_image_path"],
format_progress(len(completed_set), total),
status_text,
next_index,
list(completed_set),
results_list,
)
return handler
def launch_interface(
tasks: Sequence[Task],
results_path: Path,
initial_results: Sequence[Result],
*,
share: bool,
server_name: Optional[str],
server_port: Optional[int],
) -> None:
completed_ids = {entry.get("task_id", entry.get("id")) for entry in initial_results}
completed_ids.discard(None)
total = len(tasks)
initial_index = find_next_index(tasks, list(completed_ids))
with gr.Blocks(
title="Human Preference Voting",
css=CUSTOM_CSS,
) as demo:
gr.HTML("<div class='app-title'>Human Preference Voting</div>")
prompt_html = gr.HTML()
with gr.Row(elem_classes="image-row"):
left_image = gr.Image(label="Left Image", type="filepath", show_label=True)
right_image = gr.Image(label="Right Image", type="filepath", show_label=True)
progress_md = gr.Markdown(elem_classes="progress-text")
status_md = gr.Markdown(elem_classes="status-text")
with gr.Row(elem_classes="button-row"):
left_btn = gr.Button("Prefer Left", variant="primary")
tie_btn = gr.Button("No Preference", variant="secondary")
right_btn = gr.Button("Prefer Right", variant="primary")
current_index_state = gr.State(initial_index)
completed_state = gr.State(list(completed_ids))
results_state = gr.State(list(initial_results))
def init_view(
index: Optional[int],
completed: Sequence[int],
results_data: Sequence[Result],
) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]:
if total == 0:
return (
"<div class='prompt-container'><div class='prompt-text'>No tasks available.</div></div>",
"",
"",
format_progress(len(completed), total),
"No tasks available.",
None,
list(completed),
list(results_data),
)
if index is None:
return (
"<div class='prompt-container'><div class='prompt-text'>All tasks completed.</div></div>",
"",
"",
format_progress(len(completed), total),
"All tasks completed.",
None,
list(completed),
list(results_data),
)
task = tasks[index]
return (
build_prompt_html(task),
task["display_left_image_path"],
task["display_right_image_path"],
format_progress(len(completed), total),
"Select a preference to continue.",
index,
list(completed),
list(results_data),
)
demo.load(
init_view,
inputs=[current_index_state, completed_state, results_state],
outputs=[
prompt_html,
left_image,
right_image,
progress_md,
status_md,
current_index_state,
completed_state,
results_state,
],
)
handler = vote_handler_factory(tasks, results_path)
def make_click_fn(choice_value: str):
def on_click(
current_index: Optional[int],
completed_ids_value: Sequence[int],
results_data: Sequence[Result],
) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]:
return handler(choice_value, current_index, completed_ids_value, results_data)
return on_click
for button, vote_choice in [
(left_btn, "left"),
(tie_btn, "tie"),
(right_btn, "right"),
]:
button.click(
make_click_fn(vote_choice),
inputs=[current_index_state, completed_state, results_state],
outputs=[
prompt_html,
left_image,
right_image,
progress_md,
status_md,
current_index_state,
completed_state,
results_state,
],
)
demo.queue()
demo.launch(
share=share,
server_name=server_name,
server_port=server_port,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Gradio interface for human preference voting between image pairs.",
)
parser.add_argument(
"--tasks",
required=True,
type=Path,
help="Path to tasks JSON file, e.g. ./tasks.json",
)
parser.add_argument(
"--out",
type=Path,
default=Path("votes.json"),
help="Path to output JSON file for votes (default: ./votes.json).",
)
parser.add_argument(
"--image-root",
type=Path,
default=None,
help="Optional image root directory overriding task file location.",
)
parser.add_argument(
"--share",
action="store_true",
help="Enable Gradio share link for external access.",
)
parser.add_argument(
"--server-name",
type=str,
default=None,
help="Host address to bind, e.g. 0.0.0.0 for LAN access.",
)
parser.add_argument(
"--server-port",
type=int,
default=None,
help="Port number to bind; defaults to Gradio's choice.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Optional random seed to make task order reproducible.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
tasks_path: Path = args.tasks
results_path: Path = args.out
image_root: Optional[Path] = args.image_root
if not tasks_path.exists():
raise FileNotFoundError(f"Tasks file not found: {tasks_path}")
if image_root is not None and not image_root.exists():
raise FileNotFoundError(f"Image root not found: {image_root}")
tasks = load_tasks(tasks_path, image_root)
if not tasks:
raise ValueError("Task list is empty.")
prepare_display_plan(tasks, seed=args.seed)
results = load_results(results_path)
launch_interface(
tasks,
results_path,
results,
share=bool(args.share),
server_name=args.server_name,
server_port=args.server_port,
)
if __name__ == "__main__":
main()