"""SQLite schema + DAO for the annotation backend. Single-writer model: annotations are append-only, assignments are created up front. We use stdlib sqlite3 rather than SQLAlchemy to keep the install footprint small. """ from __future__ import annotations import json import secrets import sqlite3 import time from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Iterator SCHEMA = """ CREATE TABLE IF NOT EXISTS items ( item_id TEXT PRIMARY KEY, payload_json TEXT NOT NULL, is_gold INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE IF NOT EXISTS annotators ( annotator_id TEXT PRIMARY KEY, token TEXT NOT NULL UNIQUE, created_at REAL NOT NULL, cap INTEGER, email TEXT, round_number INTEGER NOT NULL DEFAULT 1 ); CREATE TABLE IF NOT EXISTS assignments ( item_id TEXT NOT NULL, annotator_id TEXT NOT NULL, assigned_at REAL NOT NULL, PRIMARY KEY (item_id, annotator_id), FOREIGN KEY (item_id) REFERENCES items(item_id), FOREIGN KEY (annotator_id) REFERENCES annotators(annotator_id) ); CREATE TABLE IF NOT EXISTS labels ( item_id TEXT NOT NULL, annotator_id TEXT NOT NULL, chosen_index INTEGER NOT NULL, seconds REAL, confidence INTEGER, submitted_at REAL NOT NULL, PRIMARY KEY (item_id, annotator_id), FOREIGN KEY (item_id) REFERENCES items(item_id), FOREIGN KEY (annotator_id) REFERENCES annotators(annotator_id) ); CREATE INDEX IF NOT EXISTS idx_assignments_annotator ON assignments(annotator_id); CREATE INDEX IF NOT EXISTS idx_labels_annotator ON labels(annotator_id); """ @dataclass(frozen=True) class ItemRow: item_id: str payload: dict is_gold: bool @dataclass(frozen=True) class LabelRow: item_id: str annotator_id: str chosen_index: int seconds: float | None confidence: int | None submitted_at: float def connect(db_path: str | Path) -> sqlite3.Connection: db_path = Path(db_path) db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON") # DELETE (default) journal mode keeps the main .sqlite file consistent # after every write, so simple file-copy backups always reflect the # latest state. WAL would require an explicit checkpoint before each # backup, which is a foot-gun we hit in production on HF Spaces. conn.execute("PRAGMA journal_mode = DELETE") return conn def init_schema(conn: sqlite3.Connection) -> None: conn.executescript(SCHEMA) # Forward-compatible migrations for older DBs. Each ALTER is wrapped # in a try/except since sqlite3 throws OperationalError on existing # columns — we can't check sqlite_master cleanly across SQLite # versions, so just-try-it is the pragmatic pattern. for ddl in ( "ALTER TABLE annotators ADD COLUMN cap INTEGER", "ALTER TABLE annotators ADD COLUMN email TEXT", "ALTER TABLE annotators ADD COLUMN round_number INTEGER NOT NULL DEFAULT 1", ): try: conn.execute(ddl) except sqlite3.OperationalError: pass # CREATE INDEX must run AFTER the ALTER TABLE for email, since the # SCHEMA's CREATE TABLE is a no-op when the table already exists # (restored from an older backup) and hence doesn't add the column # on its own. conn.execute("CREATE INDEX IF NOT EXISTS idx_annotators_email ON annotators(email)") def mint_token() -> str: return secrets.token_urlsafe(16) def insert_item(conn: sqlite3.Connection, item_id: str, payload: dict, is_gold: bool = False) -> None: conn.execute( "INSERT OR REPLACE INTO items(item_id, payload_json, is_gold) VALUES (?, ?, ?)", (item_id, json.dumps(payload, sort_keys=True), int(is_gold)), ) def insert_annotator( conn: sqlite3.Connection, annotator_id: str, token: str, cap: int | None = None, email: str | None = None, round_number: int = 1, ) -> None: conn.execute( "INSERT OR REPLACE INTO annotators" "(annotator_id, token, created_at, cap, email, round_number) " "VALUES (?, ?, ?, ?, ?, ?)", (annotator_id, token, time.time(), cap, email, round_number), ) def get_annotator_cap(conn: sqlite3.Connection, annotator_id: str) -> int | None: row = conn.execute( "SELECT cap FROM annotators WHERE annotator_id = ?", (annotator_id,) ).fetchone() return None if row is None or row["cap"] is None else int(row["cap"]) def insert_assignment(conn: sqlite3.Connection, item_id: str, annotator_id: str) -> None: conn.execute( "INSERT OR IGNORE INTO assignments(item_id, annotator_id, assigned_at) " "VALUES (?, ?, ?)", (item_id, annotator_id, time.time()), ) def get_annotator_by_token(conn: sqlite3.Connection, token: str) -> str | None: row = conn.execute( "SELECT annotator_id FROM annotators WHERE token = ?", (token,) ).fetchone() return row["annotator_id"] if row else None def get_item(conn: sqlite3.Connection, item_id: str) -> ItemRow | None: row = conn.execute( "SELECT item_id, payload_json, is_gold FROM items WHERE item_id = ?", (item_id,), ).fetchone() if not row: return None return ItemRow( item_id=row["item_id"], payload=json.loads(row["payload_json"]), is_gold=bool(row["is_gold"]), ) def next_unlabeled_item( conn: sqlite3.Connection, annotator_id: str ) -> ItemRow | None: """Pre-assigned dispatch: hand out the annotator's next un-labeled assignment.""" row = conn.execute( """ SELECT items.item_id, items.payload_json, items.is_gold FROM assignments JOIN items ON items.item_id = assignments.item_id LEFT JOIN labels ON labels.item_id = assignments.item_id AND labels.annotator_id = assignments.annotator_id WHERE assignments.annotator_id = ? AND labels.item_id IS NULL ORDER BY assignments.assigned_at ASC LIMIT 1 """, (annotator_id,), ).fetchone() if not row: return None return ItemRow( item_id=row["item_id"], payload=json.loads(row["payload_json"]), is_gold=bool(row["is_gold"]), ) def next_pooled_item( conn: sqlite3.Connection, annotator_id: str, max_labels_per_item: int, ) -> ItemRow | None: """Pull-based dispatch: pick any item needing more labels that this annotator hasn't seen yet. Breadth-first over coverage so every item gets at least one label before anyone gets a second. Within the same coverage bucket we tie-break with RANDOM() so two annotators opening the UI at similar times don't get the same item_id sequence — otherwise item_id alphabetical order creates stimulus-level bias correlated with annotator fatigue.""" row = conn.execute( """ SELECT items.item_id, items.payload_json, items.is_gold, COALESCE(counts.n, 0) AS n_labels FROM items LEFT JOIN ( SELECT item_id, COUNT(*) AS n FROM labels GROUP BY item_id ) AS counts ON counts.item_id = items.item_id LEFT JOIN labels mine ON mine.item_id = items.item_id AND mine.annotator_id = ? WHERE mine.item_id IS NULL AND COALESCE(counts.n, 0) < ? ORDER BY n_labels ASC, RANDOM() LIMIT 1 """, (annotator_id, max_labels_per_item), ).fetchone() if not row: return None return ItemRow( item_id=row["item_id"], payload=json.loads(row["payload_json"]), is_gold=bool(row["is_gold"]), ) def get_annotator_row(conn: sqlite3.Connection, annotator_id: str) -> sqlite3.Row | None: return conn.execute( "SELECT annotator_id, token, cap, email, round_number, created_at " "FROM annotators WHERE annotator_id = ?", (annotator_id,), ).fetchone() def session_accuracy( conn: sqlite3.Connection, annotator_id: str ) -> tuple[int, int]: """Return (n_correct, n_total) for this annotator's labels. Compares label.chosen_index to items.payload_json->'$.correct_index'. Returns (0, 0) if the annotator has no labels yet. """ row = conn.execute( """ SELECT COUNT(*) AS n, SUM(CASE WHEN l.chosen_index = CAST(json_extract(i.payload_json, '$.correct_index') AS INTEGER) THEN 1 ELSE 0 END) AS n_correct FROM labels l JOIN items i USING(item_id) WHERE l.annotator_id = ? """, (annotator_id,), ).fetchone() n = int(row["n"] or 0) n_correct = int(row["n_correct"] or 0) return n_correct, n def set_annotator_email( conn: sqlite3.Connection, annotator_id: str, email: str ) -> None: conn.execute( "UPDATE annotators SET email = ? WHERE annotator_id = ?", (email, annotator_id), ) def email_passed_rounds( conn: sqlite3.Connection, email: str, acc_threshold: float, target_cap: int | None = None ) -> set[int]: """Return the set of round_numbers for which this email has at least one annotator that (a) hit cap and (b) has acc >= threshold. `target_cap` is optional; None = use each annotator's own cap to decide cap_reached. Passing cap explicitly is useful when the caller wants to ignore annotators whose session didn't finish. """ rows = conn.execute( "SELECT annotator_id, cap, round_number FROM annotators " "WHERE email = ?", (email,), ).fetchall() passed: set[int] = set() for r in rows: n_correct, n = session_accuracy(conn, r["annotator_id"]) cap = r["cap"] if target_cap is None else target_cap if cap is None or n < cap: continue if n == 0: continue if (n_correct / n) >= acc_threshold: passed.add(int(r["round_number"])) return passed def email_passed_label_count( conn: sqlite3.Connection, email: str, acc_threshold: float ) -> int: """Total labels across all of this email's annotators whose session (a) hit their cap and (b) has acc >= threshold. Each label = 1 lottery entry.""" rows = conn.execute( "SELECT annotator_id, cap FROM annotators WHERE email = ?", (email,) ).fetchall() total = 0 for r in rows: n_correct, n = session_accuracy(conn, r["annotator_id"]) cap = r["cap"] if cap is None or n < cap or n == 0: continue if (n_correct / n) >= acc_threshold: total += n return total def count_annotator_labels(conn: sqlite3.Connection, annotator_id: str) -> int: return int(conn.execute( "SELECT COUNT(*) AS n FROM labels WHERE annotator_id = ?", (annotator_id,), ).fetchone()["n"]) def record_label( conn: sqlite3.Connection, item_id: str, annotator_id: str, chosen_index: int, seconds: float | None, confidence: int | None, ) -> None: conn.execute( """ INSERT OR REPLACE INTO labels (item_id, annotator_id, chosen_index, seconds, confidence, submitted_at) VALUES (?, ?, ?, ?, ?, ?) """, (item_id, annotator_id, chosen_index, seconds, confidence, time.time()), ) def progress(conn: sqlite3.Connection, annotator_id: str) -> dict[str, int]: assigned = conn.execute( "SELECT COUNT(*) AS n FROM assignments WHERE annotator_id = ?", (annotator_id,), ).fetchone()["n"] labeled = conn.execute( "SELECT COUNT(*) AS n FROM labels WHERE annotator_id = ?", (annotator_id,), ).fetchone()["n"] return {"assigned": int(assigned), "labeled": int(labeled)} def iter_labels(conn: sqlite3.Connection) -> Iterator[LabelRow]: for row in conn.execute( "SELECT item_id, annotator_id, chosen_index, seconds, confidence, submitted_at FROM labels" ): yield LabelRow( item_id=row["item_id"], annotator_id=row["annotator_id"], chosen_index=int(row["chosen_index"]), seconds=row["seconds"], confidence=row["confidence"], submitted_at=float(row["submitted_at"]), ) @contextmanager def open_db(db_path: str | Path): conn = connect(db_path) try: init_schema(conn) yield conn finally: conn.close()