File size: 18,753 Bytes
871ff87
 
 
 
acdd723
871ff87
b8cd5c3
871ff87
b8cd5c3
 
871ff87
 
b8cd5c3
871ff87
 
 
 
 
 
 
b8cd5c3
000a5ee
 
 
 
b8cd5c3
 
 
 
 
 
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8cd5c3
 
 
 
 
871ff87
 
 
 
 
 
 
 
acdd723
b8cd5c3
 
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acdd723
 
 
 
b8cd5c3
 
 
 
 
 
 
 
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
acdd723
b8cd5c3
 
 
871ff87
1e767b9
 
 
 
 
 
 
 
 
 
 
871ff87
 
 
 
 
 
 
 
 
 
 
 
083fb75
 
 
 
b8cd5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871ff87
083fb75
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
083fb75
b8cd5c3
083fb75
acdd723
b8cd5c3
 
083fb75
 
 
 
 
 
 
acdd723
b8cd5c3
 
 
 
 
 
 
acdd723
 
083fb75
871ff87
 
acdd723
 
 
 
b8cd5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871ff87
 
 
 
 
 
 
 
 
 
 
b8cd5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8cd5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000a5ee
b8cd5c3
000a5ee
 
 
b8cd5c3
 
000a5ee
 
 
b8cd5c3
000a5ee
 
 
 
 
 
 
b8cd5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
000a5ee
b8cd5c3
000a5ee
b8cd5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000a5ee
b8cd5c3
 
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
"""FastAPI app serving the annotator web UI + /api/task, /api/label, /api/progress."""

from __future__ import annotations

import hmac
import os
import re
import sqlite3
import time
from collections import deque
from pathlib import Path

from fastapi import Depends, FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, conint

from aamcq.annotation import db as dbmod
from aamcq.annotation.assignment import bootstrap_annotators

DEFAULT_ACC_THRESHOLD = 0.40
# Round 1 (cap=20) + up to 3 bonus rounds (cap=10 each) = 50 labels max per
# annotator. Each label that lives in a passing session counts as one
# lottery entry; more labels = better odds.
MAX_LOTTERY_ROUND = 4
_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")


def _is_valid_email(email: str) -> bool:
    return bool(email) and len(email) <= 254 and bool(_EMAIL_RE.match(email))

REPO_ROOT = Path(__file__).resolve().parents[3]
DEFAULT_DB = REPO_ROOT / "data" / "annotations.sqlite"
DEFAULT_IMAGE_DIR = REPO_ROOT / "data" / "images"
DEFAULT_STATIC_DIR = REPO_ROOT / "labeling" / "static"


class LabelPayload(BaseModel):
    token: str = Field(min_length=8, max_length=128)
    item_id: str = Field(min_length=1, max_length=128)
    chosen_index: conint(ge=0, le=3)  # type: ignore[valid-type]
    seconds: float | None = Field(default=None, ge=0, le=3600)
    confidence: int | None = Field(default=None, ge=1, le=5)


def _sanitize_item(payload: dict) -> dict:
    """Strip `correct_index` before sending to annotator."""
    return {k: v for k, v in payload.items() if k != "correct_index"}


class SubmitEmailPayload(BaseModel):
    token: str = Field(min_length=8, max_length=128)
    email: str = Field(min_length=3, max_length=254)


def create_app(
    db_path: str | os.PathLike[str] | None = None,
    image_dir: str | os.PathLike[str] | None = None,
    static_dir: str | os.PathLike[str] | None = None,
    pool_mode: bool = False,
    anonymous_register: bool = False,
    max_labels_per_item: int = 3,
    max_labels_per_annotator: int | None = None,
    access_password: str | None = None,
    acc_threshold: float = DEFAULT_ACC_THRESHOLD,
    register_rate_limit: tuple[int, int] | None = (5, 3600),
) -> FastAPI:
    """Labeling server.

    `pool_mode=False` (default): annotators see only items pre-assigned to them
    (round-robin). Requires bootstrap_annotators() + assign_items_round_robin()
    before serving.

    `pool_mode=True`: ignore pre-assignment; dispatch any item that still needs
    labels and this annotator hasn't labeled. Items are handed out breadth-first
    over existing-label-count β€” every item gets one label before anyone gets a
    second. Unfinished work from one annotator is naturally picked up by the
    next person who logs in. Cap each session with `max_labels_per_annotator`.

    `anonymous_register=True`: `POST /api/register` mints a fresh annotator_id
    + token on demand, so a single public URL can serve any number of
    concurrent anonymous annotators (each browser session = one annotator).
    Intended for public-URL crowdsourcing.

    `access_password`: if set, `/api/register` requires a matching
    `?password=` query param (constant-time compared). Cheap anti-spam
    gate for public Spaces β€” existing tokens keep working regardless.

    `acc_threshold`: session accuracy (vs `correct_index`) required for
    email submission + lottery bonus credit. Defaults to 0.40. At
    p_random=0.25 and n=20, false-positive rate is ~10%, so pair with
    IP rate-limiting + UNIQUE-email-per-round checks.

    `register_rate_limit`: (max_requests, window_seconds) tuple. Defaults
    to (5, 3600) β€” per-IP rolling window. Pass None to disable.
    """
    db_path = Path(db_path or DEFAULT_DB)
    image_dir = Path(image_dir or DEFAULT_IMAGE_DIR)
    static_dir = Path(static_dir or DEFAULT_STATIC_DIR)

    app = FastAPI(title="AestheticMCQ Annotation")
    conn = dbmod.connect(db_path)
    dbmod.init_schema(conn)
    app.state.conn = conn
    app.state.image_dir = image_dir
    app.state.pool_mode = pool_mode
    app.state.anonymous_register = anonymous_register
    app.state.max_labels_per_item = max_labels_per_item
    app.state.max_labels_per_annotator = max_labels_per_annotator
    app.state.access_password = access_password
    app.state.acc_threshold = float(acc_threshold)
    app.state.register_rate_limit = register_rate_limit
    app.state.register_hits: dict[str, deque[float]] = {}

    @app.middleware("http")
    async def _deny_framing(request, call_next):
        # Block any browser from rendering us inside an iframe. The HF
        # Spaces outer page (huggingface.co/spaces/...) embeds us that
        # way and its script-load cycle double-fires our password
        # prompt. Users should visit the direct *.hf.space URL instead.
        response = await call_next(request)
        response.headers["X-Frame-Options"] = "DENY"
        response.headers["Content-Security-Policy"] = "frame-ancestors 'none'"
        return response

    def get_conn() -> sqlite3.Connection:
        return app.state.conn

    def resolve_annotator(
        token: str,
        conn: sqlite3.Connection = Depends(get_conn),
    ) -> str:
        annotator_id = dbmod.get_annotator_by_token(conn, token)
        if not annotator_id:
            raise HTTPException(status_code=401, detail="invalid token")
        return annotator_id

    def _effective_cap(conn: sqlite3.Connection, annotator_id: str) -> int | None:
        per = dbmod.get_annotator_cap(conn, annotator_id)
        return per if per is not None else app.state.max_labels_per_annotator

    def _client_ip(request: Request) -> str:
        # HF Spaces sit behind a proxy β€” take the first entry of
        # X-Forwarded-For if present, else the direct peer address.
        xff = request.headers.get("x-forwarded-for")
        if xff:
            return xff.split(",")[0].strip()
        return request.client.host if request.client else "unknown"

    def _check_rate_limit(ip: str) -> None:
        limit = app.state.register_rate_limit
        if limit is None:
            return
        max_req, window = limit
        now = time.time()
        q = app.state.register_hits.setdefault(ip, deque())
        while q and q[0] < now - window:
            q.popleft()
        if len(q) >= max_req:
            raise HTTPException(
                status_code=429,
                detail=f"too many register requests (limit {max_req}/{window}s)",
            )
        q.append(now)

    def _next_task_payload(annotator_id: str, conn: sqlite3.Connection, n_done: int) -> dict:
        cap = _effective_cap(conn, annotator_id)
        if cap is not None and n_done >= cap:
            return {"done": True, "reason": "cap_reached", "labeled": n_done, "cap": cap}
        if app.state.pool_mode:
            item = dbmod.next_pooled_item(conn, annotator_id, app.state.max_labels_per_item)
        else:
            item = dbmod.next_unlabeled_item(conn, annotator_id)
        if item is None:
            return {"done": True, "reason": "pool_empty", "labeled": n_done}
        return {
            "done": False,
            "item_id": item.item_id,
            "payload": _sanitize_item(item.payload),
            "image_url": f"/images/{item.item_id}.png",
            "labeled": n_done,
            "cap": cap,
        }

    @app.post("/api/register")
    def api_register(
        request: Request,
        cap: int | None = Query(default=None, ge=1, le=10000),
        password: str | None = Query(default=None, max_length=256),
        email: str | None = Query(default=None, max_length=254),
        round: int = Query(default=1, ge=1, le=MAX_LOTTERY_ROUND),
        conn: sqlite3.Connection = Depends(get_conn),
    ):
        """Mint a fresh anonymous annotator. Only enabled when anonymous_register.

        Optional `?cap=N` sets a per-annotator label cap that overrides the
        server default (used by the frontend to give the first session a
        larger quota than subsequent ones).

        `?email=` + `?round=N` attach this new annotator to an existing
        email-based chain for lottery-multiplier rounds (round 2/3). For
        round > 1 we verify the prior round was passed by an earlier
        annotator with the same email; otherwise reject. round==1 can
        include email but is rarer (email is normally attached later
        via /api/submit_email once session acc is known).

        If `access_password` was set at startup, `?password=` must match
        (constant-time compared) or we return 403.
        """
        if not app.state.anonymous_register:
            raise HTTPException(status_code=404, detail="anonymous register disabled")
        expected = app.state.access_password
        if expected:
            if not password or not hmac.compare_digest(password, expected):
                raise HTTPException(status_code=403, detail="wrong access password")

        _check_rate_limit(_client_ip(request))

        if email is not None:
            if not _is_valid_email(email):
                raise HTTPException(status_code=400, detail="bad email format")
            # round > 1 must chain off a PASSED prior round for this email.
            if round > 1:
                passed = dbmod.email_passed_rounds(
                    conn, email, app.state.acc_threshold
                )
                if (round - 1) not in passed:
                    raise HTTPException(
                        status_code=400,
                        detail=f"round {round} requires passed round {round-1} for this email",
                    )
                if round in passed:
                    raise HTTPException(
                        status_code=400,
                        detail=f"round {round} already passed for this email",
                    )

        existing = {row["annotator_id"] for row in conn.execute(
            "SELECT annotator_id FROM annotators"
        )}
        n = 0
        while True:
            candidate = f"anon_{dbmod.mint_token()[:10]}"
            if candidate not in existing:
                break
            n += 1
            if n > 8:
                raise HTTPException(status_code=500, detail="could not mint unique id")

        # bootstrap_annotators wasn't set up for email/round; inline the
        # same effect so we can pass those through.
        token = dbmod.mint_token()
        dbmod.insert_annotator(
            conn, candidate, token, cap=cap, email=email, round_number=round
        )
        return {
            "annotator_id": candidate,
            "token": token,
            "cap": cap,
            "email": email,
            "round_number": round,
        }

    @app.get("/api/task")
    def api_task(
        token: str = Query(min_length=8, max_length=128),
        conn: sqlite3.Connection = Depends(get_conn),
    ):
        annotator_id = resolve_annotator(token, conn)
        n_done = dbmod.count_annotator_labels(conn, annotator_id)
        return _next_task_payload(annotator_id, conn, n_done)

    @app.post("/api/label")
    def api_label(
        payload: LabelPayload,
        conn: sqlite3.Connection = Depends(get_conn),
    ):
        annotator_id = resolve_annotator(payload.token, conn)
        item_row = dbmod.get_item(conn, payload.item_id)
        if item_row is None:
            raise HTTPException(status_code=404, detail="unknown item_id")
        if not app.state.pool_mode:
            # Pre-assigned mode: require an assignment row.
            assigned = conn.execute(
                "SELECT 1 FROM assignments WHERE item_id = ? AND annotator_id = ? LIMIT 1",
                (payload.item_id, annotator_id),
            ).fetchone()
            if assigned is None:
                raise HTTPException(status_code=403, detail="item not assigned to annotator")
        dbmod.record_label(
            conn,
            payload.item_id,
            annotator_id,
            int(payload.chosen_index),
            payload.seconds,
            payload.confidence,
        )
        return {"ok": True}

    @app.get("/api/session_status")
    def api_session_status(
        token: str = Query(min_length=8, max_length=128),
        conn: sqlite3.Connection = Depends(get_conn),
    ):
        """Tell the frontend what the done-page should render.

        Returns pass/fail flags β€” never the raw accuracy number β€” so
        annotators can't binary-search the threshold by reloading.
        """
        annotator_id = resolve_annotator(token, conn)
        row = dbmod.get_annotator_row(conn, annotator_id)
        if row is None:
            raise HTTPException(status_code=401, detail="invalid token")
        cap = _effective_cap(conn, annotator_id)
        n_correct, n = dbmod.session_accuracy(conn, annotator_id)
        cap_reached = cap is not None and n >= cap
        acc_pass = cap_reached and n > 0 and (n_correct / n) >= app.state.acc_threshold

        # Lottery state: look at ALL annotators for this email (if any)
        # to compute the current multiplier + whether more rounds are
        # available.
        email = row["email"]
        labels_in_lottery = 0
        if email:
            # Count labels across ALL this email's passing annotators;
            # each label is one lottery entry.
            labels_in_lottery = dbmod.email_passed_label_count(
                conn, email, app.state.acc_threshold
            )
            # Include THIS session optimistically if it just passed β€” the
            # helper's scan excludes uncommitted state, and we want the
            # UI to reflect the new total immediately.
            if acc_pass:
                # Avoid double-counting if this annotator's already
                # persisted labels would have been counted.
                passed_rounds = dbmod.email_passed_rounds(
                    conn, email, app.state.acc_threshold
                )
                if int(row["round_number"]) not in passed_rounds:
                    labels_in_lottery += int(n)

        can_extend = (
            acc_pass
            and email is not None
            and int(row["round_number"]) < MAX_LOTTERY_ROUND
        )

        return {
            "cap_reached": bool(cap_reached),
            "n_labeled": int(n),
            "cap": cap,
            "acc_pass": bool(acc_pass),
            "round_number": int(row["round_number"]),
            "email": email,
            "labels_in_lottery": int(labels_in_lottery),
            "can_extend": bool(can_extend),
            "next_round_cap": 10,
        }

    @app.post("/api/submit_email")
    def api_submit_email(
        payload: SubmitEmailPayload,
        conn: sqlite3.Connection = Depends(get_conn),
    ):
        """Attach the email to a round-1 annotator whose session passed.

        Guards:
          - Annotator exists (valid token).
          - round_number == 1 (email is introduced on round 1 only).
          - Current email is NULL (can't overwrite once set).
          - Session complete: n_labeled >= cap.
          - acc >= threshold.
          - Email hasn't already passed round 1 via another annotator
            (prevents multiple lottery entries per email).
        """
        annotator_id = resolve_annotator(payload.token, conn)
        row = dbmod.get_annotator_row(conn, annotator_id)
        if row is None:
            raise HTTPException(status_code=401, detail="invalid token")
        if int(row["round_number"]) != 1:
            raise HTTPException(
                status_code=400, detail="email only submitted on round 1"
            )
        if row["email"]:
            raise HTTPException(
                status_code=409, detail="email already set for this annotator"
            )
        if not _is_valid_email(payload.email):
            raise HTTPException(status_code=400, detail="bad email format")

        cap = _effective_cap(conn, annotator_id)
        n_correct, n = dbmod.session_accuracy(conn, annotator_id)
        if cap is None or n < cap:
            raise HTTPException(
                status_code=400,
                detail=f"session not complete ({n}/{cap})",
            )
        if n == 0 or (n_correct / n) < app.state.acc_threshold:
            raise HTTPException(
                status_code=403,
                detail=f"accuracy below threshold ({app.state.acc_threshold})",
            )

        # Prevent double-credit: this email must not have a passed
        # round 1 on another annotator.
        if 1 in dbmod.email_passed_rounds(conn, payload.email, app.state.acc_threshold):
            raise HTTPException(
                status_code=409,
                detail="email already credited for round 1",
            )

        dbmod.set_annotator_email(conn, annotator_id, payload.email)
        return {
            "ok": True,
            "email": payload.email,
            "round_number": 1,
            "labels_in_lottery": int(n),
        }

    @app.get("/api/progress")
    def api_progress(
        token: str = Query(min_length=8, max_length=128),
        conn: sqlite3.Connection = Depends(get_conn),
    ):
        annotator_id = resolve_annotator(token, conn)
        n_done = dbmod.count_annotator_labels(conn, annotator_id)
        if app.state.pool_mode:
            cap = app.state.max_labels_per_annotator
            return {
                "labeled": n_done,
                "assigned": cap if cap is not None else 0,
            }
        return dbmod.progress(conn, annotator_id)

    @app.get("/images/{item_id}.png")
    def serve_image(item_id: str):
        # Defense against path traversal β€” allow only [A-Za-z0-9_-.] in item_id.
        if not item_id or any(c not in _ALLOWED_ITEM_CHARS for c in item_id):
            raise HTTPException(status_code=400, detail="bad item_id")
        path = (app.state.image_dir / f"{item_id}.png").resolve()
        if app.state.image_dir.resolve() not in path.parents:
            raise HTTPException(status_code=400, detail="bad path")
        if not path.exists():
            raise HTTPException(status_code=404, detail="image missing")
        return FileResponse(path, media_type="image/png")

    @app.get("/healthz")
    def healthz():
        return {"ok": True}

    if static_dir.exists():
        app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
    else:
        @app.get("/")
        def root():
            return JSONResponse({"detail": "static dir missing; /api/* still usable"})

    return app


_ALLOWED_ITEM_CHARS = frozenset(
    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-."
)