bs82's picture
deploy sd35m: 40 pairs
edb73aa verified
Raw
History Blame Contribute Delete
10.2 kB
#!/usr/bin/env python
"""Human-preference study (2AFC pairwise) — SD3.5-Medium ImageReward-guided vs vanilla, 4 styles.
Accepted T2I protocol (Imagen/Parti/Emu/Diffusion-DPO): two questions/pair (overall preference + prompt
alignment), tie option, neutral labels, left/right randomized, gold catch trials, email-gated (one
submission/person). SQLite backend, JSON API + single-page UI.
Run: ./env/bin/python -m uvicorn human_study.app.main:app --host 0.0.0.0 --port 8000
DB: human_study/data/study.db Manifest: human_study/data/pairs.json
Admin: GET /api/stats?key=$STUDY_ADMIN_KEY GET /api/export?key=... (CSV)
"""
import json, os, sqlite3, time, uuid, random, csv, io, shutil
from pathlib import Path
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
HERE = Path(__file__).resolve().parent
# STUDY_DATA (pairs.json + study.db) and STUDY_STATIC (index.html + imgs/) are env-overridable so the same
# app code can host multiple independent studies (e.g. the 4-style study and a photo-only study).
DATA = Path(os.environ.get("STUDY_DATA", str(HERE.parent / "data")))
STATIC = Path(os.environ.get("STUDY_STATIC", str(HERE / "static")))
DB = DATA / "study.db"
PAIRS = {p["pair_id"]: p for p in json.load(open(DATA / "pairs.json"))}
ADMIN_KEY = os.environ.get("STUDY_ADMIN_KEY", "rsfg-admin")
# Durable storage on Hugging Face Spaces (ephemeral disk): restore the response DB from a private Dataset on
# startup, back it up after every submission. No-op when the env vars are unset (e.g. local runs).
PERSIST_REPO = os.environ.get("HF_PERSIST_REPO") # e.g. "bs82/image-study-db"
STUDY_NAME = os.environ.get("STUDY_NAME", "study") # subfolder in the Dataset (one per study)
HF_TOKEN = os.environ.get("HF_TOKEN")
_DB_IN_REPO = f"{STUDY_NAME}/study.db"
def restore_db():
if not (PERSIST_REPO and HF_TOKEN):
return
try:
from huggingface_hub import hf_hub_download
p = hf_hub_download(PERSIST_REPO, _DB_IN_REPO, repo_type="dataset", token=HF_TOKEN)
shutil.copy(p, DB)
print(f"[persist] restored DB from {PERSIST_REPO}:{_DB_IN_REPO}")
except Exception as e:
print(f"[persist] no prior DB ({e}); starting fresh")
def backup_db():
if not (PERSIST_REPO and HF_TOKEN):
return
try:
from huggingface_hub import HfApi
HfApi(token=HF_TOKEN).upload_file(path_or_fileobj=str(DB), path_in_repo=_DB_IN_REPO,
repo_id=PERSIST_REPO, repo_type="dataset",
commit_message=f"backup {STUDY_NAME}")
except Exception as e:
print(f"[persist] backup failed: {e}")
def db():
c = sqlite3.connect(DB); c.row_factory = sqlite3.Row; return c
def init_db():
DATA.mkdir(parents=True, exist_ok=True)
with db() as c:
c.execute("""CREATE TABLE IF NOT EXISTS participants(
token TEXT PRIMARY KEY, email TEXT UNIQUE, created REAL, completed INTEGER DEFAULT 0,
n_catch INTEGER DEFAULT 0, n_catch_pass INTEGER DEFAULT 0, record INTEGER DEFAULT 1)""")
c.execute("""CREATE TABLE IF NOT EXISTS votes(
token TEXT, email TEXT, pair_id INTEGER, kind TEXT, style TEXT, question TEXT,
choice TEXT, guided_chosen INTEGER, ts REAL)""") # guided_chosen: 1 guided, 0 vanilla, NULL tie
restore_db()
init_db()
app = FastAPI(title="Image preference study")
app.mount("/static", StaticFiles(directory=str(STATIC)), name="static")
@app.get("/", response_class=HTMLResponse)
def index():
return (STATIC / "index.html").read_text()
@app.get("/api/info")
def info():
n = len(PAIRS)
return {"n_pairs": n, "est_minutes": max(1, round(n * 3.5 / 60))}
def norm_email(e):
e = (e or "").strip().lower()
if "@" not in e or "." not in e.split("@")[-1] or len(e) < 5:
raise HTTPException(400, "Please enter a valid email address.")
return e
@app.post("/api/register")
async def register(req: Request):
body = await req.json()
email = norm_email(body.get("email"))
record = 0 if body.get("record") is False else 1 # opt-out checkbox; default = record
with db() as c:
row = c.execute("SELECT token, completed FROM participants WHERE email=?", (email,)).fetchone()
if row and row["completed"]:
raise HTTPException(409, "This email has already completed the study. Thank you!")
token = row["token"] if row else uuid.uuid4().hex
if not row:
c.execute("INSERT INTO participants(token,email,created,record) VALUES(?,?,?,?)",
(token, email, time.time(), record))
else:
c.execute("UPDATE participants SET record=? WHERE token=?", (record, token))
order = list(PAIRS.values())
random.Random(hash(token) & 0xffffffff).shuffle(order)
pairs = [{"pair_id": p["pair_id"], "prompt": p["prompt"],
"left": f"/static/imgs/{p['left']}", "right": f"/static/imgs/{p['right']}"} for p in order]
return {"token": token, "pairs": pairs}
def _summarize(votes):
"""Per-participant own results (thank-you page): three-way win/tie/lose vs our method, overall + per style.
win = chose ours, tie = 'cannot decide', lose = chose the other; rates are over ALL comparisons (incl. ties)."""
def tri():
return [0, 0, 0] # [ours_win, tie, other_win]
overall = tri(); by_style = {}
for v in votes:
p = PAIRS.get(int(v.get("pair_id", -1)))
ch = v.get("choice")
if not p or p["kind"] != "real" or ch not in ("left", "right", "cannot"):
continue
k = 1 if ch == "cannot" else (0 if ch == p["guided_side"] else 2)
overall[k] += 1
by_style.setdefault(p["style"], tri())[k] += 1
def pack(t):
n = sum(t)
return {"win": t[0], "tie": t[1], "lose": t[2], "n": n,
"win_rate": (t[0] / n) if n else None,
"tie_rate": (t[1] / n) if n else None,
"lose_rate": (t[2] / n) if n else None}
return {"pref": {**pack(overall), "by_style": {k: pack(t) for k, t in by_style.items()}}}
@app.post("/api/submit")
async def submit(req: Request):
body = await req.json()
token, votes = body.get("token", ""), body.get("votes", [])
with db() as c:
row = c.execute("SELECT email, completed, record FROM participants WHERE token=?", (token,)).fetchone()
if not row:
raise HTTPException(404, "Unknown session.")
if row["completed"]:
raise HTTPException(409, "Already submitted.")
email, record = row["email"], row["record"]
n_catch = n_pass = 0
for v in votes:
p = PAIRS.get(int(v.get("pair_id", -1)))
ch = v.get("choice")
if not p or ch not in ("left", "right", "cannot"):
continue
if p["kind"] == "catch":
gc = int(ch == p["correct_side"]); n_catch += 1; n_pass += gc
elif ch == "cannot":
gc = None # "Cannot decide" → tie / no preference (excluded from win-rate denominator)
else:
gc = int(ch == p["guided_side"])
if record: # opt-out: skip persisting votes (results still shown to the participant)
c.execute("INSERT INTO votes VALUES(?,?,?,?,?,?,?,?,?)",
(token, email, p["pair_id"], p["kind"], p["style"], "pref", ch, gc, time.time()))
c.execute("UPDATE participants SET completed=1, n_catch=?, n_catch_pass=? WHERE token=?",
(n_catch, n_pass, token))
backup_db()
return {"ok": True, "recorded": bool(record), "results": _summarize(votes)}
@app.get("/api/stats")
def stats(key: str = ""):
if key != ADMIN_KEY:
raise HTTPException(403, "bad key")
with db() as c:
np_ = c.execute("SELECT COUNT(*) n FROM participants WHERE completed=1").fetchone()["n"]
# attentive = passed all catch preference trials
attn = "token IN (SELECT token FROM participants WHERE completed=1 AND n_catch_pass=n_catch AND n_catch>0)"
out = {"completed_participants": np_, "by_question": {}}
for q in ("pref", "align"):
rows = c.execute(f"""SELECT style, COUNT(*) n, SUM(guided_chosen) g,
SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties
FROM votes WHERE kind='real' AND question=? AND token IN
(SELECT token FROM participants WHERE completed=1) GROUP BY style""", (q,)).fetchall()
tot = c.execute(f"""SELECT COUNT(*) n, SUM(guided_chosen) g,
SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties
FROM votes WHERE kind='real' AND question=?
AND token IN (SELECT token FROM participants WHERE completed=1)""", (q,)).fetchone()
N = tot["n"] or 0; W = tot["g"] or 0; T = tot["ties"] or 0; L = N - W - T # win / tie / lose
out["by_question"][q] = {
"n": N, "win": W, "tie": T, "lose": L,
"win_rate": (W/N) if N else None, "tie_rate": (T/N) if N else None, "lose_rate": (L/N) if N else None,
"decisive_win_rate": (W/(W+L)) if (W+L) else None, # ties dropped (sign-test denominator)
"by_style": {r["style"]: {"n": r["n"], "win": r["g"] or 0, "tie": r["ties"] or 0,
"lose": (r["n"] - (r["g"] or 0) - (r["ties"] or 0))} for r in rows}}
return out
@app.get("/api/export")
def export(key: str = ""):
if key != ADMIN_KEY:
raise HTTPException(403, "bad key")
buf = io.StringIO(); w = csv.writer(buf)
w.writerow(["email", "pair_id", "kind", "style", "question", "choice", "guided_chosen", "ts"])
with db() as c:
for r in c.execute("SELECT email,pair_id,kind,style,question,choice,guided_chosen,ts FROM votes ORDER BY ts"):
w.writerow([r["email"], r["pair_id"], r["kind"], r["style"], r["question"], r["choice"],
"" if r["guided_chosen"] is None else r["guided_chosen"], r["ts"]])
return PlainTextResponse(buf.getvalue(), media_type="text/csv")