Spaces:
Sleeping
Sleeping
| """FastAPI app serving the annotator web UI + /api/task, /api/label, /api/progress.""" | |
| from __future__ import annotations | |
| import hmac | |
| import os | |
| import re | |
| import sqlite3 | |
| import time | |
| from collections import deque | |
| from pathlib import Path | |
| from fastapi import Depends, FastAPI, HTTPException, Query, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field, conint | |
| from aamcq.annotation import db as dbmod | |
| from aamcq.annotation.assignment import bootstrap_annotators | |
| DEFAULT_ACC_THRESHOLD = 0.40 | |
| # Round 1 (cap=20) + up to 3 bonus rounds (cap=10 each) = 50 labels max per | |
| # annotator. Each label that lives in a passing session counts as one | |
| # lottery entry; more labels = better odds. | |
| MAX_LOTTERY_ROUND = 4 | |
| _EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$") | |
| def _is_valid_email(email: str) -> bool: | |
| return bool(email) and len(email) <= 254 and bool(_EMAIL_RE.match(email)) | |
| REPO_ROOT = Path(__file__).resolve().parents[3] | |
| DEFAULT_DB = REPO_ROOT / "data" / "annotations.sqlite" | |
| DEFAULT_IMAGE_DIR = REPO_ROOT / "data" / "images" | |
| DEFAULT_STATIC_DIR = REPO_ROOT / "labeling" / "static" | |
| class LabelPayload(BaseModel): | |
| token: str = Field(min_length=8, max_length=128) | |
| item_id: str = Field(min_length=1, max_length=128) | |
| chosen_index: conint(ge=0, le=3) # type: ignore[valid-type] | |
| seconds: float | None = Field(default=None, ge=0, le=3600) | |
| confidence: int | None = Field(default=None, ge=1, le=5) | |
| def _sanitize_item(payload: dict) -> dict: | |
| """Strip `correct_index` before sending to annotator.""" | |
| return {k: v for k, v in payload.items() if k != "correct_index"} | |
| class SubmitEmailPayload(BaseModel): | |
| token: str = Field(min_length=8, max_length=128) | |
| email: str = Field(min_length=3, max_length=254) | |
| def create_app( | |
| db_path: str | os.PathLike[str] | None = None, | |
| image_dir: str | os.PathLike[str] | None = None, | |
| static_dir: str | os.PathLike[str] | None = None, | |
| pool_mode: bool = False, | |
| anonymous_register: bool = False, | |
| max_labels_per_item: int = 3, | |
| max_labels_per_annotator: int | None = None, | |
| access_password: str | None = None, | |
| acc_threshold: float = DEFAULT_ACC_THRESHOLD, | |
| register_rate_limit: tuple[int, int] | None = (5, 3600), | |
| ) -> FastAPI: | |
| """Labeling server. | |
| `pool_mode=False` (default): annotators see only items pre-assigned to them | |
| (round-robin). Requires bootstrap_annotators() + assign_items_round_robin() | |
| before serving. | |
| `pool_mode=True`: ignore pre-assignment; dispatch any item that still needs | |
| labels and this annotator hasn't labeled. Items are handed out breadth-first | |
| over existing-label-count — every item gets one label before anyone gets a | |
| second. Unfinished work from one annotator is naturally picked up by the | |
| next person who logs in. Cap each session with `max_labels_per_annotator`. | |
| `anonymous_register=True`: `POST /api/register` mints a fresh annotator_id | |
| + token on demand, so a single public URL can serve any number of | |
| concurrent anonymous annotators (each browser session = one annotator). | |
| Intended for public-URL crowdsourcing. | |
| `access_password`: if set, `/api/register` requires a matching | |
| `?password=` query param (constant-time compared). Cheap anti-spam | |
| gate for public Spaces — existing tokens keep working regardless. | |
| `acc_threshold`: session accuracy (vs `correct_index`) required for | |
| email submission + lottery bonus credit. Defaults to 0.40. At | |
| p_random=0.25 and n=20, false-positive rate is ~10%, so pair with | |
| IP rate-limiting + UNIQUE-email-per-round checks. | |
| `register_rate_limit`: (max_requests, window_seconds) tuple. Defaults | |
| to (5, 3600) — per-IP rolling window. Pass None to disable. | |
| """ | |
| db_path = Path(db_path or DEFAULT_DB) | |
| image_dir = Path(image_dir or DEFAULT_IMAGE_DIR) | |
| static_dir = Path(static_dir or DEFAULT_STATIC_DIR) | |
| app = FastAPI(title="AestheticMCQ Annotation") | |
| conn = dbmod.connect(db_path) | |
| dbmod.init_schema(conn) | |
| app.state.conn = conn | |
| app.state.image_dir = image_dir | |
| app.state.pool_mode = pool_mode | |
| app.state.anonymous_register = anonymous_register | |
| app.state.max_labels_per_item = max_labels_per_item | |
| app.state.max_labels_per_annotator = max_labels_per_annotator | |
| app.state.access_password = access_password | |
| app.state.acc_threshold = float(acc_threshold) | |
| app.state.register_rate_limit = register_rate_limit | |
| app.state.register_hits: dict[str, deque[float]] = {} | |
| async def _deny_framing(request, call_next): | |
| # Block any browser from rendering us inside an iframe. The HF | |
| # Spaces outer page (huggingface.co/spaces/...) embeds us that | |
| # way and its script-load cycle double-fires our password | |
| # prompt. Users should visit the direct *.hf.space URL instead. | |
| response = await call_next(request) | |
| response.headers["X-Frame-Options"] = "DENY" | |
| response.headers["Content-Security-Policy"] = "frame-ancestors 'none'" | |
| return response | |
| def get_conn() -> sqlite3.Connection: | |
| return app.state.conn | |
| def resolve_annotator( | |
| token: str, | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ) -> str: | |
| annotator_id = dbmod.get_annotator_by_token(conn, token) | |
| if not annotator_id: | |
| raise HTTPException(status_code=401, detail="invalid token") | |
| return annotator_id | |
| def _effective_cap(conn: sqlite3.Connection, annotator_id: str) -> int | None: | |
| per = dbmod.get_annotator_cap(conn, annotator_id) | |
| return per if per is not None else app.state.max_labels_per_annotator | |
| def _client_ip(request: Request) -> str: | |
| # HF Spaces sit behind a proxy — take the first entry of | |
| # X-Forwarded-For if present, else the direct peer address. | |
| xff = request.headers.get("x-forwarded-for") | |
| if xff: | |
| return xff.split(",")[0].strip() | |
| return request.client.host if request.client else "unknown" | |
| def _check_rate_limit(ip: str) -> None: | |
| limit = app.state.register_rate_limit | |
| if limit is None: | |
| return | |
| max_req, window = limit | |
| now = time.time() | |
| q = app.state.register_hits.setdefault(ip, deque()) | |
| while q and q[0] < now - window: | |
| q.popleft() | |
| if len(q) >= max_req: | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"too many register requests (limit {max_req}/{window}s)", | |
| ) | |
| q.append(now) | |
| def _next_task_payload(annotator_id: str, conn: sqlite3.Connection, n_done: int) -> dict: | |
| cap = _effective_cap(conn, annotator_id) | |
| if cap is not None and n_done >= cap: | |
| return {"done": True, "reason": "cap_reached", "labeled": n_done, "cap": cap} | |
| if app.state.pool_mode: | |
| item = dbmod.next_pooled_item(conn, annotator_id, app.state.max_labels_per_item) | |
| else: | |
| item = dbmod.next_unlabeled_item(conn, annotator_id) | |
| if item is None: | |
| return {"done": True, "reason": "pool_empty", "labeled": n_done} | |
| return { | |
| "done": False, | |
| "item_id": item.item_id, | |
| "payload": _sanitize_item(item.payload), | |
| "image_url": f"/images/{item.item_id}.png", | |
| "labeled": n_done, | |
| "cap": cap, | |
| } | |
| def api_register( | |
| request: Request, | |
| cap: int | None = Query(default=None, ge=1, le=10000), | |
| password: str | None = Query(default=None, max_length=256), | |
| email: str | None = Query(default=None, max_length=254), | |
| round: int = Query(default=1, ge=1, le=MAX_LOTTERY_ROUND), | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ): | |
| """Mint a fresh anonymous annotator. Only enabled when anonymous_register. | |
| Optional `?cap=N` sets a per-annotator label cap that overrides the | |
| server default (used by the frontend to give the first session a | |
| larger quota than subsequent ones). | |
| `?email=` + `?round=N` attach this new annotator to an existing | |
| email-based chain for lottery-multiplier rounds (round 2/3). For | |
| round > 1 we verify the prior round was passed by an earlier | |
| annotator with the same email; otherwise reject. round==1 can | |
| include email but is rarer (email is normally attached later | |
| via /api/submit_email once session acc is known). | |
| If `access_password` was set at startup, `?password=` must match | |
| (constant-time compared) or we return 403. | |
| """ | |
| if not app.state.anonymous_register: | |
| raise HTTPException(status_code=404, detail="anonymous register disabled") | |
| expected = app.state.access_password | |
| if expected: | |
| if not password or not hmac.compare_digest(password, expected): | |
| raise HTTPException(status_code=403, detail="wrong access password") | |
| _check_rate_limit(_client_ip(request)) | |
| if email is not None: | |
| if not _is_valid_email(email): | |
| raise HTTPException(status_code=400, detail="bad email format") | |
| # round > 1 must chain off a PASSED prior round for this email. | |
| if round > 1: | |
| passed = dbmod.email_passed_rounds( | |
| conn, email, app.state.acc_threshold | |
| ) | |
| if (round - 1) not in passed: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"round {round} requires passed round {round-1} for this email", | |
| ) | |
| if round in passed: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"round {round} already passed for this email", | |
| ) | |
| existing = {row["annotator_id"] for row in conn.execute( | |
| "SELECT annotator_id FROM annotators" | |
| )} | |
| n = 0 | |
| while True: | |
| candidate = f"anon_{dbmod.mint_token()[:10]}" | |
| if candidate not in existing: | |
| break | |
| n += 1 | |
| if n > 8: | |
| raise HTTPException(status_code=500, detail="could not mint unique id") | |
| # bootstrap_annotators wasn't set up for email/round; inline the | |
| # same effect so we can pass those through. | |
| token = dbmod.mint_token() | |
| dbmod.insert_annotator( | |
| conn, candidate, token, cap=cap, email=email, round_number=round | |
| ) | |
| return { | |
| "annotator_id": candidate, | |
| "token": token, | |
| "cap": cap, | |
| "email": email, | |
| "round_number": round, | |
| } | |
| def api_task( | |
| token: str = Query(min_length=8, max_length=128), | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ): | |
| annotator_id = resolve_annotator(token, conn) | |
| n_done = dbmod.count_annotator_labels(conn, annotator_id) | |
| return _next_task_payload(annotator_id, conn, n_done) | |
| def api_label( | |
| payload: LabelPayload, | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ): | |
| annotator_id = resolve_annotator(payload.token, conn) | |
| item_row = dbmod.get_item(conn, payload.item_id) | |
| if item_row is None: | |
| raise HTTPException(status_code=404, detail="unknown item_id") | |
| if not app.state.pool_mode: | |
| # Pre-assigned mode: require an assignment row. | |
| assigned = conn.execute( | |
| "SELECT 1 FROM assignments WHERE item_id = ? AND annotator_id = ? LIMIT 1", | |
| (payload.item_id, annotator_id), | |
| ).fetchone() | |
| if assigned is None: | |
| raise HTTPException(status_code=403, detail="item not assigned to annotator") | |
| dbmod.record_label( | |
| conn, | |
| payload.item_id, | |
| annotator_id, | |
| int(payload.chosen_index), | |
| payload.seconds, | |
| payload.confidence, | |
| ) | |
| return {"ok": True} | |
| def api_session_status( | |
| token: str = Query(min_length=8, max_length=128), | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ): | |
| """Tell the frontend what the done-page should render. | |
| Returns pass/fail flags — never the raw accuracy number — so | |
| annotators can't binary-search the threshold by reloading. | |
| """ | |
| annotator_id = resolve_annotator(token, conn) | |
| row = dbmod.get_annotator_row(conn, annotator_id) | |
| if row is None: | |
| raise HTTPException(status_code=401, detail="invalid token") | |
| cap = _effective_cap(conn, annotator_id) | |
| n_correct, n = dbmod.session_accuracy(conn, annotator_id) | |
| cap_reached = cap is not None and n >= cap | |
| acc_pass = cap_reached and n > 0 and (n_correct / n) >= app.state.acc_threshold | |
| # Lottery state: look at ALL annotators for this email (if any) | |
| # to compute the current multiplier + whether more rounds are | |
| # available. | |
| email = row["email"] | |
| labels_in_lottery = 0 | |
| if email: | |
| # Count labels across ALL this email's passing annotators; | |
| # each label is one lottery entry. | |
| labels_in_lottery = dbmod.email_passed_label_count( | |
| conn, email, app.state.acc_threshold | |
| ) | |
| # Include THIS session optimistically if it just passed — the | |
| # helper's scan excludes uncommitted state, and we want the | |
| # UI to reflect the new total immediately. | |
| if acc_pass: | |
| # Avoid double-counting if this annotator's already | |
| # persisted labels would have been counted. | |
| passed_rounds = dbmod.email_passed_rounds( | |
| conn, email, app.state.acc_threshold | |
| ) | |
| if int(row["round_number"]) not in passed_rounds: | |
| labels_in_lottery += int(n) | |
| can_extend = ( | |
| acc_pass | |
| and email is not None | |
| and int(row["round_number"]) < MAX_LOTTERY_ROUND | |
| ) | |
| return { | |
| "cap_reached": bool(cap_reached), | |
| "n_labeled": int(n), | |
| "cap": cap, | |
| "acc_pass": bool(acc_pass), | |
| "round_number": int(row["round_number"]), | |
| "email": email, | |
| "labels_in_lottery": int(labels_in_lottery), | |
| "can_extend": bool(can_extend), | |
| "next_round_cap": 10, | |
| } | |
| def api_submit_email( | |
| payload: SubmitEmailPayload, | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ): | |
| """Attach the email to a round-1 annotator whose session passed. | |
| Guards: | |
| - Annotator exists (valid token). | |
| - round_number == 1 (email is introduced on round 1 only). | |
| - Current email is NULL (can't overwrite once set). | |
| - Session complete: n_labeled >= cap. | |
| - acc >= threshold. | |
| - Email hasn't already passed round 1 via another annotator | |
| (prevents multiple lottery entries per email). | |
| """ | |
| annotator_id = resolve_annotator(payload.token, conn) | |
| row = dbmod.get_annotator_row(conn, annotator_id) | |
| if row is None: | |
| raise HTTPException(status_code=401, detail="invalid token") | |
| if int(row["round_number"]) != 1: | |
| raise HTTPException( | |
| status_code=400, detail="email only submitted on round 1" | |
| ) | |
| if row["email"]: | |
| raise HTTPException( | |
| status_code=409, detail="email already set for this annotator" | |
| ) | |
| if not _is_valid_email(payload.email): | |
| raise HTTPException(status_code=400, detail="bad email format") | |
| cap = _effective_cap(conn, annotator_id) | |
| n_correct, n = dbmod.session_accuracy(conn, annotator_id) | |
| if cap is None or n < cap: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"session not complete ({n}/{cap})", | |
| ) | |
| if n == 0 or (n_correct / n) < app.state.acc_threshold: | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"accuracy below threshold ({app.state.acc_threshold})", | |
| ) | |
| # Prevent double-credit: this email must not have a passed | |
| # round 1 on another annotator. | |
| if 1 in dbmod.email_passed_rounds(conn, payload.email, app.state.acc_threshold): | |
| raise HTTPException( | |
| status_code=409, | |
| detail="email already credited for round 1", | |
| ) | |
| dbmod.set_annotator_email(conn, annotator_id, payload.email) | |
| return { | |
| "ok": True, | |
| "email": payload.email, | |
| "round_number": 1, | |
| "labels_in_lottery": int(n), | |
| } | |
| def api_progress( | |
| token: str = Query(min_length=8, max_length=128), | |
| conn: sqlite3.Connection = Depends(get_conn), | |
| ): | |
| annotator_id = resolve_annotator(token, conn) | |
| n_done = dbmod.count_annotator_labels(conn, annotator_id) | |
| if app.state.pool_mode: | |
| cap = app.state.max_labels_per_annotator | |
| return { | |
| "labeled": n_done, | |
| "assigned": cap if cap is not None else 0, | |
| } | |
| return dbmod.progress(conn, annotator_id) | |
| def serve_image(item_id: str): | |
| # Defense against path traversal — allow only [A-Za-z0-9_-.] in item_id. | |
| if not item_id or any(c not in _ALLOWED_ITEM_CHARS for c in item_id): | |
| raise HTTPException(status_code=400, detail="bad item_id") | |
| path = (app.state.image_dir / f"{item_id}.png").resolve() | |
| if app.state.image_dir.resolve() not in path.parents: | |
| raise HTTPException(status_code=400, detail="bad path") | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail="image missing") | |
| return FileResponse(path, media_type="image/png") | |
| def healthz(): | |
| return {"ok": True} | |
| if static_dir.exists(): | |
| app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static") | |
| else: | |
| def root(): | |
| return JSONResponse({"detail": "static dir missing; /api/* still usable"}) | |
| return app | |
| _ALLOWED_ITEM_CHARS = frozenset( | |
| "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-." | |
| ) | |