File size: 3,363 Bytes
aceb1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Overlap sampling for heterogeneous annotator coverage.

Given the ``num_annotators_per_item.overlap_sample`` config block, selects a
deterministic fraction of items to receive a raised annotator cap. The cap
override is written onto each sampled item's metadata so the centralized
``ItemStateManager._get_annotator_cap_for_item`` helper picks it up.

This module is called once after all items have been loaded.
"""

from __future__ import annotations

from collections import defaultdict
from typing import Dict, List, Optional

import logging
import random as _random

logger = logging.getLogger(__name__)


_METADATA_KEY = "required_annotations"


def apply_overlap_sample(item_state_manager, config: dict) -> Dict[str, int]:
    """
    Stamp ``required_annotations`` on a deterministic sample of items.

    Returns a mapping of sampled instance_id -> assigned cap, for reporting.
    Items whose existing metadata already carries ``required_annotations``
    (e.g., from a previous load with persisted state) are not overwritten.
    """
    nap = config.get("num_annotators_per_item")
    if not isinstance(nap, dict):
        return {}
    overlap = nap.get("overlap_sample")
    if not overlap:
        return {}

    fraction = float(overlap["fraction"])
    count = int(overlap["count"])
    stratify_by = overlap.get("stratify_by")
    seed = int(overlap.get("seed", item_state_manager.random_seed))
    rng = _random.Random(seed)

    all_ids = list(item_state_manager.instance_id_to_instance.keys())
    if not all_ids:
        return {}

    # Build strata
    strata: Dict[Optional[str], List[str]] = defaultdict(list)
    if stratify_by:
        for iid in all_ids:
            item = item_state_manager.instance_id_to_instance[iid]
            data = item.get_data() if hasattr(item, "get_data") else {}
            key = data.get(stratify_by) if isinstance(data, dict) else None
            # Also accept the indexed category if it matches
            if key is None and hasattr(item_state_manager, "instance_id_to_categories"):
                cats = item_state_manager.instance_id_to_categories.get(iid)
                if cats:
                    key = sorted(cats)[0]
            strata[key if key is not None else "__uncategorized__"].append(iid)
    else:
        strata[None] = list(all_ids)

    sampled: Dict[str, int] = {}
    for key, ids in strata.items():
        if not ids:
            continue
        # Deterministic ordering across runs
        ids_sorted = sorted(ids)
        rng_local = _random.Random(f"{seed}:{key}" if key is not None else seed)
        rng_local.shuffle(ids_sorted)
        target = max(1, int(round(len(ids_sorted) * fraction)))
        for iid in ids_sorted[:target]:
            item = item_state_manager.instance_id_to_instance[iid]
            # Don't clobber an existing per-item override (operator may have
            # set one via item_data; respect that authority).
            if item.get_metadata(_METADATA_KEY) is not None:
                continue
            item.add_metadata(_METADATA_KEY, count)
            sampled[iid] = count

    if sampled:
        logger.info(
            "Overlap sample: %d / %d items raised to %d annotators (fraction=%s, stratify_by=%s, seed=%d)",
            len(sampled), len(all_ids), count, fraction, stratify_by, seed,
        )
    return sampled