"""Assign MCQ items to annotators with gold-item injection + k-fold coverage.""" from __future__ import annotations import random from dataclasses import dataclass from typing import Iterable from aamcq.annotation import db as dbmod @dataclass(frozen=True) class AssignmentPolicy: labels_per_item: int = 3 gold_injection_rate: float = 0.10 def assign_items_round_robin( conn, annotator_ids: list[str], item_ids: list[str], gold_item_ids: list[str], policy: AssignmentPolicy, rng: random.Random, ) -> dict[str, list[str]]: """Create assignments so each non-gold item gets `labels_per_item` distinct annotators. Gold items are inserted randomly into each annotator's queue at the requested rate. Returns a mapping {annotator_id: [item_id, ...] in assigned order}. """ if not annotator_ids: raise ValueError("need at least 1 annotator") if policy.labels_per_item > len(annotator_ids): raise ValueError( f"labels_per_item={policy.labels_per_item} > annotators={len(annotator_ids)}" ) queues: dict[str, list[str]] = {aid: [] for aid in annotator_ids} shuffled_items = list(item_ids) rng.shuffle(shuffled_items) for item_id in shuffled_items: chosen = rng.sample(annotator_ids, policy.labels_per_item) for aid in chosen: queues[aid].append(item_id) if gold_item_ids and policy.gold_injection_rate > 0: for aid, queue in queues.items(): n_gold = max(1, int(round(len(queue) * policy.gold_injection_rate))) gold_pick = rng.choices(gold_item_ids, k=n_gold) # interleave golds at random positions for gold_id in gold_pick: pos = rng.randrange(len(queue) + 1) queue.insert(pos, gold_id) for aid, queue in queues.items(): for item_id in queue: dbmod.insert_assignment(conn, item_id, aid) return queues def bootstrap_annotators( conn, annotator_ids: Iterable[str], cap: int | None = None, ) -> dict[str, str]: """Create annotator rows with freshly minted tokens. Returns {annotator_id: token}. `cap` sets a per-annotator label cap that overrides the server default for just these annotators (used for anonymous registration to customise the cap per session). """ tokens: dict[str, str] = {} for aid in annotator_ids: token = dbmod.mint_token() dbmod.insert_annotator(conn, aid, token, cap=cap) tokens[aid] = token return tokens