| """Caption Preference Study — Gradio Space. |
| |
| Participants enter an access code (validated against a private 1000-code list |
| on HF), then see an image and two captions (human vs. model) and pick a |
| preference. Per-participant results are stored as ``<ACCESS_CODE>.csv`` in a |
| private HF dataset. If a participant returns later their session resumes from |
| wherever they left off, and if they have already completed the study they are |
| told so. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import io |
| import json |
| import os |
| import random |
| import re |
| import threading |
| import time |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any |
|
|
| import gradio as gr |
| import pandas as pd |
| from huggingface_hub import HfApi, hf_hub_download, snapshot_download |
| from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError |
|
|
|
|
| HF_USER = "pmadinei" |
| IMAGES_REPO = f"{HF_USER}/caption-preference-images" |
| RESULTS_REPO = f"{HF_USER}/caption-preference-results" |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| RESPONSE_TIME_CAP = 100.0 |
| CSV_PATH = Path(__file__).parent / "Qwen3-VL-8B-Instruct.csv" |
| IMAGE_DIR = Path(os.environ.get("IMAGE_DIR", "/tmp/caption_experiment_images")) |
| IMAGE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| RESULTS_COLUMNS = [ |
| "id", |
| "image_id", |
| "filename", |
| "type", |
| "human_caption", |
| "model_caption", |
| "preference", |
| "response_time", |
| ] |
| ACCESS_CODES_FILE = "access_codes.json" |
| ACCESS_CODE_RE = re.compile(r"^[A-Z0-9]+$") |
|
|
| api = HfApi(token=HF_TOKEN) |
|
|
|
|
| |
| |
| |
|
|
| def _clean_caption(value: Any) -> str: |
| if value is None: |
| return "" |
| text = str(value) |
| if len(text) >= 2 and text[0] == text[-1] and text[0] in ('"', "'"): |
| text = text[1:-1] |
| return text |
|
|
|
|
| print(f"[startup] Loading CSV from {CSV_PATH}") |
| df = pd.read_csv(CSV_PATH) |
| df["human_caption"] = df["human_caption"].map(_clean_caption) |
| df["model_caption"] = df["model_caption"].map(_clean_caption) |
|
|
| _test_mask = df["image_id"].astype(str).str.contains("test", case=False, na=False) |
| TEST_DF = df[_test_mask].reset_index(drop=True) |
| NONTEST_DF = df[~_test_mask].reset_index(drop=True) |
|
|
| NONTEST_IMAGE_IDS: list = list(NONTEST_DF["image_id"].unique()) |
| NONTEST_IMAGE_ID_SET = set(NONTEST_IMAGE_IDS) |
| IMAGE_ID_TO_FILENAMES: dict = { |
| img_id: list(NONTEST_DF[NONTEST_DF["image_id"] == img_id]["filename"].unique()) |
| for img_id in NONTEST_IMAGE_IDS |
| } |
|
|
| |
| |
| FILENAME_TO_TYPES: dict = { |
| fn: list(NONTEST_DF[NONTEST_DF["filename"] == fn]["type"].unique()) |
| for fn in NONTEST_DF["filename"].unique() |
| } |
| |
| |
| VALID_TRIAL_KEYS: set = { |
| (str(iid), str(fn), str(ty)) |
| for iid, fn, ty in zip( |
| NONTEST_DF["image_id"], NONTEST_DF["filename"], NONTEST_DF["type"] |
| ) |
| } |
|
|
|
|
| def _empty_counts() -> dict: |
| """A fully zero-initialised ``{image_id: {filename: {type: 0}}}`` tree. |
| |
| Only the caption types each filename actually has are included. |
| """ |
| tree: dict = {} |
| for img_id in NONTEST_IMAGE_IDS: |
| key = str(img_id) |
| tree[key] = { |
| fn: {t: 0 for t in FILENAME_TO_TYPES[fn]} |
| for fn in IMAGE_ID_TO_FILENAMES[img_id] |
| } |
| return tree |
|
|
| TEST_ROW_IDS = set(int(x) for x in TEST_DF["id"]) if len(TEST_DF) else set() |
| TOTAL_TRIALS_PER_PARTICIPANT = len(NONTEST_IMAGE_IDS) + len(TEST_DF) |
| print( |
| f"[startup] {len(df)} rows | {len(NONTEST_IMAGE_IDS)} non-test image_ids | " |
| f"{len(TEST_DF)} test rows | {TOTAL_TRIALS_PER_PARTICIPANT} trials per participant" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _ensure_images_downloaded() -> None: |
| if not HF_TOKEN: |
| print("[startup] WARNING: HF_TOKEN is not set; cannot download images.") |
| return |
| print(f"[startup] Downloading images from {IMAGES_REPO} to {IMAGE_DIR}...") |
| snapshot_download( |
| repo_id=IMAGES_REPO, |
| repo_type="dataset", |
| local_dir=str(IMAGE_DIR), |
| token=HF_TOKEN, |
| max_workers=16, |
| ) |
| print("[startup] Image download complete.") |
|
|
|
|
| _ensure_images_downloaded() |
|
|
|
|
| |
| |
| |
|
|
| _ACCESS_CODES: set = set() |
|
|
|
|
| def _normalize_code(code: Any) -> str: |
| return (str(code) if code is not None else "").strip().upper() |
|
|
|
|
| def _load_access_codes() -> None: |
| global _ACCESS_CODES |
| if not HF_TOKEN: |
| print("[access] WARNING: HF_TOKEN not set; cannot load access codes.") |
| return |
| try: |
| path = hf_hub_download( |
| repo_id=RESULTS_REPO, |
| repo_type="dataset", |
| filename=ACCESS_CODES_FILE, |
| token=HF_TOKEN, |
| force_download=True, |
| ) |
| with open(path) as f: |
| data = json.load(f) |
| _ACCESS_CODES = set(_normalize_code(c) for c in data) |
| print(f"[access] Loaded {len(_ACCESS_CODES)} access codes.") |
| except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError): |
| print(f"[access] ERROR: {ACCESS_CODES_FILE} not found in {RESULTS_REPO}.") |
| _ACCESS_CODES = set() |
| except Exception as exc: |
| print(f"[access] ERROR loading access codes: {exc}") |
| _ACCESS_CODES = set() |
|
|
|
|
| _load_access_codes() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| _STATE_LOCK = threading.Lock() |
| |
| _STATE: dict = _empty_counts() |
|
|
|
|
| def _get_count(image_id: Any, filename: str, caption_type: str) -> int: |
| return int( |
| _STATE.get(str(image_id), {}).get(filename, {}).get(caption_type, 0) |
| ) |
|
|
|
|
| def _incr_count( |
| image_id: Any, filename: str, caption_type: str, amount: int = 1 |
| ) -> None: |
| per_image = _STATE.setdefault(str(image_id), {}) |
| per_filename = per_image.setdefault(filename, {}) |
| per_filename[caption_type] = int(per_filename.get(caption_type, 0)) + amount |
|
|
|
|
| def _counts_from_results() -> dict | None: |
| """Tally (image_id, filename, type) exposures across every results/*.csv. |
| |
| Returns a zero-initialised ``{image_id: {filename: {type: count}}}`` tree, |
| or ``None`` if the results listing could not be read (so the caller can |
| fall back to the cache). |
| """ |
| if not HF_TOKEN: |
| return None |
| try: |
| files = api.list_repo_files(repo_id=RESULTS_REPO, repo_type="dataset") |
| except Exception as exc: |
| print(f"[state] Could not list results files ({exc}).") |
| return None |
|
|
| result_files = [ |
| f for f in files if f.startswith("results/") and f.endswith(".csv") |
| ] |
| counts: dict = _empty_counts() |
| n_rows = 0 |
| for rf in result_files: |
| try: |
| path = hf_hub_download( |
| repo_id=RESULTS_REPO, |
| repo_type="dataset", |
| filename=rf, |
| token=HF_TOKEN, |
| force_download=True, |
| ) |
| frame = pd.read_csv(path) |
| except Exception as exc: |
| print(f"[state] Skipping unreadable results file {rf} ({exc}).") |
| continue |
| needed = {"image_id", "filename", "type"} |
| if not needed.issubset(frame.columns): |
| continue |
| for iid, fn, ty in zip( |
| frame["image_id"].astype(str), |
| frame["filename"].astype(str), |
| frame["type"].astype(str), |
| ): |
| if (iid, fn, ty) not in VALID_TRIAL_KEYS: |
| continue |
| counts[iid][fn][ty] += 1 |
| n_rows += 1 |
| print( |
| f"[state] Tallied {n_rows} exposures from {len(result_files)} " |
| f"results file(s)." |
| ) |
| return counts |
|
|
|
|
| def _load_state() -> None: |
| """Seed ``_STATE`` from the cached state.json (fallback before refresh).""" |
| global _STATE |
| if not HF_TOKEN: |
| return |
| try: |
| path = hf_hub_download( |
| repo_id=RESULTS_REPO, |
| repo_type="dataset", |
| filename="state.json", |
| token=HF_TOKEN, |
| force_download=True, |
| ) |
| with open(path) as f: |
| loaded = json.load(f) |
| |
| |
| if isinstance(loaded, dict) and "type_counts" not in loaded: |
| _STATE = loaded |
| else: |
| _STATE = _empty_counts() |
| print(f"[state] Loaded cached exposure tree for {len(_STATE)} image_id(s).") |
| except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError): |
| print("[state] No existing state.json found, starting fresh.") |
| _STATE = _empty_counts() |
| except Exception as exc: |
| print(f"[state] Could not load state.json ({exc}); starting fresh.") |
| _STATE = _empty_counts() |
|
|
|
|
| def _save_state() -> None: |
| if not HF_TOKEN: |
| return |
| payload = json.dumps(_STATE, indent=2).encode() |
| api.upload_file( |
| path_or_fileobj=io.BytesIO(payload), |
| path_in_repo="state.json", |
| repo_id=RESULTS_REPO, |
| repo_type="dataset", |
| commit_message="Update exposure counts", |
| ) |
|
|
|
|
| def _refresh_counts_from_results() -> None: |
| """Rebuild counts from the authoritative results CSVs and persist them.""" |
| global _STATE |
| counts = _counts_from_results() |
| if counts is None: |
| return |
| with _STATE_LOCK: |
| _STATE = counts |
| try: |
| _save_state() |
| except Exception as exc: |
| print(f"[state] WARNING: could not persist state.json ({exc}).") |
|
|
|
|
| _load_state() |
| _refresh_counts_from_results() |
|
|
|
|
| def _assign_trials(image_ids_to_assign: list) -> dict: |
| """Pick the lowest-occurrence (filename, type) trial per image_id. |
| |
| For each image_id we scan every ``(filename, type)`` trial it has (only the |
| caption types each filename actually has) and pick the one with the lowest |
| recorded count in ``state.json``. Ties are broken by order, i.e. the first |
| trial that reaches the minimum count wins. Picks are reserved immediately |
| (count incremented + persisted) so the next assignment sees the update. |
| """ |
| with _STATE_LOCK: |
| assignments: dict = {} |
| for img_id in image_ids_to_assign: |
| best_count: int | None = None |
| best_fn: str | None = None |
| best_type: str | None = None |
| for fn in IMAGE_ID_TO_FILENAMES[img_id]: |
| for caption_type in FILENAME_TO_TYPES[fn]: |
| count = _get_count(img_id, fn, caption_type) |
| if best_count is None or count < best_count: |
| best_count = count |
| best_fn = fn |
| best_type = caption_type |
|
|
| _incr_count(img_id, best_fn, best_type) |
| assignments[img_id] = (best_fn, best_type) |
| if assignments: |
| try: |
| _save_state() |
| except Exception as exc: |
| print(f"[state] WARNING: could not persist state.json ({exc}).") |
| return assignments |
|
|
|
|
| |
| |
| |
|
|
| def _participant_filename(code: str) -> str: |
| return f"results/{code}.csv" |
|
|
|
|
| def _load_participant_results(participant_file: str) -> list[dict]: |
| if not HF_TOKEN: |
| return [] |
| try: |
| path = hf_hub_download( |
| repo_id=RESULTS_REPO, |
| repo_type="dataset", |
| filename=participant_file, |
| token=HF_TOKEN, |
| force_download=True, |
| ) |
| frame = pd.read_csv(path) |
| return frame.to_dict(orient="records") |
| except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError): |
| return [] |
| except Exception as exc: |
| print(f"[participant] Could not load {participant_file} ({exc})") |
| return [] |
|
|
|
|
| def _completed_keys(prior_results: list[dict]) -> tuple[set, set]: |
| """Return (done_nontest_image_ids, done_test_row_ids) from a CSV-loaded list.""" |
| done_image_ids = set() |
| done_test_ids = set() |
| for r in prior_results: |
| try: |
| row_id = int(r["id"]) |
| except (KeyError, TypeError, ValueError): |
| continue |
| if row_id in TEST_ROW_IDS: |
| done_test_ids.add(row_id) |
| continue |
| img_id_str = str(r.get("image_id")) |
| if "test" in img_id_str.lower(): |
| done_test_ids.add(row_id) |
| continue |
| img_id_val = r.get("image_id") |
| if img_id_val in NONTEST_IMAGE_ID_SET: |
| done_image_ids.add(img_id_val) |
| else: |
| try: |
| coerced = int(img_id_val) |
| if coerced in NONTEST_IMAGE_ID_SET: |
| done_image_ids.add(coerced) |
| except (TypeError, ValueError): |
| pass |
| return done_image_ids, done_test_ids |
|
|
|
|
| def _is_complete(prior_results: list[dict]) -> bool: |
| done_image_ids, done_test_ids = _completed_keys(prior_results) |
| return done_image_ids >= NONTEST_IMAGE_ID_SET and done_test_ids >= TEST_ROW_IDS |
|
|
|
|
| def _build_remaining_trials(prior_results: list[dict]) -> list[dict]: |
| done_image_ids, done_test_ids = _completed_keys(prior_results) |
|
|
| remaining_image_ids = [ |
| iid for iid in NONTEST_IMAGE_IDS if iid not in done_image_ids |
| ] |
| assignments = _assign_trials(remaining_image_ids) |
|
|
| trials: list[dict] = [] |
| for img_id in remaining_image_ids: |
| fn, caption_type = assignments[img_id] |
| match = NONTEST_DF[ |
| (NONTEST_DF["image_id"] == img_id) |
| & (NONTEST_DF["filename"] == fn) |
| & (NONTEST_DF["type"] == caption_type) |
| ] |
| if match.empty: |
| continue |
| trials.append(_row_to_trial(match.iloc[0])) |
|
|
| for _, row in TEST_DF.iterrows(): |
| if int(row["id"]) in done_test_ids: |
| continue |
| trials.append(_row_to_trial(row)) |
|
|
| random.shuffle(trials) |
| return trials |
|
|
|
|
| def _row_to_trial(row: pd.Series) -> dict: |
| raw_image_id = row["image_id"] |
| if isinstance(raw_image_id, (int,)) or ( |
| isinstance(raw_image_id, str) and raw_image_id.lstrip("-").isdigit() |
| ): |
| image_id_out: Any = int(raw_image_id) |
| else: |
| image_id_out = str(raw_image_id) |
| return { |
| "id": int(row["id"]), |
| "image_id": image_id_out, |
| "filename": str(row["filename"]), |
| "type": str(row["type"]), |
| "human_caption": str(row["human_caption"]), |
| "model_caption": str(row["model_caption"]), |
| "human_on_left": random.choice([True, False]), |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| _SAVE_REGISTRY_LOCK = threading.Lock() |
| _SAVE_ENTRIES: dict[str, dict] = {} |
|
|
|
|
| def _save_entry(participant_file: str) -> dict: |
| with _SAVE_REGISTRY_LOCK: |
| entry = _SAVE_ENTRIES.get(participant_file) |
| if entry is None: |
| entry = {"lock": threading.Lock(), "saved_count": 0} |
| _SAVE_ENTRIES[participant_file] = entry |
| return entry |
|
|
|
|
| def _reset_save_baseline(participant_file: str, count: int) -> None: |
| """Align the never-shrink guard with what's actually on HF at session start.""" |
| entry = _save_entry(participant_file) |
| with entry["lock"]: |
| entry["saved_count"] = count |
|
|
|
|
| _SAVE_MAX_RETRIES = 6 |
|
|
|
|
| def _save_results(participant_file: str, results: list[dict]) -> None: |
| if not HF_TOKEN or not results: |
| return |
| snapshot = list(results) |
| entry = _save_entry(participant_file) |
| |
| with entry["lock"]: |
| |
| if len(snapshot) <= entry["saved_count"]: |
| return |
| frame = pd.DataFrame(snapshot, columns=RESULTS_COLUMNS) |
| csv_bytes = frame.to_csv(index=False).encode() |
|
|
| |
| |
| |
| |
| for attempt in range(_SAVE_MAX_RETRIES): |
| try: |
| api.upload_file( |
| path_or_fileobj=io.BytesIO(csv_bytes), |
| path_in_repo=participant_file, |
| repo_id=RESULTS_REPO, |
| repo_type="dataset", |
| commit_message=f"Update {participant_file} (n={len(snapshot)})", |
| ) |
| entry["saved_count"] = len(snapshot) |
| return |
| except Exception as exc: |
| wait = 0.5 * (2**attempt) + random.uniform(0, 0.4) |
| print( |
| f"[save] upload attempt {attempt + 1}/{_SAVE_MAX_RETRIES} " |
| f"failed for {participant_file} ({exc}); retrying in {wait:.1f}s." |
| ) |
| time.sleep(wait) |
| print( |
| f"[save] ERROR: gave up saving {participant_file} after " |
| f"{_SAVE_MAX_RETRIES} attempts (n={len(snapshot)})." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| WELCOME_HTML = """ |
| <div style="text-align:center; padding: 12px 16px 4px;"> |
| <h2 style="margin-bottom: 8px;">Caption Preference Study</h2> |
| <p style="font-size: 1.05em; margin: 0;"> |
| You will see images with two captions. Click the caption that better |
| describes the image. |
| </p> |
| </div> |
| """ |
|
|
| DONE_NEW_HTML = """ |
| <div style="text-align:center; padding: 32px;"> |
| <h2>All done — thank you for participating!</h2> |
| <p>You can close this tab now.</p> |
| </div> |
| """ |
|
|
| DONE_ALREADY_HTML_TMPL = """ |
| <div style="text-align:center; padding: 32px;"> |
| <h2>You've already completed this study.</h2> |
| <p>Our records show access code <code>{code}</code> has finished all |
| {total} trials. There's nothing more to do — feel free to close this tab.</p> |
| </div> |
| """ |
|
|
|
|
| def _validation_error(message: str): |
| return ( |
| None, |
| gr.update(visible=True), |
| gr.update(visible=False), |
| gr.update(visible=False, value=""), |
| None, |
| gr.update(value=""), |
| gr.update(value=""), |
| "", |
| gr.update(value=message, visible=True), |
| ) |
|
|
|
|
| def start_session(access_code: str): |
| code = _normalize_code(access_code) |
|
|
| if not code: |
| return _validation_error("Please enter your **access code**.") |
| if not _ACCESS_CODES: |
| return _validation_error( |
| "Server isn't ready (access codes not loaded). Please try again " |
| "in a minute." |
| ) |
| if code not in _ACCESS_CODES: |
| return _validation_error( |
| "That access code isn't valid. Please double-check and try again." |
| ) |
|
|
| participant_file = _participant_filename(code) |
| prior = _load_participant_results(participant_file) |
| |
| |
| _reset_save_baseline(participant_file, len(prior)) |
|
|
| if _is_complete(prior): |
| msg = DONE_ALREADY_HTML_TMPL.format( |
| code=code, total=TOTAL_TRIALS_PER_PARTICIPANT |
| ) |
| return ( |
| None, |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(value=msg, visible=True), |
| None, |
| gr.update(value=""), |
| gr.update(value=""), |
| "", |
| gr.update(value="", visible=False), |
| ) |
|
|
| trials = _build_remaining_trials(prior) |
| if not trials: |
| |
| |
| msg = DONE_ALREADY_HTML_TMPL.format( |
| code=code, total=TOTAL_TRIALS_PER_PARTICIPANT |
| ) |
| return ( |
| None, |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(value=msg, visible=True), |
| None, |
| gr.update(value=""), |
| gr.update(value=""), |
| "", |
| gr.update(value="", visible=False), |
| ) |
|
|
| state = { |
| "participant_file": participant_file, |
| "trials": trials, |
| "current_idx": 0, |
| "trial_start_time": time.time(), |
| "results": list(prior), |
| "prior_count": len(prior), |
| "total_trials": TOTAL_TRIALS_PER_PARTICIPANT, |
| } |
| img_path, left, right, progress = _current_display(state) |
| return ( |
| state, |
| gr.update(visible=False), |
| gr.update(visible=True), |
| gr.update(value="", visible=False), |
| img_path, |
| gr.update(value=left), |
| gr.update(value=right), |
| progress, |
| gr.update(value="", visible=False), |
| ) |
|
|
|
|
| def _current_display(state: dict) -> tuple: |
| if state is None or state["current_idx"] >= len(state["trials"]): |
| return None, "", "", "" |
| trial = state["trials"][state["current_idx"]] |
| img_path = str(IMAGE_DIR / trial["filename"]) |
| if trial["human_on_left"]: |
| left, right = trial["human_caption"], trial["model_caption"] |
| else: |
| left, right = trial["model_caption"], trial["human_caption"] |
| completed = state["prior_count"] + state["current_idx"] |
| total = state["total_trials"] |
| progress = f"Trial {completed + 1} of {total}" |
| return img_path, left, right, progress |
|
|
|
|
| def _make_choice(state: dict, side: str): |
| if state is None: |
| return ( |
| state, |
| gr.update(visible=False), |
| gr.update(visible=False), |
| None, |
| gr.update(value=""), |
| gr.update(value=""), |
| "", |
| ) |
| elapsed = min(time.time() - state["trial_start_time"], RESPONSE_TIME_CAP) |
| trial = state["trials"][state["current_idx"]] |
| chose_human = trial["human_on_left"] if side == "left" else not trial["human_on_left"] |
| state["results"].append( |
| { |
| "id": trial["id"], |
| "image_id": trial["image_id"], |
| "filename": trial["filename"], |
| "type": trial["type"], |
| "human_caption": trial["human_caption"], |
| "model_caption": trial["model_caption"], |
| "preference": "H" if chose_human else "M", |
| "response_time": round(elapsed, 3), |
| } |
| ) |
|
|
| state["current_idx"] += 1 |
| is_done = state["current_idx"] >= len(state["trials"]) |
|
|
| if is_done: |
| |
| |
| _save_results(state["participant_file"], list(state["results"])) |
| else: |
| threading.Thread( |
| target=_save_results, |
| args=(state["participant_file"], list(state["results"])), |
| daemon=True, |
| ).start() |
|
|
| if is_done: |
| total = state["total_trials"] |
| return ( |
| state, |
| gr.update(visible=False), |
| gr.update(value=DONE_NEW_HTML, visible=True), |
| None, |
| gr.update(value=""), |
| gr.update(value=""), |
| f"Done — {total} / {total}", |
| ) |
|
|
| state["trial_start_time"] = time.time() |
| img_path, left, right, progress = _current_display(state) |
| return ( |
| state, |
| gr.update(visible=True), |
| gr.update(visible=False), |
| img_path, |
| gr.update(value=left), |
| gr.update(value=right), |
| progress, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| custom_css = """ |
| .caption-btn { |
| min-height: 140px !important; |
| font-size: 1.05em !important; |
| white-space: normal !important; |
| line-height: 1.4 !important; |
| padding: 16px !important; |
| text-align: left !important; |
| } |
| .center-img img { max-height: 60vh !important; object-fit: contain !important; } |
| .form-error { color: #b91c1c !important; } |
| .access-code-input input { |
| text-align: center !important; |
| font-size: 1.4em !important; |
| letter-spacing: 0.15em !important; |
| font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important; |
| } |
| """ |
|
|
| with gr.Blocks(title="Caption Preference Study", css=custom_css) as demo: |
| state = gr.State() |
|
|
| intro = gr.Group(visible=True) |
| with intro: |
| gr.HTML(WELCOME_HTML) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| pass |
| with gr.Column(scale=2): |
| code_input = gr.Textbox( |
| label="Access code", |
| placeholder="Enter your 8-character access code", |
| max_lines=1, |
| elem_classes=["access-code-input"], |
| ) |
| start_btn = gr.Button("Start", variant="primary", size="lg") |
| error_md = gr.Markdown("", visible=False, elem_classes=["form-error"]) |
| with gr.Column(scale=1): |
| pass |
|
|
| trial_group = gr.Group(visible=False) |
| with trial_group: |
| progress = gr.Markdown("") |
| image = gr.Image( |
| label=None, |
| show_label=False, |
| interactive=False, |
| elem_classes=["center-img"], |
| ) |
| with gr.Row(): |
| left_btn = gr.Button("", elem_classes=["caption-btn"]) |
| right_btn = gr.Button("", elem_classes=["caption-btn"]) |
|
|
| done_panel = gr.HTML(visible=False) |
|
|
| start_btn.click( |
| start_session, |
| inputs=[code_input], |
| outputs=[ |
| state, |
| intro, |
| trial_group, |
| done_panel, |
| image, |
| left_btn, |
| right_btn, |
| progress, |
| error_md, |
| ], |
| ) |
|
|
| left_btn.click( |
| lambda s: _make_choice(s, "left"), |
| inputs=[state], |
| outputs=[state, trial_group, done_panel, image, left_btn, right_btn, progress], |
| ) |
| right_btn.click( |
| lambda s: _make_choice(s, "right"), |
| inputs=[state], |
| outputs=[state, trial_group, done_panel, image, left_btn, right_btn, progress], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(default_concurrency_limit=8).launch(allowed_paths=[str(IMAGE_DIR)]) |
|
|