"""Caption Preference Study — Gradio Space. Participants enter an access code (validated against a private 1000-code list on HF), then see an image and two captions (human vs. model) and pick a preference. Per-participant results are stored as ``.csv`` in a private HF dataset. If a participant returns later their session resumes from wherever they left off, and if they have already completed the study they are told so. """ from __future__ import annotations import io import json import os import random import re import threading import time from datetime import datetime, timezone from pathlib import Path from typing import Any import gradio as gr import pandas as pd from huggingface_hub import HfApi, hf_hub_download, snapshot_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError HF_USER = "pmadinei" IMAGES_REPO = f"{HF_USER}/caption-preference-images" RESULTS_REPO = f"{HF_USER}/caption-preference-results" HF_TOKEN = os.environ.get("HF_TOKEN") RESPONSE_TIME_CAP = 100.0 CSV_PATH = Path(__file__).parent / "Qwen3-VL-8B-Instruct.csv" IMAGE_DIR = Path(os.environ.get("IMAGE_DIR", "/tmp/caption_experiment_images")) IMAGE_DIR.mkdir(parents=True, exist_ok=True) RESULTS_COLUMNS = [ "id", "image_id", "filename", "type", "human_caption", "model_caption", "preference", "response_time", ] ACCESS_CODES_FILE = "access_codes.json" ACCESS_CODE_RE = re.compile(r"^[A-Z0-9]+$") api = HfApi(token=HF_TOKEN) # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def _clean_caption(value: Any) -> str: if value is None: return "" text = str(value) if len(text) >= 2 and text[0] == text[-1] and text[0] in ('"', "'"): text = text[1:-1] return text print(f"[startup] Loading CSV from {CSV_PATH}") df = pd.read_csv(CSV_PATH) df["human_caption"] = df["human_caption"].map(_clean_caption) df["model_caption"] = df["model_caption"].map(_clean_caption) _test_mask = df["image_id"].astype(str).str.contains("test", case=False, na=False) TEST_DF = df[_test_mask].reset_index(drop=True) NONTEST_DF = df[~_test_mask].reset_index(drop=True) NONTEST_IMAGE_IDS: list = list(NONTEST_DF["image_id"].unique()) NONTEST_IMAGE_ID_SET = set(NONTEST_IMAGE_IDS) IMAGE_ID_TO_FILENAMES: dict = { img_id: list(NONTEST_DF[NONTEST_DF["image_id"] == img_id]["filename"].unique()) for img_id in NONTEST_IMAGE_IDS } # Caption types available per filename. Some filenames only have 2 of the 3 # possible types (e.g. no ``min_sim2model``), so we never assume all 3 exist. FILENAME_TO_TYPES: dict = { fn: list(NONTEST_DF[NONTEST_DF["filename"] == fn]["type"].unique()) for fn in NONTEST_DF["filename"].unique() } # Every legitimate (image_id, filename, type) triple in the non-test pool. Used # to ignore unrelated/test rows when tallying usage counts from results CSVs. VALID_TRIAL_KEYS: set = { (str(iid), str(fn), str(ty)) for iid, fn, ty in zip( NONTEST_DF["image_id"], NONTEST_DF["filename"], NONTEST_DF["type"] ) } def _empty_counts() -> dict: """A fully zero-initialised ``{image_id: {filename: {type: 0}}}`` tree. Only the caption types each filename actually has are included. """ tree: dict = {} for img_id in NONTEST_IMAGE_IDS: key = str(img_id) tree[key] = { fn: {t: 0 for t in FILENAME_TO_TYPES[fn]} for fn in IMAGE_ID_TO_FILENAMES[img_id] } return tree TEST_ROW_IDS = set(int(x) for x in TEST_DF["id"]) if len(TEST_DF) else set() TOTAL_TRIALS_PER_PARTICIPANT = len(NONTEST_IMAGE_IDS) + len(TEST_DF) print( f"[startup] {len(df)} rows | {len(NONTEST_IMAGE_IDS)} non-test image_ids | " f"{len(TEST_DF)} test rows | {TOTAL_TRIALS_PER_PARTICIPANT} trials per participant" ) # --------------------------------------------------------------------------- # Image download # --------------------------------------------------------------------------- def _ensure_images_downloaded() -> None: if not HF_TOKEN: print("[startup] WARNING: HF_TOKEN is not set; cannot download images.") return print(f"[startup] Downloading images from {IMAGES_REPO} to {IMAGE_DIR}...") snapshot_download( repo_id=IMAGES_REPO, repo_type="dataset", local_dir=str(IMAGE_DIR), token=HF_TOKEN, max_workers=16, ) print("[startup] Image download complete.") _ensure_images_downloaded() # --------------------------------------------------------------------------- # Access codes # --------------------------------------------------------------------------- _ACCESS_CODES: set = set() def _normalize_code(code: Any) -> str: return (str(code) if code is not None else "").strip().upper() def _load_access_codes() -> None: global _ACCESS_CODES if not HF_TOKEN: print("[access] WARNING: HF_TOKEN not set; cannot load access codes.") return try: path = hf_hub_download( repo_id=RESULTS_REPO, repo_type="dataset", filename=ACCESS_CODES_FILE, token=HF_TOKEN, force_download=True, ) with open(path) as f: data = json.load(f) _ACCESS_CODES = set(_normalize_code(c) for c in data) print(f"[access] Loaded {len(_ACCESS_CODES)} access codes.") except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError): print(f"[access] ERROR: {ACCESS_CODES_FILE} not found in {RESULTS_REPO}.") _ACCESS_CODES = set() except Exception as exc: # noqa: BLE001 print(f"[access] ERROR loading access codes: {exc}") _ACCESS_CODES = set() _load_access_codes() # --------------------------------------------------------------------------- # Exposure state (persisted to RESULTS_REPO/state.json) # # We balance two things across all participants: # 1. How often each ``filename`` is shown within its ``image_id``. # 2. How often each caption ``type`` is shown within a given ``filename``. # # The authoritative source of truth is the set of per-participant result CSVs # already stored in the results dataset: every recorded trial there is an # (image_id, filename, type) triple that was actually shown. ``state.json`` is # a {image_id: {filename: {type: count}}} cache of those tallies plus any # in-flight reservations made during the current run, so concurrent sessions # stay balanced even before their results are uploaded. # --------------------------------------------------------------------------- _STATE_LOCK = threading.Lock() # ``_STATE`` is the exposure tree: {image_id: {filename: {type: times_shown}}}. _STATE: dict = _empty_counts() def _get_count(image_id: Any, filename: str, caption_type: str) -> int: return int( _STATE.get(str(image_id), {}).get(filename, {}).get(caption_type, 0) ) def _incr_count( image_id: Any, filename: str, caption_type: str, amount: int = 1 ) -> None: per_image = _STATE.setdefault(str(image_id), {}) per_filename = per_image.setdefault(filename, {}) per_filename[caption_type] = int(per_filename.get(caption_type, 0)) + amount def _counts_from_results() -> dict | None: """Tally (image_id, filename, type) exposures across every results/*.csv. Returns a zero-initialised ``{image_id: {filename: {type: count}}}`` tree, or ``None`` if the results listing could not be read (so the caller can fall back to the cache). """ if not HF_TOKEN: return None try: files = api.list_repo_files(repo_id=RESULTS_REPO, repo_type="dataset") except Exception as exc: # noqa: BLE001 print(f"[state] Could not list results files ({exc}).") return None result_files = [ f for f in files if f.startswith("results/") and f.endswith(".csv") ] counts: dict = _empty_counts() n_rows = 0 for rf in result_files: try: path = hf_hub_download( repo_id=RESULTS_REPO, repo_type="dataset", filename=rf, token=HF_TOKEN, force_download=True, ) frame = pd.read_csv(path) except Exception as exc: # noqa: BLE001 print(f"[state] Skipping unreadable results file {rf} ({exc}).") continue needed = {"image_id", "filename", "type"} if not needed.issubset(frame.columns): continue for iid, fn, ty in zip( frame["image_id"].astype(str), frame["filename"].astype(str), frame["type"].astype(str), ): if (iid, fn, ty) not in VALID_TRIAL_KEYS: continue counts[iid][fn][ty] += 1 n_rows += 1 print( f"[state] Tallied {n_rows} exposures from {len(result_files)} " f"results file(s)." ) return counts def _load_state() -> None: """Seed ``_STATE`` from the cached state.json (fallback before refresh).""" global _STATE if not HF_TOKEN: return try: path = hf_hub_download( repo_id=RESULTS_REPO, repo_type="dataset", filename="state.json", token=HF_TOKEN, force_download=True, ) with open(path) as f: loaded = json.load(f) # Accept either the current nested tree or the legacy # ``{"type_counts": ...}`` wrapper; rebuild fresh on anything else. if isinstance(loaded, dict) and "type_counts" not in loaded: _STATE = loaded else: _STATE = _empty_counts() print(f"[state] Loaded cached exposure tree for {len(_STATE)} image_id(s).") except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError): print("[state] No existing state.json found, starting fresh.") _STATE = _empty_counts() except Exception as exc: # noqa: BLE001 print(f"[state] Could not load state.json ({exc}); starting fresh.") _STATE = _empty_counts() def _save_state() -> None: if not HF_TOKEN: return payload = json.dumps(_STATE, indent=2).encode() api.upload_file( path_or_fileobj=io.BytesIO(payload), path_in_repo="state.json", repo_id=RESULTS_REPO, repo_type="dataset", commit_message="Update exposure counts", ) def _refresh_counts_from_results() -> None: """Rebuild counts from the authoritative results CSVs and persist them.""" global _STATE counts = _counts_from_results() if counts is None: return with _STATE_LOCK: _STATE = counts try: _save_state() except Exception as exc: # noqa: BLE001 print(f"[state] WARNING: could not persist state.json ({exc}).") _load_state() _refresh_counts_from_results() def _assign_trials(image_ids_to_assign: list) -> dict: """Pick the lowest-occurrence (filename, type) trial per image_id. For each image_id we scan every ``(filename, type)`` trial it has (only the caption types each filename actually has) and pick the one with the lowest recorded count in ``state.json``. Ties are broken by order, i.e. the first trial that reaches the minimum count wins. Picks are reserved immediately (count incremented + persisted) so the next assignment sees the update. """ with _STATE_LOCK: assignments: dict = {} for img_id in image_ids_to_assign: best_count: int | None = None best_fn: str | None = None best_type: str | None = None for fn in IMAGE_ID_TO_FILENAMES[img_id]: for caption_type in FILENAME_TO_TYPES[fn]: count = _get_count(img_id, fn, caption_type) if best_count is None or count < best_count: best_count = count best_fn = fn best_type = caption_type _incr_count(img_id, best_fn, best_type) assignments[img_id] = (best_fn, best_type) if assignments: try: _save_state() except Exception as exc: # noqa: BLE001 print(f"[state] WARNING: could not persist state.json ({exc}).") return assignments # --------------------------------------------------------------------------- # Per-participant CSV # --------------------------------------------------------------------------- def _participant_filename(code: str) -> str: return f"results/{code}.csv" def _load_participant_results(participant_file: str) -> list[dict]: if not HF_TOKEN: return [] try: path = hf_hub_download( repo_id=RESULTS_REPO, repo_type="dataset", filename=participant_file, token=HF_TOKEN, force_download=True, ) frame = pd.read_csv(path) return frame.to_dict(orient="records") except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError): return [] except Exception as exc: # noqa: BLE001 print(f"[participant] Could not load {participant_file} ({exc})") return [] def _completed_keys(prior_results: list[dict]) -> tuple[set, set]: """Return (done_nontest_image_ids, done_test_row_ids) from a CSV-loaded list.""" done_image_ids = set() done_test_ids = set() for r in prior_results: try: row_id = int(r["id"]) except (KeyError, TypeError, ValueError): continue if row_id in TEST_ROW_IDS: done_test_ids.add(row_id) continue img_id_str = str(r.get("image_id")) if "test" in img_id_str.lower(): done_test_ids.add(row_id) continue img_id_val = r.get("image_id") if img_id_val in NONTEST_IMAGE_ID_SET: done_image_ids.add(img_id_val) else: try: coerced = int(img_id_val) if coerced in NONTEST_IMAGE_ID_SET: done_image_ids.add(coerced) except (TypeError, ValueError): pass return done_image_ids, done_test_ids def _is_complete(prior_results: list[dict]) -> bool: done_image_ids, done_test_ids = _completed_keys(prior_results) return done_image_ids >= NONTEST_IMAGE_ID_SET and done_test_ids >= TEST_ROW_IDS def _build_remaining_trials(prior_results: list[dict]) -> list[dict]: done_image_ids, done_test_ids = _completed_keys(prior_results) remaining_image_ids = [ iid for iid in NONTEST_IMAGE_IDS if iid not in done_image_ids ] assignments = _assign_trials(remaining_image_ids) trials: list[dict] = [] for img_id in remaining_image_ids: fn, caption_type = assignments[img_id] match = NONTEST_DF[ (NONTEST_DF["image_id"] == img_id) & (NONTEST_DF["filename"] == fn) & (NONTEST_DF["type"] == caption_type) ] if match.empty: continue trials.append(_row_to_trial(match.iloc[0])) for _, row in TEST_DF.iterrows(): if int(row["id"]) in done_test_ids: continue trials.append(_row_to_trial(row)) random.shuffle(trials) return trials def _row_to_trial(row: pd.Series) -> dict: raw_image_id = row["image_id"] if isinstance(raw_image_id, (int,)) or ( isinstance(raw_image_id, str) and raw_image_id.lstrip("-").isdigit() ): image_id_out: Any = int(raw_image_id) else: image_id_out = str(raw_image_id) return { "id": int(row["id"]), "image_id": image_id_out, "filename": str(row["filename"]), "type": str(row["type"]), "human_caption": str(row["human_caption"]), "model_caption": str(row["model_caption"]), "human_on_left": random.choice([True, False]), } # Per-participant save coordination. Uploads for a given participant file are # serialized through one lock, and we never overwrite a larger file with a # smaller (stale) snapshot. This prevents the out-of-order/last-writer-wins race # that previously truncated participant files when clicks were saved from # unsynchronized background threads. _SAVE_REGISTRY_LOCK = threading.Lock() _SAVE_ENTRIES: dict[str, dict] = {} def _save_entry(participant_file: str) -> dict: with _SAVE_REGISTRY_LOCK: entry = _SAVE_ENTRIES.get(participant_file) if entry is None: entry = {"lock": threading.Lock(), "saved_count": 0} _SAVE_ENTRIES[participant_file] = entry return entry def _reset_save_baseline(participant_file: str, count: int) -> None: """Align the never-shrink guard with what's actually on HF at session start.""" entry = _save_entry(participant_file) with entry["lock"]: entry["saved_count"] = count _SAVE_MAX_RETRIES = 6 def _save_results(participant_file: str, results: list[dict]) -> None: if not HF_TOKEN or not results: return snapshot = list(results) entry = _save_entry(participant_file) # Serialize all uploads for this participant so they can't race each other. with entry["lock"]: # Never replace a more-complete file with a stale/smaller snapshot. if len(snapshot) <= entry["saved_count"]: return frame = pd.DataFrame(snapshot, columns=RESULTS_COLUMNS) csv_bytes = frame.to_csv(index=False).encode() # Different participants commit to the same repo concurrently, so an # individual upload can still be rejected with a revision conflict. # Retry with backoff so no answer is silently dropped (this was the # original data-loss bug: conflicts were swallowed and never retried). for attempt in range(_SAVE_MAX_RETRIES): try: api.upload_file( path_or_fileobj=io.BytesIO(csv_bytes), path_in_repo=participant_file, repo_id=RESULTS_REPO, repo_type="dataset", commit_message=f"Update {participant_file} (n={len(snapshot)})", ) entry["saved_count"] = len(snapshot) return except Exception as exc: # noqa: BLE001 wait = 0.5 * (2**attempt) + random.uniform(0, 0.4) print( f"[save] upload attempt {attempt + 1}/{_SAVE_MAX_RETRIES} " f"failed for {participant_file} ({exc}); retrying in {wait:.1f}s." ) time.sleep(wait) print( f"[save] ERROR: gave up saving {participant_file} after " f"{_SAVE_MAX_RETRIES} attempts (n={len(snapshot)})." ) # --------------------------------------------------------------------------- # Gradio handlers # --------------------------------------------------------------------------- WELCOME_HTML = """

Caption Preference Study

You will see images with two captions. Click the caption that better describes the image.

""" DONE_NEW_HTML = """

All done — thank you for participating!

You can close this tab now.

""" DONE_ALREADY_HTML_TMPL = """

You've already completed this study.

Our records show access code {code} has finished all {total} trials. There's nothing more to do — feel free to close this tab.

""" def _validation_error(message: str): return ( None, # state gr.update(visible=True), # intro gr.update(visible=False), # trial group gr.update(visible=False, value=""), # done panel None, # image gr.update(value=""), # left button gr.update(value=""), # right button "", # progress gr.update(value=message, visible=True), # error markdown ) def start_session(access_code: str): code = _normalize_code(access_code) if not code: return _validation_error("Please enter your **access code**.") if not _ACCESS_CODES: return _validation_error( "Server isn't ready (access codes not loaded). Please try again " "in a minute." ) if code not in _ACCESS_CODES: return _validation_error( "That access code isn't valid. Please double-check and try again." ) participant_file = _participant_filename(code) prior = _load_participant_results(participant_file) # Baseline the never-shrink save guard to the file that's actually on HF, # so a returning participant's saves grow from their real prior progress. _reset_save_baseline(participant_file, len(prior)) if _is_complete(prior): msg = DONE_ALREADY_HTML_TMPL.format( code=code, total=TOTAL_TRIALS_PER_PARTICIPANT ) return ( None, gr.update(visible=False), gr.update(visible=False), gr.update(value=msg, visible=True), None, gr.update(value=""), gr.update(value=""), "", gr.update(value="", visible=False), ) trials = _build_remaining_trials(prior) if not trials: # Defensive: nothing left to do but the strict completeness check did # not return True. Treat as done so the participant isn't stuck. msg = DONE_ALREADY_HTML_TMPL.format( code=code, total=TOTAL_TRIALS_PER_PARTICIPANT ) return ( None, gr.update(visible=False), gr.update(visible=False), gr.update(value=msg, visible=True), None, gr.update(value=""), gr.update(value=""), "", gr.update(value="", visible=False), ) state = { "participant_file": participant_file, "trials": trials, "current_idx": 0, "trial_start_time": time.time(), "results": list(prior), "prior_count": len(prior), "total_trials": TOTAL_TRIALS_PER_PARTICIPANT, } img_path, left, right, progress = _current_display(state) return ( state, gr.update(visible=False), # intro gr.update(visible=True), # trial group gr.update(value="", visible=False), # done panel img_path, # image gr.update(value=left), # left button gr.update(value=right), # right button progress, # progress gr.update(value="", visible=False), # error ) def _current_display(state: dict) -> tuple: if state is None or state["current_idx"] >= len(state["trials"]): return None, "", "", "" trial = state["trials"][state["current_idx"]] img_path = str(IMAGE_DIR / trial["filename"]) if trial["human_on_left"]: left, right = trial["human_caption"], trial["model_caption"] else: left, right = trial["model_caption"], trial["human_caption"] completed = state["prior_count"] + state["current_idx"] total = state["total_trials"] progress = f"Trial {completed + 1} of {total}" return img_path, left, right, progress def _make_choice(state: dict, side: str): if state is None: return ( state, gr.update(visible=False), gr.update(visible=False), None, gr.update(value=""), gr.update(value=""), "", ) elapsed = min(time.time() - state["trial_start_time"], RESPONSE_TIME_CAP) trial = state["trials"][state["current_idx"]] chose_human = trial["human_on_left"] if side == "left" else not trial["human_on_left"] state["results"].append( { "id": trial["id"], "image_id": trial["image_id"], "filename": trial["filename"], "type": trial["type"], "human_caption": trial["human_caption"], "model_caption": trial["model_caption"], "preference": "H" if chose_human else "M", "response_time": round(elapsed, 3), } ) state["current_idx"] += 1 is_done = state["current_idx"] >= len(state["trials"]) if is_done: # Final trial: save synchronously so completion is guaranteed persisted # (all trials written) before we show the "done" panel. _save_results(state["participant_file"], list(state["results"])) else: threading.Thread( target=_save_results, args=(state["participant_file"], list(state["results"])), daemon=True, ).start() if is_done: total = state["total_trials"] return ( state, gr.update(visible=False), gr.update(value=DONE_NEW_HTML, visible=True), None, gr.update(value=""), gr.update(value=""), f"Done — {total} / {total}", ) state["trial_start_time"] = time.time() img_path, left, right, progress = _current_display(state) return ( state, gr.update(visible=True), gr.update(visible=False), img_path, gr.update(value=left), gr.update(value=right), progress, ) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- custom_css = """ .caption-btn { min-height: 140px !important; font-size: 1.05em !important; white-space: normal !important; line-height: 1.4 !important; padding: 16px !important; text-align: left !important; } .center-img img { max-height: 60vh !important; object-fit: contain !important; } .form-error { color: #b91c1c !important; } .access-code-input input { text-align: center !important; font-size: 1.4em !important; letter-spacing: 0.15em !important; font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important; } """ with gr.Blocks(title="Caption Preference Study", css=custom_css) as demo: state = gr.State() intro = gr.Group(visible=True) with intro: gr.HTML(WELCOME_HTML) with gr.Row(): with gr.Column(scale=1): pass with gr.Column(scale=2): code_input = gr.Textbox( label="Access code", placeholder="Enter your 8-character access code", max_lines=1, elem_classes=["access-code-input"], ) start_btn = gr.Button("Start", variant="primary", size="lg") error_md = gr.Markdown("", visible=False, elem_classes=["form-error"]) with gr.Column(scale=1): pass trial_group = gr.Group(visible=False) with trial_group: progress = gr.Markdown("") image = gr.Image( label=None, show_label=False, interactive=False, elem_classes=["center-img"], ) with gr.Row(): left_btn = gr.Button("", elem_classes=["caption-btn"]) right_btn = gr.Button("", elem_classes=["caption-btn"]) done_panel = gr.HTML(visible=False) start_btn.click( start_session, inputs=[code_input], outputs=[ state, intro, trial_group, done_panel, image, left_btn, right_btn, progress, error_md, ], ) left_btn.click( lambda s: _make_choice(s, "left"), inputs=[state], outputs=[state, trial_group, done_panel, image, left_btn, right_btn, progress], ) right_btn.click( lambda s: _make_choice(s, "right"), inputs=[state], outputs=[state, trial_group, done_panel, image, left_btn, right_btn, progress], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=8).launch(allowed_paths=[str(IMAGE_DIR)])