lanczos's picture
deploy: labeling server
000a5ee verified
"""SQLite schema + DAO for the annotation backend.
Single-writer model: annotations are append-only, assignments are created up
front. We use stdlib sqlite3 rather than SQLAlchemy to keep the install
footprint small.
"""
from __future__ import annotations
import json
import secrets
import sqlite3
import time
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
SCHEMA = """
CREATE TABLE IF NOT EXISTS items (
item_id TEXT PRIMARY KEY,
payload_json TEXT NOT NULL,
is_gold INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS annotators (
annotator_id TEXT PRIMARY KEY,
token TEXT NOT NULL UNIQUE,
created_at REAL NOT NULL,
cap INTEGER,
email TEXT,
round_number INTEGER NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS assignments (
item_id TEXT NOT NULL,
annotator_id TEXT NOT NULL,
assigned_at REAL NOT NULL,
PRIMARY KEY (item_id, annotator_id),
FOREIGN KEY (item_id) REFERENCES items(item_id),
FOREIGN KEY (annotator_id) REFERENCES annotators(annotator_id)
);
CREATE TABLE IF NOT EXISTS labels (
item_id TEXT NOT NULL,
annotator_id TEXT NOT NULL,
chosen_index INTEGER NOT NULL,
seconds REAL,
confidence INTEGER,
submitted_at REAL NOT NULL,
PRIMARY KEY (item_id, annotator_id),
FOREIGN KEY (item_id) REFERENCES items(item_id),
FOREIGN KEY (annotator_id) REFERENCES annotators(annotator_id)
);
CREATE INDEX IF NOT EXISTS idx_assignments_annotator
ON assignments(annotator_id);
CREATE INDEX IF NOT EXISTS idx_labels_annotator
ON labels(annotator_id);
"""
@dataclass(frozen=True)
class ItemRow:
item_id: str
payload: dict
is_gold: bool
@dataclass(frozen=True)
class LabelRow:
item_id: str
annotator_id: str
chosen_index: int
seconds: float | None
confidence: int | None
submitted_at: float
def connect(db_path: str | Path) -> sqlite3.Connection:
db_path = Path(db_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
# DELETE (default) journal mode keeps the main .sqlite file consistent
# after every write, so simple file-copy backups always reflect the
# latest state. WAL would require an explicit checkpoint before each
# backup, which is a foot-gun we hit in production on HF Spaces.
conn.execute("PRAGMA journal_mode = DELETE")
return conn
def init_schema(conn: sqlite3.Connection) -> None:
conn.executescript(SCHEMA)
# Forward-compatible migrations for older DBs. Each ALTER is wrapped
# in a try/except since sqlite3 throws OperationalError on existing
# columns — we can't check sqlite_master cleanly across SQLite
# versions, so just-try-it is the pragmatic pattern.
for ddl in (
"ALTER TABLE annotators ADD COLUMN cap INTEGER",
"ALTER TABLE annotators ADD COLUMN email TEXT",
"ALTER TABLE annotators ADD COLUMN round_number INTEGER NOT NULL DEFAULT 1",
):
try:
conn.execute(ddl)
except sqlite3.OperationalError:
pass
# CREATE INDEX must run AFTER the ALTER TABLE for email, since the
# SCHEMA's CREATE TABLE is a no-op when the table already exists
# (restored from an older backup) and hence doesn't add the column
# on its own.
conn.execute("CREATE INDEX IF NOT EXISTS idx_annotators_email ON annotators(email)")
def mint_token() -> str:
return secrets.token_urlsafe(16)
def insert_item(conn: sqlite3.Connection, item_id: str, payload: dict, is_gold: bool = False) -> None:
conn.execute(
"INSERT OR REPLACE INTO items(item_id, payload_json, is_gold) VALUES (?, ?, ?)",
(item_id, json.dumps(payload, sort_keys=True), int(is_gold)),
)
def insert_annotator(
conn: sqlite3.Connection,
annotator_id: str,
token: str,
cap: int | None = None,
email: str | None = None,
round_number: int = 1,
) -> None:
conn.execute(
"INSERT OR REPLACE INTO annotators"
"(annotator_id, token, created_at, cap, email, round_number) "
"VALUES (?, ?, ?, ?, ?, ?)",
(annotator_id, token, time.time(), cap, email, round_number),
)
def get_annotator_cap(conn: sqlite3.Connection, annotator_id: str) -> int | None:
row = conn.execute(
"SELECT cap FROM annotators WHERE annotator_id = ?", (annotator_id,)
).fetchone()
return None if row is None or row["cap"] is None else int(row["cap"])
def insert_assignment(conn: sqlite3.Connection, item_id: str, annotator_id: str) -> None:
conn.execute(
"INSERT OR IGNORE INTO assignments(item_id, annotator_id, assigned_at) "
"VALUES (?, ?, ?)",
(item_id, annotator_id, time.time()),
)
def get_annotator_by_token(conn: sqlite3.Connection, token: str) -> str | None:
row = conn.execute(
"SELECT annotator_id FROM annotators WHERE token = ?", (token,)
).fetchone()
return row["annotator_id"] if row else None
def get_item(conn: sqlite3.Connection, item_id: str) -> ItemRow | None:
row = conn.execute(
"SELECT item_id, payload_json, is_gold FROM items WHERE item_id = ?",
(item_id,),
).fetchone()
if not row:
return None
return ItemRow(
item_id=row["item_id"],
payload=json.loads(row["payload_json"]),
is_gold=bool(row["is_gold"]),
)
def next_unlabeled_item(
conn: sqlite3.Connection, annotator_id: str
) -> ItemRow | None:
"""Pre-assigned dispatch: hand out the annotator's next un-labeled assignment."""
row = conn.execute(
"""
SELECT items.item_id, items.payload_json, items.is_gold
FROM assignments
JOIN items ON items.item_id = assignments.item_id
LEFT JOIN labels
ON labels.item_id = assignments.item_id
AND labels.annotator_id = assignments.annotator_id
WHERE assignments.annotator_id = ?
AND labels.item_id IS NULL
ORDER BY assignments.assigned_at ASC
LIMIT 1
""",
(annotator_id,),
).fetchone()
if not row:
return None
return ItemRow(
item_id=row["item_id"],
payload=json.loads(row["payload_json"]),
is_gold=bool(row["is_gold"]),
)
def next_pooled_item(
conn: sqlite3.Connection,
annotator_id: str,
max_labels_per_item: int,
) -> ItemRow | None:
"""Pull-based dispatch: pick any item needing more labels that this
annotator hasn't seen yet. Breadth-first over coverage so every item
gets at least one label before anyone gets a second. Within the same
coverage bucket we tie-break with RANDOM() so two annotators opening
the UI at similar times don't get the same item_id sequence —
otherwise item_id alphabetical order creates stimulus-level bias
correlated with annotator fatigue."""
row = conn.execute(
"""
SELECT items.item_id, items.payload_json, items.is_gold,
COALESCE(counts.n, 0) AS n_labels
FROM items
LEFT JOIN (
SELECT item_id, COUNT(*) AS n
FROM labels
GROUP BY item_id
) AS counts ON counts.item_id = items.item_id
LEFT JOIN labels mine
ON mine.item_id = items.item_id AND mine.annotator_id = ?
WHERE mine.item_id IS NULL
AND COALESCE(counts.n, 0) < ?
ORDER BY n_labels ASC, RANDOM()
LIMIT 1
""",
(annotator_id, max_labels_per_item),
).fetchone()
if not row:
return None
return ItemRow(
item_id=row["item_id"],
payload=json.loads(row["payload_json"]),
is_gold=bool(row["is_gold"]),
)
def get_annotator_row(conn: sqlite3.Connection, annotator_id: str) -> sqlite3.Row | None:
return conn.execute(
"SELECT annotator_id, token, cap, email, round_number, created_at "
"FROM annotators WHERE annotator_id = ?",
(annotator_id,),
).fetchone()
def session_accuracy(
conn: sqlite3.Connection, annotator_id: str
) -> tuple[int, int]:
"""Return (n_correct, n_total) for this annotator's labels.
Compares label.chosen_index to items.payload_json->'$.correct_index'.
Returns (0, 0) if the annotator has no labels yet.
"""
row = conn.execute(
"""
SELECT
COUNT(*) AS n,
SUM(CASE WHEN l.chosen_index =
CAST(json_extract(i.payload_json, '$.correct_index') AS INTEGER)
THEN 1 ELSE 0 END) AS n_correct
FROM labels l
JOIN items i USING(item_id)
WHERE l.annotator_id = ?
""",
(annotator_id,),
).fetchone()
n = int(row["n"] or 0)
n_correct = int(row["n_correct"] or 0)
return n_correct, n
def set_annotator_email(
conn: sqlite3.Connection, annotator_id: str, email: str
) -> None:
conn.execute(
"UPDATE annotators SET email = ? WHERE annotator_id = ?",
(email, annotator_id),
)
def email_passed_rounds(
conn: sqlite3.Connection, email: str, acc_threshold: float, target_cap: int | None = None
) -> set[int]:
"""Return the set of round_numbers for which this email has at least
one annotator that (a) hit cap and (b) has acc >= threshold.
`target_cap` is optional; None = use each annotator's own cap to
decide cap_reached. Passing cap explicitly is useful when the
caller wants to ignore annotators whose session didn't finish.
"""
rows = conn.execute(
"SELECT annotator_id, cap, round_number FROM annotators "
"WHERE email = ?",
(email,),
).fetchall()
passed: set[int] = set()
for r in rows:
n_correct, n = session_accuracy(conn, r["annotator_id"])
cap = r["cap"] if target_cap is None else target_cap
if cap is None or n < cap:
continue
if n == 0:
continue
if (n_correct / n) >= acc_threshold:
passed.add(int(r["round_number"]))
return passed
def email_passed_label_count(
conn: sqlite3.Connection, email: str, acc_threshold: float
) -> int:
"""Total labels across all of this email's annotators whose session
(a) hit their cap and (b) has acc >= threshold. Each label = 1
lottery entry."""
rows = conn.execute(
"SELECT annotator_id, cap FROM annotators WHERE email = ?", (email,)
).fetchall()
total = 0
for r in rows:
n_correct, n = session_accuracy(conn, r["annotator_id"])
cap = r["cap"]
if cap is None or n < cap or n == 0:
continue
if (n_correct / n) >= acc_threshold:
total += n
return total
def count_annotator_labels(conn: sqlite3.Connection, annotator_id: str) -> int:
return int(conn.execute(
"SELECT COUNT(*) AS n FROM labels WHERE annotator_id = ?",
(annotator_id,),
).fetchone()["n"])
def record_label(
conn: sqlite3.Connection,
item_id: str,
annotator_id: str,
chosen_index: int,
seconds: float | None,
confidence: int | None,
) -> None:
conn.execute(
"""
INSERT OR REPLACE INTO labels
(item_id, annotator_id, chosen_index, seconds, confidence, submitted_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(item_id, annotator_id, chosen_index, seconds, confidence, time.time()),
)
def progress(conn: sqlite3.Connection, annotator_id: str) -> dict[str, int]:
assigned = conn.execute(
"SELECT COUNT(*) AS n FROM assignments WHERE annotator_id = ?",
(annotator_id,),
).fetchone()["n"]
labeled = conn.execute(
"SELECT COUNT(*) AS n FROM labels WHERE annotator_id = ?",
(annotator_id,),
).fetchone()["n"]
return {"assigned": int(assigned), "labeled": int(labeled)}
def iter_labels(conn: sqlite3.Connection) -> Iterator[LabelRow]:
for row in conn.execute(
"SELECT item_id, annotator_id, chosen_index, seconds, confidence, submitted_at FROM labels"
):
yield LabelRow(
item_id=row["item_id"],
annotator_id=row["annotator_id"],
chosen_index=int(row["chosen_index"]),
seconds=row["seconds"],
confidence=row["confidence"],
submitted_at=float(row["submitted_at"]),
)
@contextmanager
def open_db(db_path: str | Path):
conn = connect(db_path)
try:
init_schema(conn)
yield conn
finally:
conn.close()