#!/usr/bin/env python """Human-preference study (2AFC pairwise) — SD3.5-Medium ImageReward-guided vs vanilla, 4 styles. Accepted T2I protocol (Imagen/Parti/Emu/Diffusion-DPO): two questions/pair (overall preference + prompt alignment), tie option, neutral labels, left/right randomized, gold catch trials, email-gated (one submission/person). SQLite backend, JSON API + single-page UI. Run: ./env/bin/python -m uvicorn human_study.app.main:app --host 0.0.0.0 --port 8000 DB: human_study/data/study.db Manifest: human_study/data/pairs.json Admin: GET /api/stats?key=$STUDY_ADMIN_KEY GET /api/export?key=... (CSV) """ import json, os, sqlite3, time, uuid, random, csv, io, shutil from pathlib import Path from fastapi import FastAPI, Request, HTTPException from fastapi.responses import HTMLResponse, PlainTextResponse from fastapi.staticfiles import StaticFiles HERE = Path(__file__).resolve().parent # STUDY_DATA (pairs.json + study.db) and STUDY_STATIC (index.html + imgs/) are env-overridable so the same # app code can host multiple independent studies (e.g. the 4-style study and a photo-only study). DATA = Path(os.environ.get("STUDY_DATA", str(HERE.parent / "data"))) STATIC = Path(os.environ.get("STUDY_STATIC", str(HERE / "static"))) DB = DATA / "study.db" PAIRS = {p["pair_id"]: p for p in json.load(open(DATA / "pairs.json"))} ADMIN_KEY = os.environ.get("STUDY_ADMIN_KEY", "rsfg-admin") # Durable storage on Hugging Face Spaces (ephemeral disk): restore the response DB from a private Dataset on # startup, back it up after every submission. No-op when the env vars are unset (e.g. local runs). PERSIST_REPO = os.environ.get("HF_PERSIST_REPO") # e.g. "bs82/image-study-db" STUDY_NAME = os.environ.get("STUDY_NAME", "study") # subfolder in the Dataset (one per study) HF_TOKEN = os.environ.get("HF_TOKEN") _DB_IN_REPO = f"{STUDY_NAME}/study.db" def restore_db(): if not (PERSIST_REPO and HF_TOKEN): return try: from huggingface_hub import hf_hub_download p = hf_hub_download(PERSIST_REPO, _DB_IN_REPO, repo_type="dataset", token=HF_TOKEN) shutil.copy(p, DB) print(f"[persist] restored DB from {PERSIST_REPO}:{_DB_IN_REPO}") except Exception as e: print(f"[persist] no prior DB ({e}); starting fresh") def backup_db(): if not (PERSIST_REPO and HF_TOKEN): return try: from huggingface_hub import HfApi HfApi(token=HF_TOKEN).upload_file(path_or_fileobj=str(DB), path_in_repo=_DB_IN_REPO, repo_id=PERSIST_REPO, repo_type="dataset", commit_message=f"backup {STUDY_NAME}") except Exception as e: print(f"[persist] backup failed: {e}") def db(): c = sqlite3.connect(DB); c.row_factory = sqlite3.Row; return c def init_db(): DATA.mkdir(parents=True, exist_ok=True) with db() as c: c.execute("""CREATE TABLE IF NOT EXISTS participants( token TEXT PRIMARY KEY, email TEXT UNIQUE, created REAL, completed INTEGER DEFAULT 0, n_catch INTEGER DEFAULT 0, n_catch_pass INTEGER DEFAULT 0, record INTEGER DEFAULT 1)""") c.execute("""CREATE TABLE IF NOT EXISTS votes( token TEXT, email TEXT, pair_id INTEGER, kind TEXT, style TEXT, question TEXT, choice TEXT, guided_chosen INTEGER, ts REAL)""") # guided_chosen: 1 guided, 0 vanilla, NULL tie restore_db() init_db() app = FastAPI(title="Image preference study") app.mount("/static", StaticFiles(directory=str(STATIC)), name="static") @app.get("/", response_class=HTMLResponse) def index(): return (STATIC / "index.html").read_text() @app.get("/api/info") def info(): n = len(PAIRS) return {"n_pairs": n, "est_minutes": max(1, round(n * 3.5 / 60))} def norm_email(e): e = (e or "").strip().lower() if "@" not in e or "." not in e.split("@")[-1] or len(e) < 5: raise HTTPException(400, "Please enter a valid email address.") return e @app.post("/api/register") async def register(req: Request): body = await req.json() email = norm_email(body.get("email")) record = 0 if body.get("record") is False else 1 # opt-out checkbox; default = record with db() as c: row = c.execute("SELECT token, completed FROM participants WHERE email=?", (email,)).fetchone() if row and row["completed"]: raise HTTPException(409, "This email has already completed the study. Thank you!") token = row["token"] if row else uuid.uuid4().hex if not row: c.execute("INSERT INTO participants(token,email,created,record) VALUES(?,?,?,?)", (token, email, time.time(), record)) else: c.execute("UPDATE participants SET record=? WHERE token=?", (record, token)) order = list(PAIRS.values()) random.Random(hash(token) & 0xffffffff).shuffle(order) pairs = [{"pair_id": p["pair_id"], "prompt": p["prompt"], "left": f"/static/imgs/{p['left']}", "right": f"/static/imgs/{p['right']}"} for p in order] return {"token": token, "pairs": pairs} def _summarize(votes): """Per-participant own results (thank-you page): three-way win/tie/lose vs our method, overall + per style. win = chose ours, tie = 'cannot decide', lose = chose the other; rates are over ALL comparisons (incl. ties).""" def tri(): return [0, 0, 0] # [ours_win, tie, other_win] overall = tri(); by_style = {} for v in votes: p = PAIRS.get(int(v.get("pair_id", -1))) ch = v.get("choice") if not p or p["kind"] != "real" or ch not in ("left", "right", "cannot"): continue k = 1 if ch == "cannot" else (0 if ch == p["guided_side"] else 2) overall[k] += 1 by_style.setdefault(p["style"], tri())[k] += 1 def pack(t): n = sum(t) return {"win": t[0], "tie": t[1], "lose": t[2], "n": n, "win_rate": (t[0] / n) if n else None, "tie_rate": (t[1] / n) if n else None, "lose_rate": (t[2] / n) if n else None} return {"pref": {**pack(overall), "by_style": {k: pack(t) for k, t in by_style.items()}}} @app.post("/api/submit") async def submit(req: Request): body = await req.json() token, votes = body.get("token", ""), body.get("votes", []) with db() as c: row = c.execute("SELECT email, completed, record FROM participants WHERE token=?", (token,)).fetchone() if not row: raise HTTPException(404, "Unknown session.") if row["completed"]: raise HTTPException(409, "Already submitted.") email, record = row["email"], row["record"] n_catch = n_pass = 0 for v in votes: p = PAIRS.get(int(v.get("pair_id", -1))) ch = v.get("choice") if not p or ch not in ("left", "right", "cannot"): continue if p["kind"] == "catch": gc = int(ch == p["correct_side"]); n_catch += 1; n_pass += gc elif ch == "cannot": gc = None # "Cannot decide" → tie / no preference (excluded from win-rate denominator) else: gc = int(ch == p["guided_side"]) if record: # opt-out: skip persisting votes (results still shown to the participant) c.execute("INSERT INTO votes VALUES(?,?,?,?,?,?,?,?,?)", (token, email, p["pair_id"], p["kind"], p["style"], "pref", ch, gc, time.time())) c.execute("UPDATE participants SET completed=1, n_catch=?, n_catch_pass=? WHERE token=?", (n_catch, n_pass, token)) backup_db() return {"ok": True, "recorded": bool(record), "results": _summarize(votes)} @app.get("/api/stats") def stats(key: str = ""): if key != ADMIN_KEY: raise HTTPException(403, "bad key") with db() as c: np_ = c.execute("SELECT COUNT(*) n FROM participants WHERE completed=1").fetchone()["n"] # attentive = passed all catch preference trials attn = "token IN (SELECT token FROM participants WHERE completed=1 AND n_catch_pass=n_catch AND n_catch>0)" out = {"completed_participants": np_, "by_question": {}} for q in ("pref", "align"): rows = c.execute(f"""SELECT style, COUNT(*) n, SUM(guided_chosen) g, SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties FROM votes WHERE kind='real' AND question=? AND token IN (SELECT token FROM participants WHERE completed=1) GROUP BY style""", (q,)).fetchall() tot = c.execute(f"""SELECT COUNT(*) n, SUM(guided_chosen) g, SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties FROM votes WHERE kind='real' AND question=? AND token IN (SELECT token FROM participants WHERE completed=1)""", (q,)).fetchone() N = tot["n"] or 0; W = tot["g"] or 0; T = tot["ties"] or 0; L = N - W - T # win / tie / lose out["by_question"][q] = { "n": N, "win": W, "tie": T, "lose": L, "win_rate": (W/N) if N else None, "tie_rate": (T/N) if N else None, "lose_rate": (L/N) if N else None, "decisive_win_rate": (W/(W+L)) if (W+L) else None, # ties dropped (sign-test denominator) "by_style": {r["style"]: {"n": r["n"], "win": r["g"] or 0, "tie": r["ties"] or 0, "lose": (r["n"] - (r["g"] or 0) - (r["ties"] or 0))} for r in rows}} return out @app.get("/api/export") def export(key: str = ""): if key != ADMIN_KEY: raise HTTPException(403, "bad key") buf = io.StringIO(); w = csv.writer(buf) w.writerow(["email", "pair_id", "kind", "style", "question", "choice", "guided_chosen", "ts"]) with db() as c: for r in c.execute("SELECT email,pair_id,kind,style,question,choice,guided_chosen,ts FROM votes ORDER BY ts"): w.writerow([r["email"], r["pair_id"], r["kind"], r["style"], r["question"], r["choice"], "" if r["guided_chosen"] is None else r["guided_chosen"], r["ts"]]) return PlainTextResponse(buf.getvalue(), media_type="text/csv")