Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """Human-preference study (2AFC pairwise) — SD3.5-Medium ImageReward-guided vs vanilla, 4 styles. | |
| Accepted T2I protocol (Imagen/Parti/Emu/Diffusion-DPO): two questions/pair (overall preference + prompt | |
| alignment), tie option, neutral labels, left/right randomized, gold catch trials, email-gated (one | |
| submission/person). SQLite backend, JSON API + single-page UI. | |
| Run: ./env/bin/python -m uvicorn human_study.app.main:app --host 0.0.0.0 --port 8000 | |
| DB: human_study/data/study.db Manifest: human_study/data/pairs.json | |
| Admin: GET /api/stats?key=$STUDY_ADMIN_KEY GET /api/export?key=... (CSV) | |
| """ | |
| import json, os, sqlite3, time, uuid, random, csv, io, shutil | |
| from pathlib import Path | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import HTMLResponse, PlainTextResponse | |
| from fastapi.staticfiles import StaticFiles | |
| HERE = Path(__file__).resolve().parent | |
| # STUDY_DATA (pairs.json + study.db) and STUDY_STATIC (index.html + imgs/) are env-overridable so the same | |
| # app code can host multiple independent studies (e.g. the 4-style study and a photo-only study). | |
| DATA = Path(os.environ.get("STUDY_DATA", str(HERE.parent / "data"))) | |
| STATIC = Path(os.environ.get("STUDY_STATIC", str(HERE / "static"))) | |
| DB = DATA / "study.db" | |
| PAIRS = {p["pair_id"]: p for p in json.load(open(DATA / "pairs.json"))} | |
| ADMIN_KEY = os.environ.get("STUDY_ADMIN_KEY", "rsfg-admin") | |
| # Durable storage on Hugging Face Spaces (ephemeral disk): restore the response DB from a private Dataset on | |
| # startup, back it up after every submission. No-op when the env vars are unset (e.g. local runs). | |
| PERSIST_REPO = os.environ.get("HF_PERSIST_REPO") # e.g. "bs82/image-study-db" | |
| STUDY_NAME = os.environ.get("STUDY_NAME", "study") # subfolder in the Dataset (one per study) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| _DB_IN_REPO = f"{STUDY_NAME}/study.db" | |
| def restore_db(): | |
| if not (PERSIST_REPO and HF_TOKEN): | |
| return | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| p = hf_hub_download(PERSIST_REPO, _DB_IN_REPO, repo_type="dataset", token=HF_TOKEN) | |
| shutil.copy(p, DB) | |
| print(f"[persist] restored DB from {PERSIST_REPO}:{_DB_IN_REPO}") | |
| except Exception as e: | |
| print(f"[persist] no prior DB ({e}); starting fresh") | |
| def backup_db(): | |
| if not (PERSIST_REPO and HF_TOKEN): | |
| return | |
| try: | |
| from huggingface_hub import HfApi | |
| HfApi(token=HF_TOKEN).upload_file(path_or_fileobj=str(DB), path_in_repo=_DB_IN_REPO, | |
| repo_id=PERSIST_REPO, repo_type="dataset", | |
| commit_message=f"backup {STUDY_NAME}") | |
| except Exception as e: | |
| print(f"[persist] backup failed: {e}") | |
| def db(): | |
| c = sqlite3.connect(DB); c.row_factory = sqlite3.Row; return c | |
| def init_db(): | |
| DATA.mkdir(parents=True, exist_ok=True) | |
| with db() as c: | |
| c.execute("""CREATE TABLE IF NOT EXISTS participants( | |
| token TEXT PRIMARY KEY, email TEXT UNIQUE, created REAL, completed INTEGER DEFAULT 0, | |
| n_catch INTEGER DEFAULT 0, n_catch_pass INTEGER DEFAULT 0, record INTEGER DEFAULT 1)""") | |
| c.execute("""CREATE TABLE IF NOT EXISTS votes( | |
| token TEXT, email TEXT, pair_id INTEGER, kind TEXT, style TEXT, question TEXT, | |
| choice TEXT, guided_chosen INTEGER, ts REAL)""") # guided_chosen: 1 guided, 0 vanilla, NULL tie | |
| restore_db() | |
| init_db() | |
| app = FastAPI(title="Image preference study") | |
| app.mount("/static", StaticFiles(directory=str(STATIC)), name="static") | |
| def index(): | |
| return (STATIC / "index.html").read_text() | |
| def info(): | |
| n = len(PAIRS) | |
| return {"n_pairs": n, "est_minutes": max(1, round(n * 3.5 / 60))} | |
| def norm_email(e): | |
| e = (e or "").strip().lower() | |
| if "@" not in e or "." not in e.split("@")[-1] or len(e) < 5: | |
| raise HTTPException(400, "Please enter a valid email address.") | |
| return e | |
| async def register(req: Request): | |
| body = await req.json() | |
| email = norm_email(body.get("email")) | |
| record = 0 if body.get("record") is False else 1 # opt-out checkbox; default = record | |
| with db() as c: | |
| row = c.execute("SELECT token, completed FROM participants WHERE email=?", (email,)).fetchone() | |
| if row and row["completed"]: | |
| raise HTTPException(409, "This email has already completed the study. Thank you!") | |
| token = row["token"] if row else uuid.uuid4().hex | |
| if not row: | |
| c.execute("INSERT INTO participants(token,email,created,record) VALUES(?,?,?,?)", | |
| (token, email, time.time(), record)) | |
| else: | |
| c.execute("UPDATE participants SET record=? WHERE token=?", (record, token)) | |
| order = list(PAIRS.values()) | |
| random.Random(hash(token) & 0xffffffff).shuffle(order) | |
| pairs = [{"pair_id": p["pair_id"], "prompt": p["prompt"], | |
| "left": f"/static/imgs/{p['left']}", "right": f"/static/imgs/{p['right']}"} for p in order] | |
| return {"token": token, "pairs": pairs} | |
| def _summarize(votes): | |
| """Per-participant own results (thank-you page): three-way win/tie/lose vs our method, overall + per style. | |
| win = chose ours, tie = 'cannot decide', lose = chose the other; rates are over ALL comparisons (incl. ties).""" | |
| def tri(): | |
| return [0, 0, 0] # [ours_win, tie, other_win] | |
| overall = tri(); by_style = {} | |
| for v in votes: | |
| p = PAIRS.get(int(v.get("pair_id", -1))) | |
| ch = v.get("choice") | |
| if not p or p["kind"] != "real" or ch not in ("left", "right", "cannot"): | |
| continue | |
| k = 1 if ch == "cannot" else (0 if ch == p["guided_side"] else 2) | |
| overall[k] += 1 | |
| by_style.setdefault(p["style"], tri())[k] += 1 | |
| def pack(t): | |
| n = sum(t) | |
| return {"win": t[0], "tie": t[1], "lose": t[2], "n": n, | |
| "win_rate": (t[0] / n) if n else None, | |
| "tie_rate": (t[1] / n) if n else None, | |
| "lose_rate": (t[2] / n) if n else None} | |
| return {"pref": {**pack(overall), "by_style": {k: pack(t) for k, t in by_style.items()}}} | |
| async def submit(req: Request): | |
| body = await req.json() | |
| token, votes = body.get("token", ""), body.get("votes", []) | |
| with db() as c: | |
| row = c.execute("SELECT email, completed, record FROM participants WHERE token=?", (token,)).fetchone() | |
| if not row: | |
| raise HTTPException(404, "Unknown session.") | |
| if row["completed"]: | |
| raise HTTPException(409, "Already submitted.") | |
| email, record = row["email"], row["record"] | |
| n_catch = n_pass = 0 | |
| for v in votes: | |
| p = PAIRS.get(int(v.get("pair_id", -1))) | |
| ch = v.get("choice") | |
| if not p or ch not in ("left", "right", "cannot"): | |
| continue | |
| if p["kind"] == "catch": | |
| gc = int(ch == p["correct_side"]); n_catch += 1; n_pass += gc | |
| elif ch == "cannot": | |
| gc = None # "Cannot decide" → tie / no preference (excluded from win-rate denominator) | |
| else: | |
| gc = int(ch == p["guided_side"]) | |
| if record: # opt-out: skip persisting votes (results still shown to the participant) | |
| c.execute("INSERT INTO votes VALUES(?,?,?,?,?,?,?,?,?)", | |
| (token, email, p["pair_id"], p["kind"], p["style"], "pref", ch, gc, time.time())) | |
| c.execute("UPDATE participants SET completed=1, n_catch=?, n_catch_pass=? WHERE token=?", | |
| (n_catch, n_pass, token)) | |
| backup_db() | |
| return {"ok": True, "recorded": bool(record), "results": _summarize(votes)} | |
| def stats(key: str = ""): | |
| if key != ADMIN_KEY: | |
| raise HTTPException(403, "bad key") | |
| with db() as c: | |
| np_ = c.execute("SELECT COUNT(*) n FROM participants WHERE completed=1").fetchone()["n"] | |
| # attentive = passed all catch preference trials | |
| attn = "token IN (SELECT token FROM participants WHERE completed=1 AND n_catch_pass=n_catch AND n_catch>0)" | |
| out = {"completed_participants": np_, "by_question": {}} | |
| for q in ("pref", "align"): | |
| rows = c.execute(f"""SELECT style, COUNT(*) n, SUM(guided_chosen) g, | |
| SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties | |
| FROM votes WHERE kind='real' AND question=? AND token IN | |
| (SELECT token FROM participants WHERE completed=1) GROUP BY style""", (q,)).fetchall() | |
| tot = c.execute(f"""SELECT COUNT(*) n, SUM(guided_chosen) g, | |
| SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties | |
| FROM votes WHERE kind='real' AND question=? | |
| AND token IN (SELECT token FROM participants WHERE completed=1)""", (q,)).fetchone() | |
| N = tot["n"] or 0; W = tot["g"] or 0; T = tot["ties"] or 0; L = N - W - T # win / tie / lose | |
| out["by_question"][q] = { | |
| "n": N, "win": W, "tie": T, "lose": L, | |
| "win_rate": (W/N) if N else None, "tie_rate": (T/N) if N else None, "lose_rate": (L/N) if N else None, | |
| "decisive_win_rate": (W/(W+L)) if (W+L) else None, # ties dropped (sign-test denominator) | |
| "by_style": {r["style"]: {"n": r["n"], "win": r["g"] or 0, "tie": r["ties"] or 0, | |
| "lose": (r["n"] - (r["g"] or 0) - (r["ties"] or 0))} for r in rows}} | |
| return out | |
| def export(key: str = ""): | |
| if key != ADMIN_KEY: | |
| raise HTTPException(403, "bad key") | |
| buf = io.StringIO(); w = csv.writer(buf) | |
| w.writerow(["email", "pair_id", "kind", "style", "question", "choice", "guided_chosen", "ts"]) | |
| with db() as c: | |
| for r in c.execute("SELECT email,pair_id,kind,style,question,choice,guided_chosen,ts FROM votes ORDER BY ts"): | |
| w.writerow([r["email"], r["pair_id"], r["kind"], r["style"], r["question"], r["choice"], | |
| "" if r["guided_chosen"] is None else r["guided_chosen"], r["ts"]]) | |
| return PlainTextResponse(buf.getvalue(), media_type="text/csv") | |