codebook / potato /server_utils /overlap_sampler.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
3.36 kB
"""
Overlap sampling for heterogeneous annotator coverage.
Given the ``num_annotators_per_item.overlap_sample`` config block, selects a
deterministic fraction of items to receive a raised annotator cap. The cap
override is written onto each sampled item's metadata so the centralized
``ItemStateManager._get_annotator_cap_for_item`` helper picks it up.
This module is called once after all items have been loaded.
"""
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List, Optional
import logging
import random as _random
logger = logging.getLogger(__name__)
_METADATA_KEY = "required_annotations"
def apply_overlap_sample(item_state_manager, config: dict) -> Dict[str, int]:
"""
Stamp ``required_annotations`` on a deterministic sample of items.
Returns a mapping of sampled instance_id -> assigned cap, for reporting.
Items whose existing metadata already carries ``required_annotations``
(e.g., from a previous load with persisted state) are not overwritten.
"""
nap = config.get("num_annotators_per_item")
if not isinstance(nap, dict):
return {}
overlap = nap.get("overlap_sample")
if not overlap:
return {}
fraction = float(overlap["fraction"])
count = int(overlap["count"])
stratify_by = overlap.get("stratify_by")
seed = int(overlap.get("seed", item_state_manager.random_seed))
rng = _random.Random(seed)
all_ids = list(item_state_manager.instance_id_to_instance.keys())
if not all_ids:
return {}
# Build strata
strata: Dict[Optional[str], List[str]] = defaultdict(list)
if stratify_by:
for iid in all_ids:
item = item_state_manager.instance_id_to_instance[iid]
data = item.get_data() if hasattr(item, "get_data") else {}
key = data.get(stratify_by) if isinstance(data, dict) else None
# Also accept the indexed category if it matches
if key is None and hasattr(item_state_manager, "instance_id_to_categories"):
cats = item_state_manager.instance_id_to_categories.get(iid)
if cats:
key = sorted(cats)[0]
strata[key if key is not None else "__uncategorized__"].append(iid)
else:
strata[None] = list(all_ids)
sampled: Dict[str, int] = {}
for key, ids in strata.items():
if not ids:
continue
# Deterministic ordering across runs
ids_sorted = sorted(ids)
rng_local = _random.Random(f"{seed}:{key}" if key is not None else seed)
rng_local.shuffle(ids_sorted)
target = max(1, int(round(len(ids_sorted) * fraction)))
for iid in ids_sorted[:target]:
item = item_state_manager.instance_id_to_instance[iid]
# Don't clobber an existing per-item override (operator may have
# set one via item_data; respect that authority).
if item.get_metadata(_METADATA_KEY) is not None:
continue
item.add_metadata(_METADATA_KEY, count)
sampled[iid] = count
if sampled:
logger.info(
"Overlap sample: %d / %d items raised to %d annotators (fraction=%s, stratify_by=%s, seed=%d)",
len(sampled), len(all_ids), count, fraction, stratify_by, seed,
)
return sampled