Spaces:
Sleeping
Sleeping
| """SQLite schema + DAO for the annotation backend. | |
| Single-writer model: annotations are append-only, assignments are created up | |
| front. We use stdlib sqlite3 rather than SQLAlchemy to keep the install | |
| footprint small. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import secrets | |
| import sqlite3 | |
| import time | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterator | |
| SCHEMA = """ | |
| CREATE TABLE IF NOT EXISTS items ( | |
| item_id TEXT PRIMARY KEY, | |
| payload_json TEXT NOT NULL, | |
| is_gold INTEGER NOT NULL DEFAULT 0 | |
| ); | |
| CREATE TABLE IF NOT EXISTS annotators ( | |
| annotator_id TEXT PRIMARY KEY, | |
| token TEXT NOT NULL UNIQUE, | |
| created_at REAL NOT NULL, | |
| cap INTEGER, | |
| email TEXT, | |
| round_number INTEGER NOT NULL DEFAULT 1 | |
| ); | |
| CREATE TABLE IF NOT EXISTS assignments ( | |
| item_id TEXT NOT NULL, | |
| annotator_id TEXT NOT NULL, | |
| assigned_at REAL NOT NULL, | |
| PRIMARY KEY (item_id, annotator_id), | |
| FOREIGN KEY (item_id) REFERENCES items(item_id), | |
| FOREIGN KEY (annotator_id) REFERENCES annotators(annotator_id) | |
| ); | |
| CREATE TABLE IF NOT EXISTS labels ( | |
| item_id TEXT NOT NULL, | |
| annotator_id TEXT NOT NULL, | |
| chosen_index INTEGER NOT NULL, | |
| seconds REAL, | |
| confidence INTEGER, | |
| submitted_at REAL NOT NULL, | |
| PRIMARY KEY (item_id, annotator_id), | |
| FOREIGN KEY (item_id) REFERENCES items(item_id), | |
| FOREIGN KEY (annotator_id) REFERENCES annotators(annotator_id) | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_assignments_annotator | |
| ON assignments(annotator_id); | |
| CREATE INDEX IF NOT EXISTS idx_labels_annotator | |
| ON labels(annotator_id); | |
| """ | |
| class ItemRow: | |
| item_id: str | |
| payload: dict | |
| is_gold: bool | |
| class LabelRow: | |
| item_id: str | |
| annotator_id: str | |
| chosen_index: int | |
| seconds: float | None | |
| confidence: int | None | |
| submitted_at: float | |
| def connect(db_path: str | Path) -> sqlite3.Connection: | |
| db_path = Path(db_path) | |
| db_path.parent.mkdir(parents=True, exist_ok=True) | |
| conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| # DELETE (default) journal mode keeps the main .sqlite file consistent | |
| # after every write, so simple file-copy backups always reflect the | |
| # latest state. WAL would require an explicit checkpoint before each | |
| # backup, which is a foot-gun we hit in production on HF Spaces. | |
| conn.execute("PRAGMA journal_mode = DELETE") | |
| return conn | |
| def init_schema(conn: sqlite3.Connection) -> None: | |
| conn.executescript(SCHEMA) | |
| # Forward-compatible migrations for older DBs. Each ALTER is wrapped | |
| # in a try/except since sqlite3 throws OperationalError on existing | |
| # columns — we can't check sqlite_master cleanly across SQLite | |
| # versions, so just-try-it is the pragmatic pattern. | |
| for ddl in ( | |
| "ALTER TABLE annotators ADD COLUMN cap INTEGER", | |
| "ALTER TABLE annotators ADD COLUMN email TEXT", | |
| "ALTER TABLE annotators ADD COLUMN round_number INTEGER NOT NULL DEFAULT 1", | |
| ): | |
| try: | |
| conn.execute(ddl) | |
| except sqlite3.OperationalError: | |
| pass | |
| # CREATE INDEX must run AFTER the ALTER TABLE for email, since the | |
| # SCHEMA's CREATE TABLE is a no-op when the table already exists | |
| # (restored from an older backup) and hence doesn't add the column | |
| # on its own. | |
| conn.execute("CREATE INDEX IF NOT EXISTS idx_annotators_email ON annotators(email)") | |
| def mint_token() -> str: | |
| return secrets.token_urlsafe(16) | |
| def insert_item(conn: sqlite3.Connection, item_id: str, payload: dict, is_gold: bool = False) -> None: | |
| conn.execute( | |
| "INSERT OR REPLACE INTO items(item_id, payload_json, is_gold) VALUES (?, ?, ?)", | |
| (item_id, json.dumps(payload, sort_keys=True), int(is_gold)), | |
| ) | |
| def insert_annotator( | |
| conn: sqlite3.Connection, | |
| annotator_id: str, | |
| token: str, | |
| cap: int | None = None, | |
| email: str | None = None, | |
| round_number: int = 1, | |
| ) -> None: | |
| conn.execute( | |
| "INSERT OR REPLACE INTO annotators" | |
| "(annotator_id, token, created_at, cap, email, round_number) " | |
| "VALUES (?, ?, ?, ?, ?, ?)", | |
| (annotator_id, token, time.time(), cap, email, round_number), | |
| ) | |
| def get_annotator_cap(conn: sqlite3.Connection, annotator_id: str) -> int | None: | |
| row = conn.execute( | |
| "SELECT cap FROM annotators WHERE annotator_id = ?", (annotator_id,) | |
| ).fetchone() | |
| return None if row is None or row["cap"] is None else int(row["cap"]) | |
| def insert_assignment(conn: sqlite3.Connection, item_id: str, annotator_id: str) -> None: | |
| conn.execute( | |
| "INSERT OR IGNORE INTO assignments(item_id, annotator_id, assigned_at) " | |
| "VALUES (?, ?, ?)", | |
| (item_id, annotator_id, time.time()), | |
| ) | |
| def get_annotator_by_token(conn: sqlite3.Connection, token: str) -> str | None: | |
| row = conn.execute( | |
| "SELECT annotator_id FROM annotators WHERE token = ?", (token,) | |
| ).fetchone() | |
| return row["annotator_id"] if row else None | |
| def get_item(conn: sqlite3.Connection, item_id: str) -> ItemRow | None: | |
| row = conn.execute( | |
| "SELECT item_id, payload_json, is_gold FROM items WHERE item_id = ?", | |
| (item_id,), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| return ItemRow( | |
| item_id=row["item_id"], | |
| payload=json.loads(row["payload_json"]), | |
| is_gold=bool(row["is_gold"]), | |
| ) | |
| def next_unlabeled_item( | |
| conn: sqlite3.Connection, annotator_id: str | |
| ) -> ItemRow | None: | |
| """Pre-assigned dispatch: hand out the annotator's next un-labeled assignment.""" | |
| row = conn.execute( | |
| """ | |
| SELECT items.item_id, items.payload_json, items.is_gold | |
| FROM assignments | |
| JOIN items ON items.item_id = assignments.item_id | |
| LEFT JOIN labels | |
| ON labels.item_id = assignments.item_id | |
| AND labels.annotator_id = assignments.annotator_id | |
| WHERE assignments.annotator_id = ? | |
| AND labels.item_id IS NULL | |
| ORDER BY assignments.assigned_at ASC | |
| LIMIT 1 | |
| """, | |
| (annotator_id,), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| return ItemRow( | |
| item_id=row["item_id"], | |
| payload=json.loads(row["payload_json"]), | |
| is_gold=bool(row["is_gold"]), | |
| ) | |
| def next_pooled_item( | |
| conn: sqlite3.Connection, | |
| annotator_id: str, | |
| max_labels_per_item: int, | |
| ) -> ItemRow | None: | |
| """Pull-based dispatch: pick any item needing more labels that this | |
| annotator hasn't seen yet. Breadth-first over coverage so every item | |
| gets at least one label before anyone gets a second. Within the same | |
| coverage bucket we tie-break with RANDOM() so two annotators opening | |
| the UI at similar times don't get the same item_id sequence — | |
| otherwise item_id alphabetical order creates stimulus-level bias | |
| correlated with annotator fatigue.""" | |
| row = conn.execute( | |
| """ | |
| SELECT items.item_id, items.payload_json, items.is_gold, | |
| COALESCE(counts.n, 0) AS n_labels | |
| FROM items | |
| LEFT JOIN ( | |
| SELECT item_id, COUNT(*) AS n | |
| FROM labels | |
| GROUP BY item_id | |
| ) AS counts ON counts.item_id = items.item_id | |
| LEFT JOIN labels mine | |
| ON mine.item_id = items.item_id AND mine.annotator_id = ? | |
| WHERE mine.item_id IS NULL | |
| AND COALESCE(counts.n, 0) < ? | |
| ORDER BY n_labels ASC, RANDOM() | |
| LIMIT 1 | |
| """, | |
| (annotator_id, max_labels_per_item), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| return ItemRow( | |
| item_id=row["item_id"], | |
| payload=json.loads(row["payload_json"]), | |
| is_gold=bool(row["is_gold"]), | |
| ) | |
| def get_annotator_row(conn: sqlite3.Connection, annotator_id: str) -> sqlite3.Row | None: | |
| return conn.execute( | |
| "SELECT annotator_id, token, cap, email, round_number, created_at " | |
| "FROM annotators WHERE annotator_id = ?", | |
| (annotator_id,), | |
| ).fetchone() | |
| def session_accuracy( | |
| conn: sqlite3.Connection, annotator_id: str | |
| ) -> tuple[int, int]: | |
| """Return (n_correct, n_total) for this annotator's labels. | |
| Compares label.chosen_index to items.payload_json->'$.correct_index'. | |
| Returns (0, 0) if the annotator has no labels yet. | |
| """ | |
| row = conn.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS n, | |
| SUM(CASE WHEN l.chosen_index = | |
| CAST(json_extract(i.payload_json, '$.correct_index') AS INTEGER) | |
| THEN 1 ELSE 0 END) AS n_correct | |
| FROM labels l | |
| JOIN items i USING(item_id) | |
| WHERE l.annotator_id = ? | |
| """, | |
| (annotator_id,), | |
| ).fetchone() | |
| n = int(row["n"] or 0) | |
| n_correct = int(row["n_correct"] or 0) | |
| return n_correct, n | |
| def set_annotator_email( | |
| conn: sqlite3.Connection, annotator_id: str, email: str | |
| ) -> None: | |
| conn.execute( | |
| "UPDATE annotators SET email = ? WHERE annotator_id = ?", | |
| (email, annotator_id), | |
| ) | |
| def email_passed_rounds( | |
| conn: sqlite3.Connection, email: str, acc_threshold: float, target_cap: int | None = None | |
| ) -> set[int]: | |
| """Return the set of round_numbers for which this email has at least | |
| one annotator that (a) hit cap and (b) has acc >= threshold. | |
| `target_cap` is optional; None = use each annotator's own cap to | |
| decide cap_reached. Passing cap explicitly is useful when the | |
| caller wants to ignore annotators whose session didn't finish. | |
| """ | |
| rows = conn.execute( | |
| "SELECT annotator_id, cap, round_number FROM annotators " | |
| "WHERE email = ?", | |
| (email,), | |
| ).fetchall() | |
| passed: set[int] = set() | |
| for r in rows: | |
| n_correct, n = session_accuracy(conn, r["annotator_id"]) | |
| cap = r["cap"] if target_cap is None else target_cap | |
| if cap is None or n < cap: | |
| continue | |
| if n == 0: | |
| continue | |
| if (n_correct / n) >= acc_threshold: | |
| passed.add(int(r["round_number"])) | |
| return passed | |
| def email_passed_label_count( | |
| conn: sqlite3.Connection, email: str, acc_threshold: float | |
| ) -> int: | |
| """Total labels across all of this email's annotators whose session | |
| (a) hit their cap and (b) has acc >= threshold. Each label = 1 | |
| lottery entry.""" | |
| rows = conn.execute( | |
| "SELECT annotator_id, cap FROM annotators WHERE email = ?", (email,) | |
| ).fetchall() | |
| total = 0 | |
| for r in rows: | |
| n_correct, n = session_accuracy(conn, r["annotator_id"]) | |
| cap = r["cap"] | |
| if cap is None or n < cap or n == 0: | |
| continue | |
| if (n_correct / n) >= acc_threshold: | |
| total += n | |
| return total | |
| def count_annotator_labels(conn: sqlite3.Connection, annotator_id: str) -> int: | |
| return int(conn.execute( | |
| "SELECT COUNT(*) AS n FROM labels WHERE annotator_id = ?", | |
| (annotator_id,), | |
| ).fetchone()["n"]) | |
| def record_label( | |
| conn: sqlite3.Connection, | |
| item_id: str, | |
| annotator_id: str, | |
| chosen_index: int, | |
| seconds: float | None, | |
| confidence: int | None, | |
| ) -> None: | |
| conn.execute( | |
| """ | |
| INSERT OR REPLACE INTO labels | |
| (item_id, annotator_id, chosen_index, seconds, confidence, submitted_at) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, | |
| (item_id, annotator_id, chosen_index, seconds, confidence, time.time()), | |
| ) | |
| def progress(conn: sqlite3.Connection, annotator_id: str) -> dict[str, int]: | |
| assigned = conn.execute( | |
| "SELECT COUNT(*) AS n FROM assignments WHERE annotator_id = ?", | |
| (annotator_id,), | |
| ).fetchone()["n"] | |
| labeled = conn.execute( | |
| "SELECT COUNT(*) AS n FROM labels WHERE annotator_id = ?", | |
| (annotator_id,), | |
| ).fetchone()["n"] | |
| return {"assigned": int(assigned), "labeled": int(labeled)} | |
| def iter_labels(conn: sqlite3.Connection) -> Iterator[LabelRow]: | |
| for row in conn.execute( | |
| "SELECT item_id, annotator_id, chosen_index, seconds, confidence, submitted_at FROM labels" | |
| ): | |
| yield LabelRow( | |
| item_id=row["item_id"], | |
| annotator_id=row["annotator_id"], | |
| chosen_index=int(row["chosen_index"]), | |
| seconds=row["seconds"], | |
| confidence=row["confidence"], | |
| submitted_at=float(row["submitted_at"]), | |
| ) | |
| def open_db(db_path: str | Path): | |
| conn = connect(db_path) | |
| try: | |
| init_schema(conn) | |
| yield conn | |
| finally: | |
| conn.close() | |