codebook / potato /solo_mode /instance_selector.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
17.5 kB
"""
Instance Selector for Solo Mode
This module implements weighted instance selection for human annotation.
It combines multiple signals (LLM confidence, diversity, disagreements, random)
to prioritize which instances the human annotator should see next.
"""
import logging
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
import threading
logger = logging.getLogger(__name__)
@dataclass
class SelectionWeights:
"""Configuration for instance selection weights."""
low_confidence: float = 0.4 # Low LLM confidence instances
diverse: float = 0.3 # Diverse instances (embedding clusters)
random: float = 0.2 # Random sample for calibration
disagreement: float = 0.1 # Instances with prior disagreements
edge_case_rule: float = 0.0 # Instances matching edge case rule patterns
cartography: float = 0.0 # Instances with high confidence variability
llm_predicted: float = 0.0 # Instances with LLM predictions needing human comparison
def validate(self) -> None:
"""Validate that weights sum to 1.0."""
total = (
self.low_confidence +
self.diverse +
self.random +
self.disagreement +
self.edge_case_rule +
self.cartography +
self.llm_predicted
)
if abs(total - 1.0) > 0.001:
# Normalize
self.low_confidence /= total
self.diverse /= total
self.random /= total
self.disagreement /= total
self.edge_case_rule /= total
self.cartography /= total
self.llm_predicted /= total
logger.warning(f"Normalized selection weights (original sum: {total})")
class InstanceSelector:
"""
Weighted instance selector for Solo Mode.
Combines multiple signals to select which instances the human
should annotate, optimizing for efficient use of human labeling time.
Selection pools:
1. Low confidence: Instances where LLM is uncertain
2. Diverse: Instances from different embedding clusters
3. Random: Random sample for calibration
4. Disagreement: Instances with prior human-LLM disagreement
"""
def __init__(
self,
weights: Optional[SelectionWeights] = None,
config: Optional[Dict[str, Any]] = None
):
"""
Initialize the instance selector.
Args:
weights: Selection weight configuration
config: Full application configuration
"""
self.weights = weights or SelectionWeights()
self.weights.validate()
self.config = config or {}
self._lock = threading.RLock()
# Random state
self.random = random.Random()
# Track selection history
self.selection_history: List[Dict[str, Any]] = []
# Pool state
self._low_confidence_pool: List[str] = []
self._diverse_pool: List[str] = []
self._random_pool: List[str] = []
self._disagreement_pool: List[str] = []
self._edge_case_rule_pool: List[str] = []
self._cartography_pool: List[str] = []
self._llm_predicted_pool: List[str] = []
# Cache predictions for use in _select_lowest_confidence
self._predictions_cache: Dict[str, Dict[str, Any]] = {}
def configure(
self,
low_confidence_weight: float = 0.4,
diversity_weight: float = 0.3,
random_weight: float = 0.2,
disagreement_weight: float = 0.1,
edge_case_rule_weight: float = 0.0,
cartography_weight: float = 0.0,
) -> None:
"""Configure selection weights."""
self.weights = SelectionWeights(
low_confidence=low_confidence_weight,
diverse=diversity_weight,
random=random_weight,
disagreement=disagreement_weight,
edge_case_rule=edge_case_rule_weight,
cartography=cartography_weight,
)
self.weights.validate()
def refresh_pools(
self,
available_ids: Set[str],
llm_predictions: Optional[Dict[str, Dict[str, Any]]] = None,
disagreement_ids: Optional[Set[str]] = None,
confidence_threshold: float = 0.5,
edge_case_rule_ids: Optional[Set[str]] = None,
cartography_scores: Optional[Dict[str, float]] = None,
) -> None:
"""
Refresh the selection pools based on current state.
Args:
available_ids: Set of instance IDs available for selection
llm_predictions: Dict of instance_id -> schema -> prediction
disagreement_ids: Set of instance IDs with disagreements
confidence_threshold: Threshold for low confidence pool
edge_case_rule_ids: Set of instance IDs matching edge case rule patterns
cartography_scores: Dict of instance_id -> variability score
"""
with self._lock:
available_list = list(available_ids)
# Cache predictions for use in _select_lowest_confidence
self._predictions_cache = llm_predictions or {}
# Clear pools
self._low_confidence_pool = []
self._diverse_pool = []
self._random_pool = []
self._disagreement_pool = []
self._edge_case_rule_pool = []
self._cartography_pool = []
self._llm_predicted_pool = []
# Build low confidence pool
if llm_predictions:
for instance_id in available_list:
if instance_id in llm_predictions:
preds = llm_predictions[instance_id]
# Check if any prediction is below threshold
for pred in preds.values():
confidence = pred.get('confidence_score', 1.0)
if confidence < confidence_threshold:
self._low_confidence_pool.append(instance_id)
break
# Build LLM-predicted pool: instances with predictions that aren't
# already in the low_confidence pool (those are more valuable there).
# This pool steers human annotations toward instances where a comparison
# with the LLM can happen immediately.
if llm_predictions:
low_conf_set = set(self._low_confidence_pool)
self._llm_predicted_pool = [
iid for iid in available_list
if iid in llm_predictions and iid not in low_conf_set
]
# Build disagreement pool
if disagreement_ids:
self._disagreement_pool = [
iid for iid in available_list
if iid in disagreement_ids
]
# Build edge case rule pool
if edge_case_rule_ids:
self._edge_case_rule_pool = [
iid for iid in available_list
if iid in edge_case_rule_ids
]
# Build cartography pool (high variability = ambiguous instances)
if cartography_scores:
scored = [
(iid, score) for iid, score in cartography_scores.items()
if iid in available_ids and score > 0
]
scored.sort(key=lambda x: x[1], reverse=True)
self._cartography_pool = [iid for iid, _ in scored]
# Diverse pool uses diversity manager if available
self._diverse_pool = self._build_diverse_pool(available_list)
# Random pool is just all available (sampling happens at selection time)
self._random_pool = available_list.copy()
logger.debug(
f"Refreshed pools: low_conf={len(self._low_confidence_pool)}, "
f"llm_predicted={len(self._llm_predicted_pool)}, "
f"diverse={len(self._diverse_pool)}, "
f"random={len(self._random_pool)}, "
f"disagreement={len(self._disagreement_pool)}, "
f"edge_case_rule={len(self._edge_case_rule_pool)}, "
f"cartography={len(self._cartography_pool)}"
)
def _build_diverse_pool(self, available_ids: List[str]) -> List[str]:
"""
Build the diverse instances pool using DiversityManager.
Returns instances ordered by diversity (from different clusters).
"""
try:
from potato.diversity_manager import get_diversity_manager
dm = get_diversity_manager()
if dm is None or not dm.enabled:
return []
# Get diverse ordering from all clusters
diverse = dm.generate_diverse_ordering(
user_id='solo_mode',
available_ids=available_ids,
preserve_ids=set()
)
return diverse
except Exception as e:
logger.debug(f"Could not build diverse pool: {e}")
return []
def select_next(
self,
available_ids: Set[str],
exclude_ids: Optional[Set[str]] = None
) -> Optional[str]:
"""
Select the next instance for human annotation.
Uses weighted random selection across the pools.
Args:
available_ids: Set of available instance IDs
exclude_ids: Set of IDs to exclude from selection
Returns:
Selected instance ID, or None if no instances available
"""
with self._lock:
# Filter pools by available and exclude
exclude = exclude_ids or set()
pools = {
'low_confidence': [
iid for iid in self._low_confidence_pool
if iid in available_ids and iid not in exclude
],
'diverse': [
iid for iid in self._diverse_pool
if iid in available_ids and iid not in exclude
],
'random': [
iid for iid in self._random_pool
if iid in available_ids and iid not in exclude
],
'disagreement': [
iid for iid in self._disagreement_pool
if iid in available_ids and iid not in exclude
],
'edge_case_rule': [
iid for iid in self._edge_case_rule_pool
if iid in available_ids and iid not in exclude
],
'cartography': [
iid for iid in self._cartography_pool
if iid in available_ids and iid not in exclude
],
'llm_predicted': [
iid for iid in self._llm_predicted_pool
if iid in available_ids and iid not in exclude
],
}
# Select pool based on weights
selected_pool, pool_name = self._weighted_pool_selection(pools)
if not selected_pool:
# Fallback to any available instance
remaining = [iid for iid in available_ids if iid not in exclude]
if remaining:
instance_id = self.random.choice(remaining)
self._record_selection(instance_id, 'fallback')
return instance_id
return None
# Select from pool
if pool_name == 'low_confidence':
# Sort by confidence (lowest first) and take first
instance_id = self._select_lowest_confidence(selected_pool)
elif pool_name == 'diverse':
# Take first (already ordered by diversity)
instance_id = selected_pool[0]
elif pool_name == 'disagreement':
# Random from disagreements
instance_id = self.random.choice(selected_pool)
elif pool_name == 'edge_case_rule':
# Random from edge case rule matches
instance_id = self.random.choice(selected_pool)
elif pool_name == 'cartography':
# Take first (already sorted by variability, highest first)
instance_id = selected_pool[0]
elif pool_name == 'llm_predicted':
# Random from LLM-predicted instances
instance_id = self.random.choice(selected_pool)
else: # random
instance_id = self.random.choice(selected_pool)
self._record_selection(instance_id, pool_name)
return instance_id
def _weighted_pool_selection(
self,
pools: Dict[str, List[str]]
) -> Tuple[List[str], str]:
"""
Select a pool based on configured weights.
Returns empty list if all pools are empty.
"""
# Build list of (pool, name, weight) for non-empty pools
candidates = []
weights = []
pool_weights = {
'low_confidence': self.weights.low_confidence,
'diverse': self.weights.diverse,
'random': self.weights.random,
'disagreement': self.weights.disagreement,
'edge_case_rule': self.weights.edge_case_rule,
'cartography': self.weights.cartography,
'llm_predicted': self.weights.llm_predicted,
}
for name, pool in pools.items():
if pool: # Only consider non-empty pools
candidates.append((pool, name))
weights.append(pool_weights[name])
if not candidates:
return [], ''
# Normalize weights
total = sum(weights)
if total > 0:
weights = [w / total for w in weights]
# Weighted random selection
r = self.random.random()
cumsum = 0
for (pool, name), weight in zip(candidates, weights):
cumsum += weight
if r <= cumsum:
return pool, name
# Fallback to last
return candidates[-1]
def _select_lowest_confidence(self, pool: List[str]) -> str:
"""Select the instance with lowest LLM confidence from cached predictions."""
min_conf = float('inf')
best_id = pool[0]
for instance_id in pool:
if instance_id in self._predictions_cache:
for pred in self._predictions_cache[instance_id].values():
conf = pred.get('confidence_score', 1.0)
if conf < min_conf:
min_conf = conf
best_id = instance_id
return best_id
def _record_selection(self, instance_id: str, pool_name: str) -> None:
"""Record a selection for analytics."""
from datetime import datetime
self.selection_history.append({
'instance_id': instance_id,
'pool': pool_name,
'timestamp': datetime.now().isoformat(),
})
def select_batch(
self,
available_ids: Set[str],
batch_size: int,
exclude_ids: Optional[Set[str]] = None
) -> List[str]:
"""
Select a batch of instances for annotation.
Args:
available_ids: Available instance IDs
batch_size: Number of instances to select
exclude_ids: IDs to exclude
Returns:
List of selected instance IDs
"""
selected = []
exclude = set(exclude_ids) if exclude_ids else set()
for _ in range(batch_size):
instance_id = self.select_next(available_ids, exclude)
if instance_id is None:
break
selected.append(instance_id)
exclude.add(instance_id)
return selected
def get_selection_stats(self) -> Dict[str, Any]:
"""Get statistics about selections made."""
with self._lock:
from collections import Counter
pool_counts = Counter(s['pool'] for s in self.selection_history)
return {
'total_selections': len(self.selection_history),
'by_pool': dict(pool_counts),
'pool_sizes': {
'low_confidence': len(self._low_confidence_pool),
'diverse': len(self._diverse_pool),
'random': len(self._random_pool),
'disagreement': len(self._disagreement_pool),
'edge_case_rule': len(self._edge_case_rule_pool),
'cartography': len(self._cartography_pool),
'llm_predicted': len(self._llm_predicted_pool),
},
'weights': {
'low_confidence': self.weights.low_confidence,
'diverse': self.weights.diverse,
'random': self.weights.random,
'disagreement': self.weights.disagreement,
'edge_case_rule': self.weights.edge_case_rule,
'cartography': self.weights.cartography,
'llm_predicted': self.weights.llm_predicted,
},
}
def clear_history(self) -> None:
"""Clear selection history."""
with self._lock:
self.selection_history.clear()