import argparse
import html
import json
import random
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import gradio as gr
Task = Dict[str, Any]
Result = Dict[str, Any]
CUSTOM_CSS = """
.app-title {
text-align: center;
font-size: 2.2rem;
font-weight: 600;
margin-bottom: 0.8rem;
}
.prompt-container {
display: flex;
justify-content: center;
align-items: center;
margin-bottom: 1.2rem;
}
.prompt-text {
font-size: 1.6rem;
font-weight: 500;
text-align: center;
max-width: 960px;
}
.image-row {
display: flex;
justify-content: center;
gap: 1.5rem;
}
.button-row {
display: flex;
justify-content: center;
gap: 1.0rem;
}
.status-text {
text-align: center;
}
.progress-text {
text-align: center;
}
"""
def resolve_path(path_str: str, base_dir: Path) -> Path:
path = Path(path_str)
if not path.is_absolute():
path = (base_dir / path).resolve()
return path
def load_tasks(tasks_path: Path, image_root: Optional[Path]) -> List[Task]:
with tasks_path.open("r", encoding="utf-8") as f:
raw_tasks: Sequence[Task] = json.load(f)
base_dir = image_root if image_root is not None else tasks_path.parent
tasks: List[Task] = []
for index, entry in enumerate(raw_tasks):
task = dict(entry)
task["original_index"] = index
original_left_image = entry.get("left_image")
original_right_image = entry.get("right_image")
task["original_left_image"] = original_left_image
task["original_right_image"] = original_right_image
task["original_left_model"] = entry.get("model_left")
task["original_right_model"] = entry.get("model_right")
task["original_left_meta"] = entry.get("left_meta")
task["original_right_meta"] = entry.get("right_meta")
task["original_left_image_path"] = str(
resolve_path(original_left_image, base_dir)
)
task["original_right_image_path"] = str(
resolve_path(original_right_image, base_dir)
)
tasks.append(task)
return tasks
def prepare_display_plan(tasks: List[Task], seed: Optional[int] = None) -> None:
rng = random.Random(seed)
rng.shuffle(tasks)
for order_index, task in enumerate(tasks):
swap = rng.choice([False, True])
task["display_order"] = order_index
task["display_swap"] = swap
if swap:
task["display_left_image"] = task.get("original_right_image")
task["display_right_image"] = task.get("original_left_image")
task["display_left_image_path"] = task.get("original_right_image_path")
task["display_right_image_path"] = task.get("original_left_image_path")
task["display_left_model"] = task.get("original_right_model")
task["display_right_model"] = task.get("original_left_model")
else:
task["display_left_image"] = task.get("original_left_image")
task["display_right_image"] = task.get("original_right_image")
task["display_left_image_path"] = task.get("original_left_image_path")
task["display_right_image_path"] = task.get("original_right_image_path")
task["display_left_model"] = task.get("original_left_model")
task["display_right_model"] = task.get("original_right_model")
def load_results(results_path: Path) -> List[Result]:
if not results_path.exists():
return []
with results_path.open("r", encoding="utf-8") as f:
content = f.read().strip()
if not content:
return []
try:
loaded: List[Result] = json.loads(content)
except json.JSONDecodeError as exc:
raise ValueError(f"Unable to parse result file: {results_path}") from exc
return loaded
def ensure_parent_dir(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def save_results(results_path: Path, results: Sequence[Result]) -> None:
ensure_parent_dir(results_path)
with results_path.open("w", encoding="utf-8") as f:
json.dump(list(results), f, ensure_ascii=False, indent=2)
def build_prompt_html(task: Task) -> str:
prompt = (task.get("prompt") or "").strip()
if not prompt:
prompt = "No prompt provided."
escaped = html.escape(prompt)
return f"
"
def format_progress(completed: int, total: int) -> str:
return f"Progress: {completed}/{total}"
def find_next_index(tasks: Sequence[Task], completed_ids: Sequence[int]) -> Optional[int]:
completed_set = set(filter(None, completed_ids))
for idx, task in enumerate(tasks):
task_id = task.get("id")
if task_id is None or task_id not in completed_set:
return idx
return None
def vote_handler_factory(
tasks: Sequence[Task],
results_path: Path,
) -> Any:
total = len(tasks)
def handler(
choice: str,
current_index: Optional[int],
completed_ids: Sequence[int],
results: Sequence[Result],
) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]:
completed_set = set(filter(None, completed_ids))
results_list = list(results)
if current_index is None or current_index >= total:
status_text = "All tasks completed."
return (
"",
"",
"",
format_progress(len(completed_set), total),
status_text,
None,
list(completed_set),
results_list,
)
task = tasks[current_index]
task_id = task.get("id", current_index)
result_entry: Result = {
"task_id": task_id,
"display_order": task.get("display_order"),
"original_index": task.get("original_index"),
"prompt_id": task.get("prompt_id"),
"pair_index": task.get("pair_index"),
"original_left_model": task.get("original_left_model"),
"original_right_model": task.get("original_right_model"),
"original_left_image": task.get("original_left_image"),
"original_right_image": task.get("original_right_image"),
"display_left_model": task.get("display_left_model"),
"display_right_model": task.get("display_right_model"),
"display_left_image": task.get("display_left_image"),
"display_right_image": task.get("display_right_image"),
"display_swap": task.get("display_swap"),
"choice": choice,
"timestamp": datetime.now().isoformat(timespec="seconds"),
"prompt": task.get("prompt"),
}
results_list.append(result_entry)
save_results(results_path, results_list)
completed_set.add(task_id)
next_index = find_next_index(tasks, completed_set)
if next_index is None:
status_text = f"Recorded choice `{choice}`. All tasks completed."
return (
"",
"",
"",
format_progress(len(completed_set), total),
status_text,
None,
list(completed_set),
results_list,
)
next_task = tasks[next_index]
status_text = f"Recorded choice `{choice}`."
return (
build_prompt_html(next_task),
next_task["display_left_image_path"],
next_task["display_right_image_path"],
format_progress(len(completed_set), total),
status_text,
next_index,
list(completed_set),
results_list,
)
return handler
def launch_interface(
tasks: Sequence[Task],
results_path: Path,
initial_results: Sequence[Result],
*,
share: bool,
server_name: Optional[str],
server_port: Optional[int],
) -> None:
completed_ids = {entry.get("task_id", entry.get("id")) for entry in initial_results}
completed_ids.discard(None)
total = len(tasks)
initial_index = find_next_index(tasks, list(completed_ids))
with gr.Blocks(
title="Human Preference Voting",
css=CUSTOM_CSS,
) as demo:
gr.HTML("Human Preference Voting
")
prompt_html = gr.HTML()
with gr.Row(elem_classes="image-row"):
left_image = gr.Image(label="Left Image", type="filepath", show_label=True)
right_image = gr.Image(label="Right Image", type="filepath", show_label=True)
progress_md = gr.Markdown(elem_classes="progress-text")
status_md = gr.Markdown(elem_classes="status-text")
with gr.Row(elem_classes="button-row"):
left_btn = gr.Button("Prefer Left", variant="primary")
tie_btn = gr.Button("No Preference", variant="secondary")
right_btn = gr.Button("Prefer Right", variant="primary")
current_index_state = gr.State(initial_index)
completed_state = gr.State(list(completed_ids))
results_state = gr.State(list(initial_results))
def init_view(
index: Optional[int],
completed: Sequence[int],
results_data: Sequence[Result],
) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]:
if total == 0:
return (
"",
"",
"",
format_progress(len(completed), total),
"No tasks available.",
None,
list(completed),
list(results_data),
)
if index is None:
return (
"",
"",
"",
format_progress(len(completed), total),
"All tasks completed.",
None,
list(completed),
list(results_data),
)
task = tasks[index]
return (
build_prompt_html(task),
task["display_left_image_path"],
task["display_right_image_path"],
format_progress(len(completed), total),
"Select a preference to continue.",
index,
list(completed),
list(results_data),
)
demo.load(
init_view,
inputs=[current_index_state, completed_state, results_state],
outputs=[
prompt_html,
left_image,
right_image,
progress_md,
status_md,
current_index_state,
completed_state,
results_state,
],
)
handler = vote_handler_factory(tasks, results_path)
def make_click_fn(choice_value: str):
def on_click(
current_index: Optional[int],
completed_ids_value: Sequence[int],
results_data: Sequence[Result],
) -> Tuple[str, str, str, str, str, Optional[int], List[int], List[Result]]:
return handler(choice_value, current_index, completed_ids_value, results_data)
return on_click
for button, vote_choice in [
(left_btn, "left"),
(tie_btn, "tie"),
(right_btn, "right"),
]:
button.click(
make_click_fn(vote_choice),
inputs=[current_index_state, completed_state, results_state],
outputs=[
prompt_html,
left_image,
right_image,
progress_md,
status_md,
current_index_state,
completed_state,
results_state,
],
)
demo.queue()
demo.launch(
share=share,
server_name=server_name,
server_port=server_port,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Gradio interface for human preference voting between image pairs.",
)
parser.add_argument(
"--tasks",
required=True,
type=Path,
help="Path to tasks JSON file, e.g. ./tasks.json",
)
parser.add_argument(
"--out",
type=Path,
default=Path("votes.json"),
help="Path to output JSON file for votes (default: ./votes.json).",
)
parser.add_argument(
"--image-root",
type=Path,
default=None,
help="Optional image root directory overriding task file location.",
)
parser.add_argument(
"--share",
action="store_true",
help="Enable Gradio share link for external access.",
)
parser.add_argument(
"--server-name",
type=str,
default=None,
help="Host address to bind, e.g. 0.0.0.0 for LAN access.",
)
parser.add_argument(
"--server-port",
type=int,
default=None,
help="Port number to bind; defaults to Gradio's choice.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Optional random seed to make task order reproducible.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
tasks_path: Path = args.tasks
results_path: Path = args.out
image_root: Optional[Path] = args.image_root
if not tasks_path.exists():
raise FileNotFoundError(f"Tasks file not found: {tasks_path}")
if image_root is not None and not image_root.exists():
raise FileNotFoundError(f"Image root not found: {image_root}")
tasks = load_tasks(tasks_path, image_root)
if not tasks:
raise ValueError("Task list is empty.")
prepare_display_plan(tasks, seed=args.seed)
results = load_results(results_path)
launch_interface(
tasks,
results_path,
results,
share=bool(args.share),
server_name=args.server_name,
server_port=args.server_port,
)
if __name__ == "__main__":
main()