pmadinei's picture
Harden saves: retry on commit conflict so answers are never dropped
383b834 verified
Raw
History Blame Contribute Delete
28.8 kB
"""Caption Preference Study — Gradio Space.
Participants enter an access code (validated against a private 1000-code list
on HF), then see an image and two captions (human vs. model) and pick a
preference. Per-participant results are stored as ``<ACCESS_CODE>.csv`` in a
private HF dataset. If a participant returns later their session resumes from
wherever they left off, and if they have already completed the study they are
told so.
"""
from __future__ import annotations
import io
import json
import os
import random
import re
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
HF_USER = "pmadinei"
IMAGES_REPO = f"{HF_USER}/caption-preference-images"
RESULTS_REPO = f"{HF_USER}/caption-preference-results"
HF_TOKEN = os.environ.get("HF_TOKEN")
RESPONSE_TIME_CAP = 100.0
CSV_PATH = Path(__file__).parent / "Qwen3-VL-8B-Instruct.csv"
IMAGE_DIR = Path(os.environ.get("IMAGE_DIR", "/tmp/caption_experiment_images"))
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_COLUMNS = [
"id",
"image_id",
"filename",
"type",
"human_caption",
"model_caption",
"preference",
"response_time",
]
ACCESS_CODES_FILE = "access_codes.json"
ACCESS_CODE_RE = re.compile(r"^[A-Z0-9]+$")
api = HfApi(token=HF_TOKEN)
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _clean_caption(value: Any) -> str:
if value is None:
return ""
text = str(value)
if len(text) >= 2 and text[0] == text[-1] and text[0] in ('"', "'"):
text = text[1:-1]
return text
print(f"[startup] Loading CSV from {CSV_PATH}")
df = pd.read_csv(CSV_PATH)
df["human_caption"] = df["human_caption"].map(_clean_caption)
df["model_caption"] = df["model_caption"].map(_clean_caption)
_test_mask = df["image_id"].astype(str).str.contains("test", case=False, na=False)
TEST_DF = df[_test_mask].reset_index(drop=True)
NONTEST_DF = df[~_test_mask].reset_index(drop=True)
NONTEST_IMAGE_IDS: list = list(NONTEST_DF["image_id"].unique())
NONTEST_IMAGE_ID_SET = set(NONTEST_IMAGE_IDS)
IMAGE_ID_TO_FILENAMES: dict = {
img_id: list(NONTEST_DF[NONTEST_DF["image_id"] == img_id]["filename"].unique())
for img_id in NONTEST_IMAGE_IDS
}
# Caption types available per filename. Some filenames only have 2 of the 3
# possible types (e.g. no ``min_sim2model``), so we never assume all 3 exist.
FILENAME_TO_TYPES: dict = {
fn: list(NONTEST_DF[NONTEST_DF["filename"] == fn]["type"].unique())
for fn in NONTEST_DF["filename"].unique()
}
# Every legitimate (image_id, filename, type) triple in the non-test pool. Used
# to ignore unrelated/test rows when tallying usage counts from results CSVs.
VALID_TRIAL_KEYS: set = {
(str(iid), str(fn), str(ty))
for iid, fn, ty in zip(
NONTEST_DF["image_id"], NONTEST_DF["filename"], NONTEST_DF["type"]
)
}
def _empty_counts() -> dict:
"""A fully zero-initialised ``{image_id: {filename: {type: 0}}}`` tree.
Only the caption types each filename actually has are included.
"""
tree: dict = {}
for img_id in NONTEST_IMAGE_IDS:
key = str(img_id)
tree[key] = {
fn: {t: 0 for t in FILENAME_TO_TYPES[fn]}
for fn in IMAGE_ID_TO_FILENAMES[img_id]
}
return tree
TEST_ROW_IDS = set(int(x) for x in TEST_DF["id"]) if len(TEST_DF) else set()
TOTAL_TRIALS_PER_PARTICIPANT = len(NONTEST_IMAGE_IDS) + len(TEST_DF)
print(
f"[startup] {len(df)} rows | {len(NONTEST_IMAGE_IDS)} non-test image_ids | "
f"{len(TEST_DF)} test rows | {TOTAL_TRIALS_PER_PARTICIPANT} trials per participant"
)
# ---------------------------------------------------------------------------
# Image download
# ---------------------------------------------------------------------------
def _ensure_images_downloaded() -> None:
if not HF_TOKEN:
print("[startup] WARNING: HF_TOKEN is not set; cannot download images.")
return
print(f"[startup] Downloading images from {IMAGES_REPO} to {IMAGE_DIR}...")
snapshot_download(
repo_id=IMAGES_REPO,
repo_type="dataset",
local_dir=str(IMAGE_DIR),
token=HF_TOKEN,
max_workers=16,
)
print("[startup] Image download complete.")
_ensure_images_downloaded()
# ---------------------------------------------------------------------------
# Access codes
# ---------------------------------------------------------------------------
_ACCESS_CODES: set = set()
def _normalize_code(code: Any) -> str:
return (str(code) if code is not None else "").strip().upper()
def _load_access_codes() -> None:
global _ACCESS_CODES
if not HF_TOKEN:
print("[access] WARNING: HF_TOKEN not set; cannot load access codes.")
return
try:
path = hf_hub_download(
repo_id=RESULTS_REPO,
repo_type="dataset",
filename=ACCESS_CODES_FILE,
token=HF_TOKEN,
force_download=True,
)
with open(path) as f:
data = json.load(f)
_ACCESS_CODES = set(_normalize_code(c) for c in data)
print(f"[access] Loaded {len(_ACCESS_CODES)} access codes.")
except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError):
print(f"[access] ERROR: {ACCESS_CODES_FILE} not found in {RESULTS_REPO}.")
_ACCESS_CODES = set()
except Exception as exc: # noqa: BLE001
print(f"[access] ERROR loading access codes: {exc}")
_ACCESS_CODES = set()
_load_access_codes()
# ---------------------------------------------------------------------------
# Exposure state (persisted to RESULTS_REPO/state.json)
#
# We balance two things across all participants:
# 1. How often each ``filename`` is shown within its ``image_id``.
# 2. How often each caption ``type`` is shown within a given ``filename``.
#
# The authoritative source of truth is the set of per-participant result CSVs
# already stored in the results dataset: every recorded trial there is an
# (image_id, filename, type) triple that was actually shown. ``state.json`` is
# a {image_id: {filename: {type: count}}} cache of those tallies plus any
# in-flight reservations made during the current run, so concurrent sessions
# stay balanced even before their results are uploaded.
# ---------------------------------------------------------------------------
_STATE_LOCK = threading.Lock()
# ``_STATE`` is the exposure tree: {image_id: {filename: {type: times_shown}}}.
_STATE: dict = _empty_counts()
def _get_count(image_id: Any, filename: str, caption_type: str) -> int:
return int(
_STATE.get(str(image_id), {}).get(filename, {}).get(caption_type, 0)
)
def _incr_count(
image_id: Any, filename: str, caption_type: str, amount: int = 1
) -> None:
per_image = _STATE.setdefault(str(image_id), {})
per_filename = per_image.setdefault(filename, {})
per_filename[caption_type] = int(per_filename.get(caption_type, 0)) + amount
def _counts_from_results() -> dict | None:
"""Tally (image_id, filename, type) exposures across every results/*.csv.
Returns a zero-initialised ``{image_id: {filename: {type: count}}}`` tree,
or ``None`` if the results listing could not be read (so the caller can
fall back to the cache).
"""
if not HF_TOKEN:
return None
try:
files = api.list_repo_files(repo_id=RESULTS_REPO, repo_type="dataset")
except Exception as exc: # noqa: BLE001
print(f"[state] Could not list results files ({exc}).")
return None
result_files = [
f for f in files if f.startswith("results/") and f.endswith(".csv")
]
counts: dict = _empty_counts()
n_rows = 0
for rf in result_files:
try:
path = hf_hub_download(
repo_id=RESULTS_REPO,
repo_type="dataset",
filename=rf,
token=HF_TOKEN,
force_download=True,
)
frame = pd.read_csv(path)
except Exception as exc: # noqa: BLE001
print(f"[state] Skipping unreadable results file {rf} ({exc}).")
continue
needed = {"image_id", "filename", "type"}
if not needed.issubset(frame.columns):
continue
for iid, fn, ty in zip(
frame["image_id"].astype(str),
frame["filename"].astype(str),
frame["type"].astype(str),
):
if (iid, fn, ty) not in VALID_TRIAL_KEYS:
continue
counts[iid][fn][ty] += 1
n_rows += 1
print(
f"[state] Tallied {n_rows} exposures from {len(result_files)} "
f"results file(s)."
)
return counts
def _load_state() -> None:
"""Seed ``_STATE`` from the cached state.json (fallback before refresh)."""
global _STATE
if not HF_TOKEN:
return
try:
path = hf_hub_download(
repo_id=RESULTS_REPO,
repo_type="dataset",
filename="state.json",
token=HF_TOKEN,
force_download=True,
)
with open(path) as f:
loaded = json.load(f)
# Accept either the current nested tree or the legacy
# ``{"type_counts": ...}`` wrapper; rebuild fresh on anything else.
if isinstance(loaded, dict) and "type_counts" not in loaded:
_STATE = loaded
else:
_STATE = _empty_counts()
print(f"[state] Loaded cached exposure tree for {len(_STATE)} image_id(s).")
except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError):
print("[state] No existing state.json found, starting fresh.")
_STATE = _empty_counts()
except Exception as exc: # noqa: BLE001
print(f"[state] Could not load state.json ({exc}); starting fresh.")
_STATE = _empty_counts()
def _save_state() -> None:
if not HF_TOKEN:
return
payload = json.dumps(_STATE, indent=2).encode()
api.upload_file(
path_or_fileobj=io.BytesIO(payload),
path_in_repo="state.json",
repo_id=RESULTS_REPO,
repo_type="dataset",
commit_message="Update exposure counts",
)
def _refresh_counts_from_results() -> None:
"""Rebuild counts from the authoritative results CSVs and persist them."""
global _STATE
counts = _counts_from_results()
if counts is None:
return
with _STATE_LOCK:
_STATE = counts
try:
_save_state()
except Exception as exc: # noqa: BLE001
print(f"[state] WARNING: could not persist state.json ({exc}).")
_load_state()
_refresh_counts_from_results()
def _assign_trials(image_ids_to_assign: list) -> dict:
"""Pick the lowest-occurrence (filename, type) trial per image_id.
For each image_id we scan every ``(filename, type)`` trial it has (only the
caption types each filename actually has) and pick the one with the lowest
recorded count in ``state.json``. Ties are broken by order, i.e. the first
trial that reaches the minimum count wins. Picks are reserved immediately
(count incremented + persisted) so the next assignment sees the update.
"""
with _STATE_LOCK:
assignments: dict = {}
for img_id in image_ids_to_assign:
best_count: int | None = None
best_fn: str | None = None
best_type: str | None = None
for fn in IMAGE_ID_TO_FILENAMES[img_id]:
for caption_type in FILENAME_TO_TYPES[fn]:
count = _get_count(img_id, fn, caption_type)
if best_count is None or count < best_count:
best_count = count
best_fn = fn
best_type = caption_type
_incr_count(img_id, best_fn, best_type)
assignments[img_id] = (best_fn, best_type)
if assignments:
try:
_save_state()
except Exception as exc: # noqa: BLE001
print(f"[state] WARNING: could not persist state.json ({exc}).")
return assignments
# ---------------------------------------------------------------------------
# Per-participant CSV
# ---------------------------------------------------------------------------
def _participant_filename(code: str) -> str:
return f"results/{code}.csv"
def _load_participant_results(participant_file: str) -> list[dict]:
if not HF_TOKEN:
return []
try:
path = hf_hub_download(
repo_id=RESULTS_REPO,
repo_type="dataset",
filename=participant_file,
token=HF_TOKEN,
force_download=True,
)
frame = pd.read_csv(path)
return frame.to_dict(orient="records")
except (EntryNotFoundError, RepositoryNotFoundError, FileNotFoundError):
return []
except Exception as exc: # noqa: BLE001
print(f"[participant] Could not load {participant_file} ({exc})")
return []
def _completed_keys(prior_results: list[dict]) -> tuple[set, set]:
"""Return (done_nontest_image_ids, done_test_row_ids) from a CSV-loaded list."""
done_image_ids = set()
done_test_ids = set()
for r in prior_results:
try:
row_id = int(r["id"])
except (KeyError, TypeError, ValueError):
continue
if row_id in TEST_ROW_IDS:
done_test_ids.add(row_id)
continue
img_id_str = str(r.get("image_id"))
if "test" in img_id_str.lower():
done_test_ids.add(row_id)
continue
img_id_val = r.get("image_id")
if img_id_val in NONTEST_IMAGE_ID_SET:
done_image_ids.add(img_id_val)
else:
try:
coerced = int(img_id_val)
if coerced in NONTEST_IMAGE_ID_SET:
done_image_ids.add(coerced)
except (TypeError, ValueError):
pass
return done_image_ids, done_test_ids
def _is_complete(prior_results: list[dict]) -> bool:
done_image_ids, done_test_ids = _completed_keys(prior_results)
return done_image_ids >= NONTEST_IMAGE_ID_SET and done_test_ids >= TEST_ROW_IDS
def _build_remaining_trials(prior_results: list[dict]) -> list[dict]:
done_image_ids, done_test_ids = _completed_keys(prior_results)
remaining_image_ids = [
iid for iid in NONTEST_IMAGE_IDS if iid not in done_image_ids
]
assignments = _assign_trials(remaining_image_ids)
trials: list[dict] = []
for img_id in remaining_image_ids:
fn, caption_type = assignments[img_id]
match = NONTEST_DF[
(NONTEST_DF["image_id"] == img_id)
& (NONTEST_DF["filename"] == fn)
& (NONTEST_DF["type"] == caption_type)
]
if match.empty:
continue
trials.append(_row_to_trial(match.iloc[0]))
for _, row in TEST_DF.iterrows():
if int(row["id"]) in done_test_ids:
continue
trials.append(_row_to_trial(row))
random.shuffle(trials)
return trials
def _row_to_trial(row: pd.Series) -> dict:
raw_image_id = row["image_id"]
if isinstance(raw_image_id, (int,)) or (
isinstance(raw_image_id, str) and raw_image_id.lstrip("-").isdigit()
):
image_id_out: Any = int(raw_image_id)
else:
image_id_out = str(raw_image_id)
return {
"id": int(row["id"]),
"image_id": image_id_out,
"filename": str(row["filename"]),
"type": str(row["type"]),
"human_caption": str(row["human_caption"]),
"model_caption": str(row["model_caption"]),
"human_on_left": random.choice([True, False]),
}
# Per-participant save coordination. Uploads for a given participant file are
# serialized through one lock, and we never overwrite a larger file with a
# smaller (stale) snapshot. This prevents the out-of-order/last-writer-wins race
# that previously truncated participant files when clicks were saved from
# unsynchronized background threads.
_SAVE_REGISTRY_LOCK = threading.Lock()
_SAVE_ENTRIES: dict[str, dict] = {}
def _save_entry(participant_file: str) -> dict:
with _SAVE_REGISTRY_LOCK:
entry = _SAVE_ENTRIES.get(participant_file)
if entry is None:
entry = {"lock": threading.Lock(), "saved_count": 0}
_SAVE_ENTRIES[participant_file] = entry
return entry
def _reset_save_baseline(participant_file: str, count: int) -> None:
"""Align the never-shrink guard with what's actually on HF at session start."""
entry = _save_entry(participant_file)
with entry["lock"]:
entry["saved_count"] = count
_SAVE_MAX_RETRIES = 6
def _save_results(participant_file: str, results: list[dict]) -> None:
if not HF_TOKEN or not results:
return
snapshot = list(results)
entry = _save_entry(participant_file)
# Serialize all uploads for this participant so they can't race each other.
with entry["lock"]:
# Never replace a more-complete file with a stale/smaller snapshot.
if len(snapshot) <= entry["saved_count"]:
return
frame = pd.DataFrame(snapshot, columns=RESULTS_COLUMNS)
csv_bytes = frame.to_csv(index=False).encode()
# Different participants commit to the same repo concurrently, so an
# individual upload can still be rejected with a revision conflict.
# Retry with backoff so no answer is silently dropped (this was the
# original data-loss bug: conflicts were swallowed and never retried).
for attempt in range(_SAVE_MAX_RETRIES):
try:
api.upload_file(
path_or_fileobj=io.BytesIO(csv_bytes),
path_in_repo=participant_file,
repo_id=RESULTS_REPO,
repo_type="dataset",
commit_message=f"Update {participant_file} (n={len(snapshot)})",
)
entry["saved_count"] = len(snapshot)
return
except Exception as exc: # noqa: BLE001
wait = 0.5 * (2**attempt) + random.uniform(0, 0.4)
print(
f"[save] upload attempt {attempt + 1}/{_SAVE_MAX_RETRIES} "
f"failed for {participant_file} ({exc}); retrying in {wait:.1f}s."
)
time.sleep(wait)
print(
f"[save] ERROR: gave up saving {participant_file} after "
f"{_SAVE_MAX_RETRIES} attempts (n={len(snapshot)})."
)
# ---------------------------------------------------------------------------
# Gradio handlers
# ---------------------------------------------------------------------------
WELCOME_HTML = """
<div style="text-align:center; padding: 12px 16px 4px;">
<h2 style="margin-bottom: 8px;">Caption Preference Study</h2>
<p style="font-size: 1.05em; margin: 0;">
You will see images with two captions. Click the caption that better
describes the image.
</p>
</div>
"""
DONE_NEW_HTML = """
<div style="text-align:center; padding: 32px;">
<h2>All done — thank you for participating!</h2>
<p>You can close this tab now.</p>
</div>
"""
DONE_ALREADY_HTML_TMPL = """
<div style="text-align:center; padding: 32px;">
<h2>You've already completed this study.</h2>
<p>Our records show access code <code>{code}</code> has finished all
{total} trials. There's nothing more to do — feel free to close this tab.</p>
</div>
"""
def _validation_error(message: str):
return (
None, # state
gr.update(visible=True), # intro
gr.update(visible=False), # trial group
gr.update(visible=False, value=""), # done panel
None, # image
gr.update(value=""), # left button
gr.update(value=""), # right button
"", # progress
gr.update(value=message, visible=True), # error markdown
)
def start_session(access_code: str):
code = _normalize_code(access_code)
if not code:
return _validation_error("Please enter your **access code**.")
if not _ACCESS_CODES:
return _validation_error(
"Server isn't ready (access codes not loaded). Please try again "
"in a minute."
)
if code not in _ACCESS_CODES:
return _validation_error(
"That access code isn't valid. Please double-check and try again."
)
participant_file = _participant_filename(code)
prior = _load_participant_results(participant_file)
# Baseline the never-shrink save guard to the file that's actually on HF,
# so a returning participant's saves grow from their real prior progress.
_reset_save_baseline(participant_file, len(prior))
if _is_complete(prior):
msg = DONE_ALREADY_HTML_TMPL.format(
code=code, total=TOTAL_TRIALS_PER_PARTICIPANT
)
return (
None,
gr.update(visible=False),
gr.update(visible=False),
gr.update(value=msg, visible=True),
None,
gr.update(value=""),
gr.update(value=""),
"",
gr.update(value="", visible=False),
)
trials = _build_remaining_trials(prior)
if not trials:
# Defensive: nothing left to do but the strict completeness check did
# not return True. Treat as done so the participant isn't stuck.
msg = DONE_ALREADY_HTML_TMPL.format(
code=code, total=TOTAL_TRIALS_PER_PARTICIPANT
)
return (
None,
gr.update(visible=False),
gr.update(visible=False),
gr.update(value=msg, visible=True),
None,
gr.update(value=""),
gr.update(value=""),
"",
gr.update(value="", visible=False),
)
state = {
"participant_file": participant_file,
"trials": trials,
"current_idx": 0,
"trial_start_time": time.time(),
"results": list(prior),
"prior_count": len(prior),
"total_trials": TOTAL_TRIALS_PER_PARTICIPANT,
}
img_path, left, right, progress = _current_display(state)
return (
state,
gr.update(visible=False), # intro
gr.update(visible=True), # trial group
gr.update(value="", visible=False), # done panel
img_path, # image
gr.update(value=left), # left button
gr.update(value=right), # right button
progress, # progress
gr.update(value="", visible=False), # error
)
def _current_display(state: dict) -> tuple:
if state is None or state["current_idx"] >= len(state["trials"]):
return None, "", "", ""
trial = state["trials"][state["current_idx"]]
img_path = str(IMAGE_DIR / trial["filename"])
if trial["human_on_left"]:
left, right = trial["human_caption"], trial["model_caption"]
else:
left, right = trial["model_caption"], trial["human_caption"]
completed = state["prior_count"] + state["current_idx"]
total = state["total_trials"]
progress = f"Trial {completed + 1} of {total}"
return img_path, left, right, progress
def _make_choice(state: dict, side: str):
if state is None:
return (
state,
gr.update(visible=False),
gr.update(visible=False),
None,
gr.update(value=""),
gr.update(value=""),
"",
)
elapsed = min(time.time() - state["trial_start_time"], RESPONSE_TIME_CAP)
trial = state["trials"][state["current_idx"]]
chose_human = trial["human_on_left"] if side == "left" else not trial["human_on_left"]
state["results"].append(
{
"id": trial["id"],
"image_id": trial["image_id"],
"filename": trial["filename"],
"type": trial["type"],
"human_caption": trial["human_caption"],
"model_caption": trial["model_caption"],
"preference": "H" if chose_human else "M",
"response_time": round(elapsed, 3),
}
)
state["current_idx"] += 1
is_done = state["current_idx"] >= len(state["trials"])
if is_done:
# Final trial: save synchronously so completion is guaranteed persisted
# (all trials written) before we show the "done" panel.
_save_results(state["participant_file"], list(state["results"]))
else:
threading.Thread(
target=_save_results,
args=(state["participant_file"], list(state["results"])),
daemon=True,
).start()
if is_done:
total = state["total_trials"]
return (
state,
gr.update(visible=False),
gr.update(value=DONE_NEW_HTML, visible=True),
None,
gr.update(value=""),
gr.update(value=""),
f"Done — {total} / {total}",
)
state["trial_start_time"] = time.time()
img_path, left, right, progress = _current_display(state)
return (
state,
gr.update(visible=True),
gr.update(visible=False),
img_path,
gr.update(value=left),
gr.update(value=right),
progress,
)
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
custom_css = """
.caption-btn {
min-height: 140px !important;
font-size: 1.05em !important;
white-space: normal !important;
line-height: 1.4 !important;
padding: 16px !important;
text-align: left !important;
}
.center-img img { max-height: 60vh !important; object-fit: contain !important; }
.form-error { color: #b91c1c !important; }
.access-code-input input {
text-align: center !important;
font-size: 1.4em !important;
letter-spacing: 0.15em !important;
font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace !important;
}
"""
with gr.Blocks(title="Caption Preference Study", css=custom_css) as demo:
state = gr.State()
intro = gr.Group(visible=True)
with intro:
gr.HTML(WELCOME_HTML)
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=2):
code_input = gr.Textbox(
label="Access code",
placeholder="Enter your 8-character access code",
max_lines=1,
elem_classes=["access-code-input"],
)
start_btn = gr.Button("Start", variant="primary", size="lg")
error_md = gr.Markdown("", visible=False, elem_classes=["form-error"])
with gr.Column(scale=1):
pass
trial_group = gr.Group(visible=False)
with trial_group:
progress = gr.Markdown("")
image = gr.Image(
label=None,
show_label=False,
interactive=False,
elem_classes=["center-img"],
)
with gr.Row():
left_btn = gr.Button("", elem_classes=["caption-btn"])
right_btn = gr.Button("", elem_classes=["caption-btn"])
done_panel = gr.HTML(visible=False)
start_btn.click(
start_session,
inputs=[code_input],
outputs=[
state,
intro,
trial_group,
done_panel,
image,
left_btn,
right_btn,
progress,
error_md,
],
)
left_btn.click(
lambda s: _make_choice(s, "left"),
inputs=[state],
outputs=[state, trial_group, done_panel, image, left_btn, right_btn, progress],
)
right_btn.click(
lambda s: _make_choice(s, "right"),
inputs=[state],
outputs=[state, trial_group, done_panel, image, left_btn, right_btn, progress],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=8).launch(allowed_paths=[str(IMAGE_DIR)])