File size: 2,547 Bytes
871ff87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
083fb75
871ff87
083fb75
 
 
 
 
 
871ff87
 
 
083fb75
871ff87
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Assign MCQ items to annotators with gold-item injection + k-fold coverage."""

from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Iterable

from aamcq.annotation import db as dbmod


@dataclass(frozen=True)
class AssignmentPolicy:
    labels_per_item: int = 3
    gold_injection_rate: float = 0.10


def assign_items_round_robin(
    conn,
    annotator_ids: list[str],
    item_ids: list[str],
    gold_item_ids: list[str],
    policy: AssignmentPolicy,
    rng: random.Random,
) -> dict[str, list[str]]:
    """Create assignments so each non-gold item gets `labels_per_item` distinct annotators.

    Gold items are inserted randomly into each annotator's queue at the requested rate.
    Returns a mapping {annotator_id: [item_id, ...] in assigned order}.
    """
    if not annotator_ids:
        raise ValueError("need at least 1 annotator")
    if policy.labels_per_item > len(annotator_ids):
        raise ValueError(
            f"labels_per_item={policy.labels_per_item} > annotators={len(annotator_ids)}"
        )

    queues: dict[str, list[str]] = {aid: [] for aid in annotator_ids}

    shuffled_items = list(item_ids)
    rng.shuffle(shuffled_items)
    for item_id in shuffled_items:
        chosen = rng.sample(annotator_ids, policy.labels_per_item)
        for aid in chosen:
            queues[aid].append(item_id)

    if gold_item_ids and policy.gold_injection_rate > 0:
        for aid, queue in queues.items():
            n_gold = max(1, int(round(len(queue) * policy.gold_injection_rate)))
            gold_pick = rng.choices(gold_item_ids, k=n_gold)
            # interleave golds at random positions
            for gold_id in gold_pick:
                pos = rng.randrange(len(queue) + 1)
                queue.insert(pos, gold_id)

    for aid, queue in queues.items():
        for item_id in queue:
            dbmod.insert_assignment(conn, item_id, aid)

    return queues


def bootstrap_annotators(
    conn, annotator_ids: Iterable[str], cap: int | None = None,
) -> dict[str, str]:
    """Create annotator rows with freshly minted tokens. Returns {annotator_id: token}.

    `cap` sets a per-annotator label cap that overrides the server default for
    just these annotators (used for anonymous registration to customise the
    cap per session).
    """
    tokens: dict[str, str] = {}
    for aid in annotator_ids:
        token = dbmod.mint_token()
        dbmod.insert_annotator(conn, aid, token, cap=cap)
        tokens[aid] = token
    return tokens