File size: 10,233 Bytes
a9d7edb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edb73aa
 
 
 
 
a9d7edb
 
 
edb73aa
a9d7edb
edb73aa
 
 
 
 
 
 
 
 
 
a9d7edb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b74e85b
a9d7edb
 
 
b74e85b
 
a9d7edb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edb73aa
 
 
a9d7edb
edb73aa
a9d7edb
edb73aa
 
 
 
 
a9d7edb
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#!/usr/bin/env python
"""Human-preference study (2AFC pairwise) — SD3.5-Medium ImageReward-guided vs vanilla, 4 styles.
Accepted T2I protocol (Imagen/Parti/Emu/Diffusion-DPO): two questions/pair (overall preference + prompt
alignment), tie option, neutral labels, left/right randomized, gold catch trials, email-gated (one
submission/person). SQLite backend, JSON API + single-page UI.

Run:  ./env/bin/python -m uvicorn human_study.app.main:app --host 0.0.0.0 --port 8000
DB:   human_study/data/study.db    Manifest: human_study/data/pairs.json
Admin: GET /api/stats?key=$STUDY_ADMIN_KEY   GET /api/export?key=...  (CSV)
"""
import json, os, sqlite3, time, uuid, random, csv, io, shutil
from pathlib import Path
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles

HERE = Path(__file__).resolve().parent
# STUDY_DATA (pairs.json + study.db) and STUDY_STATIC (index.html + imgs/) are env-overridable so the same
# app code can host multiple independent studies (e.g. the 4-style study and a photo-only study).
DATA = Path(os.environ.get("STUDY_DATA", str(HERE.parent / "data")))
STATIC = Path(os.environ.get("STUDY_STATIC", str(HERE / "static")))
DB = DATA / "study.db"
PAIRS = {p["pair_id"]: p for p in json.load(open(DATA / "pairs.json"))}
ADMIN_KEY = os.environ.get("STUDY_ADMIN_KEY", "rsfg-admin")

# Durable storage on Hugging Face Spaces (ephemeral disk): restore the response DB from a private Dataset on
# startup, back it up after every submission. No-op when the env vars are unset (e.g. local runs).
PERSIST_REPO = os.environ.get("HF_PERSIST_REPO")          # e.g. "bs82/image-study-db"
STUDY_NAME = os.environ.get("STUDY_NAME", "study")        # subfolder in the Dataset (one per study)
HF_TOKEN = os.environ.get("HF_TOKEN")
_DB_IN_REPO = f"{STUDY_NAME}/study.db"

def restore_db():
    if not (PERSIST_REPO and HF_TOKEN):
        return
    try:
        from huggingface_hub import hf_hub_download
        p = hf_hub_download(PERSIST_REPO, _DB_IN_REPO, repo_type="dataset", token=HF_TOKEN)
        shutil.copy(p, DB)
        print(f"[persist] restored DB from {PERSIST_REPO}:{_DB_IN_REPO}")
    except Exception as e:
        print(f"[persist] no prior DB ({e}); starting fresh")

def backup_db():
    if not (PERSIST_REPO and HF_TOKEN):
        return
    try:
        from huggingface_hub import HfApi
        HfApi(token=HF_TOKEN).upload_file(path_or_fileobj=str(DB), path_in_repo=_DB_IN_REPO,
                                          repo_id=PERSIST_REPO, repo_type="dataset",
                                          commit_message=f"backup {STUDY_NAME}")
    except Exception as e:
        print(f"[persist] backup failed: {e}")

def db():
    c = sqlite3.connect(DB); c.row_factory = sqlite3.Row; return c

def init_db():
    DATA.mkdir(parents=True, exist_ok=True)
    with db() as c:
        c.execute("""CREATE TABLE IF NOT EXISTS participants(
            token TEXT PRIMARY KEY, email TEXT UNIQUE, created REAL, completed INTEGER DEFAULT 0,
            n_catch INTEGER DEFAULT 0, n_catch_pass INTEGER DEFAULT 0, record INTEGER DEFAULT 1)""")
        c.execute("""CREATE TABLE IF NOT EXISTS votes(
            token TEXT, email TEXT, pair_id INTEGER, kind TEXT, style TEXT, question TEXT,
            choice TEXT, guided_chosen INTEGER, ts REAL)""")  # guided_chosen: 1 guided, 0 vanilla, NULL tie
restore_db()
init_db()

app = FastAPI(title="Image preference study")
app.mount("/static", StaticFiles(directory=str(STATIC)), name="static")

@app.get("/", response_class=HTMLResponse)
def index():
    return (STATIC / "index.html").read_text()

@app.get("/api/info")
def info():
    n = len(PAIRS)
    return {"n_pairs": n, "est_minutes": max(1, round(n * 3.5 / 60))}

def norm_email(e):
    e = (e or "").strip().lower()
    if "@" not in e or "." not in e.split("@")[-1] or len(e) < 5:
        raise HTTPException(400, "Please enter a valid email address.")
    return e

@app.post("/api/register")
async def register(req: Request):
    body = await req.json()
    email = norm_email(body.get("email"))
    record = 0 if body.get("record") is False else 1   # opt-out checkbox; default = record
    with db() as c:
        row = c.execute("SELECT token, completed FROM participants WHERE email=?", (email,)).fetchone()
        if row and row["completed"]:
            raise HTTPException(409, "This email has already completed the study. Thank you!")
        token = row["token"] if row else uuid.uuid4().hex
        if not row:
            c.execute("INSERT INTO participants(token,email,created,record) VALUES(?,?,?,?)",
                      (token, email, time.time(), record))
        else:
            c.execute("UPDATE participants SET record=? WHERE token=?", (record, token))
    order = list(PAIRS.values())
    random.Random(hash(token) & 0xffffffff).shuffle(order)
    pairs = [{"pair_id": p["pair_id"], "prompt": p["prompt"],
              "left": f"/static/imgs/{p['left']}", "right": f"/static/imgs/{p['right']}"} for p in order]
    return {"token": token, "pairs": pairs}

def _summarize(votes):
    """Per-participant own results (thank-you page): three-way win/tie/lose vs our method, overall + per style.
    win = chose ours, tie = 'cannot decide', lose = chose the other; rates are over ALL comparisons (incl. ties)."""
    def tri():
        return [0, 0, 0]   # [ours_win, tie, other_win]
    overall = tri(); by_style = {}
    for v in votes:
        p = PAIRS.get(int(v.get("pair_id", -1)))
        ch = v.get("choice")
        if not p or p["kind"] != "real" or ch not in ("left", "right", "cannot"):
            continue
        k = 1 if ch == "cannot" else (0 if ch == p["guided_side"] else 2)
        overall[k] += 1
        by_style.setdefault(p["style"], tri())[k] += 1
    def pack(t):
        n = sum(t)
        return {"win": t[0], "tie": t[1], "lose": t[2], "n": n,
                "win_rate": (t[0] / n) if n else None,
                "tie_rate": (t[1] / n) if n else None,
                "lose_rate": (t[2] / n) if n else None}
    return {"pref": {**pack(overall), "by_style": {k: pack(t) for k, t in by_style.items()}}}

@app.post("/api/submit")
async def submit(req: Request):
    body = await req.json()
    token, votes = body.get("token", ""), body.get("votes", [])
    with db() as c:
        row = c.execute("SELECT email, completed, record FROM participants WHERE token=?", (token,)).fetchone()
        if not row:
            raise HTTPException(404, "Unknown session.")
        if row["completed"]:
            raise HTTPException(409, "Already submitted.")
        email, record = row["email"], row["record"]
        n_catch = n_pass = 0
        for v in votes:
            p = PAIRS.get(int(v.get("pair_id", -1)))
            ch = v.get("choice")
            if not p or ch not in ("left", "right", "cannot"):
                continue
            if p["kind"] == "catch":
                gc = int(ch == p["correct_side"]); n_catch += 1; n_pass += gc
            elif ch == "cannot":
                gc = None   # "Cannot decide" → tie / no preference (excluded from win-rate denominator)
            else:
                gc = int(ch == p["guided_side"])
            if record:   # opt-out: skip persisting votes (results still shown to the participant)
                c.execute("INSERT INTO votes VALUES(?,?,?,?,?,?,?,?,?)",
                          (token, email, p["pair_id"], p["kind"], p["style"], "pref", ch, gc, time.time()))
        c.execute("UPDATE participants SET completed=1, n_catch=?, n_catch_pass=? WHERE token=?",
                  (n_catch, n_pass, token))
    backup_db()
    return {"ok": True, "recorded": bool(record), "results": _summarize(votes)}

@app.get("/api/stats")
def stats(key: str = ""):
    if key != ADMIN_KEY:
        raise HTTPException(403, "bad key")
    with db() as c:
        np_ = c.execute("SELECT COUNT(*) n FROM participants WHERE completed=1").fetchone()["n"]
        # attentive = passed all catch preference trials
        attn = "token IN (SELECT token FROM participants WHERE completed=1 AND n_catch_pass=n_catch AND n_catch>0)"
        out = {"completed_participants": np_, "by_question": {}}
        for q in ("pref", "align"):
            rows = c.execute(f"""SELECT style, COUNT(*) n, SUM(guided_chosen) g,
                SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties
                FROM votes WHERE kind='real' AND question=? AND token IN
                (SELECT token FROM participants WHERE completed=1) GROUP BY style""", (q,)).fetchall()
            tot = c.execute(f"""SELECT COUNT(*) n, SUM(guided_chosen) g,
                SUM(CASE WHEN guided_chosen IS NULL THEN 1 ELSE 0 END) ties
                FROM votes WHERE kind='real' AND question=?
                AND token IN (SELECT token FROM participants WHERE completed=1)""", (q,)).fetchone()
            N = tot["n"] or 0; W = tot["g"] or 0; T = tot["ties"] or 0; L = N - W - T   # win / tie / lose
            out["by_question"][q] = {
                "n": N, "win": W, "tie": T, "lose": L,
                "win_rate": (W/N) if N else None, "tie_rate": (T/N) if N else None, "lose_rate": (L/N) if N else None,
                "decisive_win_rate": (W/(W+L)) if (W+L) else None,   # ties dropped (sign-test denominator)
                "by_style": {r["style"]: {"n": r["n"], "win": r["g"] or 0, "tie": r["ties"] or 0,
                                          "lose": (r["n"] - (r["g"] or 0) - (r["ties"] or 0))} for r in rows}}
    return out

@app.get("/api/export")
def export(key: str = ""):
    if key != ADMIN_KEY:
        raise HTTPException(403, "bad key")
    buf = io.StringIO(); w = csv.writer(buf)
    w.writerow(["email", "pair_id", "kind", "style", "question", "choice", "guided_chosen", "ts"])
    with db() as c:
        for r in c.execute("SELECT email,pair_id,kind,style,question,choice,guided_chosen,ts FROM votes ORDER BY ts"):
            w.writerow([r["email"], r["pair_id"], r["kind"], r["style"], r["question"], r["choice"],
                        "" if r["guided_chosen"] is None else r["guided_chosen"], r["ts"]])
    return PlainTextResponse(buf.getvalue(), media_type="text/csv")