Spaces:
Running
Running
File size: 10,233 Bytes
a9d7edb edb73aa a9d7edb edb73aa a9d7edb edb73aa a9d7edb b74e85b a9d7edb b74e85b a9d7edb edb73aa a9d7edb edb73aa a9d7edb edb73aa a9d7edb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | #!/usr/bin/env python
"""Human-preference study (2AFC pairwise) — SD3.5-Medium ImageReward-guided vs vanilla, 4 styles.
Accepted T2I protocol (Imagen/Parti/Emu/Diffusion-DPO): two questions/pair (overall preference + prompt
alignment), tie option, neutral labels, left/right randomized, gold catch trials, email-gated (one
submission/person). SQLite backend, JSON API + single-page UI.
Run: ./env/bin/python -m uvicorn human_study.app.main:app --host 0.0.0.0 --port 8000
DB: human_study/data/study.db Manifest: human_study/data/pairs.json
Admin: GET /api/stats?key=$STUDY_ADMIN_KEY GET /api/export?key=... (CSV)
"""
import json, os, sqlite3, time, uuid, random, csv, io, shutil
from pathlib import Path
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
HERE = Path(__file__).resolve().parent
# STUDY_DATA (pairs.json + study.db) and STUDY_STATIC (index.html + imgs/) are env-overridable so the same
# app code can host multiple independent studies (e.g. the 4-style study and a photo-only study).
DATA = Path(os.environ.get("STUDY_DATA", str(HERE.parent / "data")))
STATIC = Path(os.environ.get("STUDY_STATIC", str(HERE / "static")))
DB = DATA / "study.db"
PAIRS = {p["pair_id"]: p for p in json.load(open(DATA / "pairs.json"))}
ADMIN_KEY = os.environ.get("STUDY_ADMIN_KEY", "rsfg-admin")
# Durable storage on Hugging Face Spaces (ephemeral disk): restore the response DB from a private Dataset on
# startup, back it up after every submission. No-op when the env vars are unset (e.g. local runs).
PERSIST_REPO = os.environ.get("HF_PERSIST_REPO") # e.g. "bs82/image-study-db"
STUDY_NAME = os.environ.get("STUDY_NAME", "study") # subfolder in the Dataset (one per study)
HF_TOKEN = os.environ.get("HF_TOKEN")
_DB_IN_REPO = f"{STUDY_NAME}/study.db"
def restore_db():
if not (PERSIST_REPO and HF_TOKEN):
return
try:
from huggingface_hub import hf_hub_download
p = hf_hub_download(PERSIST_REPO, _DB_IN_REPO, repo_type="dataset", token=HF_TOKEN)
shutil.copy(p, DB)
print(f"[persist] restored DB from {PERSIST_REPO}:{_DB_IN_REPO}")
except Exception as e:
print(f"[persist] no prior DB ({e}); starting fresh")
def backup_db():
if not (PERSIST_REPO and HF_TOKEN):
return
try:
from huggingface_hub import HfApi
HfApi(token=HF_TOKEN).upload_file(path_or_fileobj=str(DB), path_in_repo=_DB_IN_REPO,
repo_id=PERSIST_REPO, repo_type="dataset",
commit_message=f"backup {STUDY_NAME}")
except Exception as e:
print(f"[persist] backup failed: {e}")
def db():
c = sqlite3.connect(DB); c.row_factory = sqlite3.Row; return c
def init_db():
DATA.mkdir(parents=True, exist_ok=True)
with db() as c:
c.execute("""CREATE TABLE IF NOT EXISTS participants(
token TEXT PRIMARY KEY, email TEXT UNIQUE, created REAL, completed INTEGER DEFAULT 0,
n_catch INTEGER DEFAULT 0, n_catch_pass INTEGER DEFAULT 0, record INTEGER DEFAULT 1)""")
c.execute("""CREATE TABLE IF NOT EXISTS votes(
token TEXT, email TEXT, pair_id INTEGER, kind TEXT, style TEXT, question TEXT,
choice TEXT, guided_chosen INTEGER, ts REAL)""") # guided_chosen: 1 guided, 0 vanilla, NULL tie
restore_db()
init_db()
app = FastAPI(title="Image preference study")
app.mount("/static", StaticFiles(directory=str(STATIC)), name="static")
@app.get("/", response_class=HTMLResponse)
def index():
return (STATIC / "index.html").read_text()
@app.get("/api/info")
def info():
n = len(PAIRS)
return {"n_pairs": n, "est_minutes": max(1, round(n * 3.5 / 60))}
def norm_email(e):
e = (e or "").strip().lower()
if "@" not in e or "." not in e.split("@")[-1] or len(e) < 5:
raise HTTPException(400, "Please enter a valid email address.")
return e
@app.post("/api/register")
async def register(req: Request):
body = await req.json()
email = norm_email(body.get("email"))
record = 0 if body.get("record") is False else 1 # opt-out checkbox; default = record
with db() as c:
row = c.execute("SELECT token, completed FROM participants WHERE email=?", (email,)).fetchone()
if row and row["completed"]:
raise HTTPException(409, "This email has already completed the study. Thank you!")
token = row["token"] if row else uuid.uuid4().hex
if not row:
c.execute("INSERT INTO participants(token,email,created,record) VALUES(?,?,?,?)",
(token, email, time.time(), record))
else:
c.execute("UPDATE participants SET record=? WHERE token=?", (record, token))
order = list(PAIRS.values())
random.Random(hash(token) & 0xffffffff).shuffle(order)
pairs = [{"pair_id": p["pair_id"], "prompt": p["prompt"],
"left": f"/static/imgs/{p['left']}", "right": f"/static/imgs/{p['right']}"} for p in order]
return {"token": token, "pairs": pairs}
def _summarize(votes):
"""Per-participant own results (thank-you page): three-way win/tie/lose vs our method, overall + per style.
win = chose ours, tie = 'cannot decide', lose = chose the other; rates are over ALL comparisons (incl. ties)."""
def tri():
return [0, 0, 0] # [ours_win, tie, other_win]
overall = tri(); by_style = {}
for v in votes:
p = PAIRS.get(int(v.get("pair_id", -1)))
ch = v.get("choice")
if not p or p["kind"] != "real" or ch not in ("left", "right", "cannot"):
continue
k = 1 if ch == "cannot" else (0 if ch == p["guided_side"] else 2)
overall[k] += 1
by_style.setdefault(p["style"], tri())[k] += 1
def pack(t):
n = sum(t)
return {"win": t[0], "tie": t[1], "lose": t[2], "n": n,
"win_rate": (t[0] / n) if n else None,
"tie_rate": (t[1] / n) if n else None,
"lose_rate": (t[2] / n) if n else None}
return {"pref": {**pack(overall), "by_style": {k: pack(t) for k, t in by_style.items()}}}
@app.post("/api/submit")
async def submit(req: Request):
body = await req.json()
token, votes = body.get("token", ""), body.get("votes", [])
with db() as c:
row = c.execute("SELECT email, completed, record FROM participants WHERE token=?", (token,)).fetchone()
if not row:
raise HTTPException(404, "Unknown session.")
if row["completed"]:
raise HTTPException(409, "Already submitted.")
email, record = row["email"], row["record"]
n_catch = n_pass = 0
for v in votes:
p = PAIRS.get(int(v.get("pair_id", -1)))
ch = v.get("choice")
if not p or ch not in ("left", "right", "cannot"):
continue
if p["kind"] == "catch":
gc = int(ch == p["correct_side"]); n_catch += 1; n_pass += gc
elif ch == "cannot":
gc = None # "Cannot decide" → tie / no preference (excluded from win-rate denominator)
else:
gc = int(ch == p["guided_side"])
if record: # opt-out: skip persisting votes (results still shown to the participant)
c.execute("INSERT INTO votes VALUES(?,?,?,?,?,?,?,?,?)",
(token, email, p["pair_id"], p["kind"], p["style"], "pref", ch, gc, time.time()))
c.execute("UPDATE participants SET completed=1, n_catch=?, n_catch_pass=? WHERE token=?",
(n_catch, n_pass, token))
backup_db()
return {"ok": True, "recorded": bool(record), "results": _summarize(votes)}
@app.get("/api/stats")
def stats(key: str = ""):
if key != ADMIN_KEY:
raise HTTPException(403, "bad key")
with db() as c:
np_ = c.execute("SELECT COUNT(*) n FROM participants WHERE completed=1").fetchone()["n"]
# attentive = passed all catch preference trials
attn = "token IN (SELECT token FROM participants WHERE completed=1 AND n_catch_pass=n_catch AND n_catch>0)"
out = {"completed_participants": np_, "by_question": {}}
for q in ("pref", "align"):
rows = c.execute(f"""SELECT style, COUNT(*) n, SUM(guided_chosen) g,
SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties
FROM votes WHERE kind='real' AND question=? AND token IN
(SELECT token FROM participants WHERE completed=1) GROUP BY style""", (q,)).fetchall()
tot = c.execute(f"""SELECT COUNT(*) n, SUM(guided_chosen) g,
SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties
FROM votes WHERE kind='real' AND question=?
AND token IN (SELECT token FROM participants WHERE completed=1)""", (q,)).fetchone()
N = tot["n"] or 0; W = tot["g"] or 0; T = tot["ties"] or 0; L = N - W - T # win / tie / lose
out["by_question"][q] = {
"n": N, "win": W, "tie": T, "lose": L,
"win_rate": (W/N) if N else None, "tie_rate": (T/N) if N else None, "lose_rate": (L/N) if N else None,
"decisive_win_rate": (W/(W+L)) if (W+L) else None, # ties dropped (sign-test denominator)
"by_style": {r["style"]: {"n": r["n"], "win": r["g"] or 0, "tie": r["ties"] or 0,
"lose": (r["n"] - (r["g"] or 0) - (r["ties"] or 0))} for r in rows}}
return out
@app.get("/api/export")
def export(key: str = ""):
if key != ADMIN_KEY:
raise HTTPException(403, "bad key")
buf = io.StringIO(); w = csv.writer(buf)
w.writerow(["email", "pair_id", "kind", "style", "question", "choice", "guided_chosen", "ts"])
with db() as c:
for r in c.execute("SELECT email,pair_id,kind,style,question,choice,guided_chosen,ts FROM votes ORDER BY ts"):
w.writerow([r["email"], r["pair_id"], r["kind"], r["style"], r["question"], r["choice"],
"" if r["guided_chosen"] is None else r["guided_chosen"], r["ts"]])
return PlainTextResponse(buf.getvalue(), media_type="text/csv")
|