lanczos's picture
deploy: labeling server
083fb75 verified
"""Assign MCQ items to annotators with gold-item injection + k-fold coverage."""
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Iterable
from aamcq.annotation import db as dbmod
@dataclass(frozen=True)
class AssignmentPolicy:
labels_per_item: int = 3
gold_injection_rate: float = 0.10
def assign_items_round_robin(
conn,
annotator_ids: list[str],
item_ids: list[str],
gold_item_ids: list[str],
policy: AssignmentPolicy,
rng: random.Random,
) -> dict[str, list[str]]:
"""Create assignments so each non-gold item gets `labels_per_item` distinct annotators.
Gold items are inserted randomly into each annotator's queue at the requested rate.
Returns a mapping {annotator_id: [item_id, ...] in assigned order}.
"""
if not annotator_ids:
raise ValueError("need at least 1 annotator")
if policy.labels_per_item > len(annotator_ids):
raise ValueError(
f"labels_per_item={policy.labels_per_item} > annotators={len(annotator_ids)}"
)
queues: dict[str, list[str]] = {aid: [] for aid in annotator_ids}
shuffled_items = list(item_ids)
rng.shuffle(shuffled_items)
for item_id in shuffled_items:
chosen = rng.sample(annotator_ids, policy.labels_per_item)
for aid in chosen:
queues[aid].append(item_id)
if gold_item_ids and policy.gold_injection_rate > 0:
for aid, queue in queues.items():
n_gold = max(1, int(round(len(queue) * policy.gold_injection_rate)))
gold_pick = rng.choices(gold_item_ids, k=n_gold)
# interleave golds at random positions
for gold_id in gold_pick:
pos = rng.randrange(len(queue) + 1)
queue.insert(pos, gold_id)
for aid, queue in queues.items():
for item_id in queue:
dbmod.insert_assignment(conn, item_id, aid)
return queues
def bootstrap_annotators(
conn, annotator_ids: Iterable[str], cap: int | None = None,
) -> dict[str, str]:
"""Create annotator rows with freshly minted tokens. Returns {annotator_id: token}.
`cap` sets a per-annotator label cap that overrides the server default for
just these annotators (used for anonymous registration to customise the
cap per session).
"""
tokens: dict[str, str] = {}
for aid in annotator_ids:
token = dbmod.mint_token()
dbmod.insert_annotator(conn, aid, token, cap=cap)
tokens[aid] = token
return tokens