""" Instance Selector for Solo Mode This module implements weighted instance selection for human annotation. It combines multiple signals (LLM confidence, diversity, disagreements, random) to prioritize which instances the human annotator should see next. """ import logging import random from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple import threading logger = logging.getLogger(__name__) @dataclass class SelectionWeights: """Configuration for instance selection weights.""" low_confidence: float = 0.4 # Low LLM confidence instances diverse: float = 0.3 # Diverse instances (embedding clusters) random: float = 0.2 # Random sample for calibration disagreement: float = 0.1 # Instances with prior disagreements edge_case_rule: float = 0.0 # Instances matching edge case rule patterns cartography: float = 0.0 # Instances with high confidence variability llm_predicted: float = 0.0 # Instances with LLM predictions needing human comparison def validate(self) -> None: """Validate that weights sum to 1.0.""" total = ( self.low_confidence + self.diverse + self.random + self.disagreement + self.edge_case_rule + self.cartography + self.llm_predicted ) if abs(total - 1.0) > 0.001: # Normalize self.low_confidence /= total self.diverse /= total self.random /= total self.disagreement /= total self.edge_case_rule /= total self.cartography /= total self.llm_predicted /= total logger.warning(f"Normalized selection weights (original sum: {total})") class InstanceSelector: """ Weighted instance selector for Solo Mode. Combines multiple signals to select which instances the human should annotate, optimizing for efficient use of human labeling time. Selection pools: 1. Low confidence: Instances where LLM is uncertain 2. Diverse: Instances from different embedding clusters 3. Random: Random sample for calibration 4. Disagreement: Instances with prior human-LLM disagreement """ def __init__( self, weights: Optional[SelectionWeights] = None, config: Optional[Dict[str, Any]] = None ): """ Initialize the instance selector. Args: weights: Selection weight configuration config: Full application configuration """ self.weights = weights or SelectionWeights() self.weights.validate() self.config = config or {} self._lock = threading.RLock() # Random state self.random = random.Random() # Track selection history self.selection_history: List[Dict[str, Any]] = [] # Pool state self._low_confidence_pool: List[str] = [] self._diverse_pool: List[str] = [] self._random_pool: List[str] = [] self._disagreement_pool: List[str] = [] self._edge_case_rule_pool: List[str] = [] self._cartography_pool: List[str] = [] self._llm_predicted_pool: List[str] = [] # Cache predictions for use in _select_lowest_confidence self._predictions_cache: Dict[str, Dict[str, Any]] = {} def configure( self, low_confidence_weight: float = 0.4, diversity_weight: float = 0.3, random_weight: float = 0.2, disagreement_weight: float = 0.1, edge_case_rule_weight: float = 0.0, cartography_weight: float = 0.0, ) -> None: """Configure selection weights.""" self.weights = SelectionWeights( low_confidence=low_confidence_weight, diverse=diversity_weight, random=random_weight, disagreement=disagreement_weight, edge_case_rule=edge_case_rule_weight, cartography=cartography_weight, ) self.weights.validate() def refresh_pools( self, available_ids: Set[str], llm_predictions: Optional[Dict[str, Dict[str, Any]]] = None, disagreement_ids: Optional[Set[str]] = None, confidence_threshold: float = 0.5, edge_case_rule_ids: Optional[Set[str]] = None, cartography_scores: Optional[Dict[str, float]] = None, ) -> None: """ Refresh the selection pools based on current state. Args: available_ids: Set of instance IDs available for selection llm_predictions: Dict of instance_id -> schema -> prediction disagreement_ids: Set of instance IDs with disagreements confidence_threshold: Threshold for low confidence pool edge_case_rule_ids: Set of instance IDs matching edge case rule patterns cartography_scores: Dict of instance_id -> variability score """ with self._lock: available_list = list(available_ids) # Cache predictions for use in _select_lowest_confidence self._predictions_cache = llm_predictions or {} # Clear pools self._low_confidence_pool = [] self._diverse_pool = [] self._random_pool = [] self._disagreement_pool = [] self._edge_case_rule_pool = [] self._cartography_pool = [] self._llm_predicted_pool = [] # Build low confidence pool if llm_predictions: for instance_id in available_list: if instance_id in llm_predictions: preds = llm_predictions[instance_id] # Check if any prediction is below threshold for pred in preds.values(): confidence = pred.get('confidence_score', 1.0) if confidence < confidence_threshold: self._low_confidence_pool.append(instance_id) break # Build LLM-predicted pool: instances with predictions that aren't # already in the low_confidence pool (those are more valuable there). # This pool steers human annotations toward instances where a comparison # with the LLM can happen immediately. if llm_predictions: low_conf_set = set(self._low_confidence_pool) self._llm_predicted_pool = [ iid for iid in available_list if iid in llm_predictions and iid not in low_conf_set ] # Build disagreement pool if disagreement_ids: self._disagreement_pool = [ iid for iid in available_list if iid in disagreement_ids ] # Build edge case rule pool if edge_case_rule_ids: self._edge_case_rule_pool = [ iid for iid in available_list if iid in edge_case_rule_ids ] # Build cartography pool (high variability = ambiguous instances) if cartography_scores: scored = [ (iid, score) for iid, score in cartography_scores.items() if iid in available_ids and score > 0 ] scored.sort(key=lambda x: x[1], reverse=True) self._cartography_pool = [iid for iid, _ in scored] # Diverse pool uses diversity manager if available self._diverse_pool = self._build_diverse_pool(available_list) # Random pool is just all available (sampling happens at selection time) self._random_pool = available_list.copy() logger.debug( f"Refreshed pools: low_conf={len(self._low_confidence_pool)}, " f"llm_predicted={len(self._llm_predicted_pool)}, " f"diverse={len(self._diverse_pool)}, " f"random={len(self._random_pool)}, " f"disagreement={len(self._disagreement_pool)}, " f"edge_case_rule={len(self._edge_case_rule_pool)}, " f"cartography={len(self._cartography_pool)}" ) def _build_diverse_pool(self, available_ids: List[str]) -> List[str]: """ Build the diverse instances pool using DiversityManager. Returns instances ordered by diversity (from different clusters). """ try: from potato.diversity_manager import get_diversity_manager dm = get_diversity_manager() if dm is None or not dm.enabled: return [] # Get diverse ordering from all clusters diverse = dm.generate_diverse_ordering( user_id='solo_mode', available_ids=available_ids, preserve_ids=set() ) return diverse except Exception as e: logger.debug(f"Could not build diverse pool: {e}") return [] def select_next( self, available_ids: Set[str], exclude_ids: Optional[Set[str]] = None ) -> Optional[str]: """ Select the next instance for human annotation. Uses weighted random selection across the pools. Args: available_ids: Set of available instance IDs exclude_ids: Set of IDs to exclude from selection Returns: Selected instance ID, or None if no instances available """ with self._lock: # Filter pools by available and exclude exclude = exclude_ids or set() pools = { 'low_confidence': [ iid for iid in self._low_confidence_pool if iid in available_ids and iid not in exclude ], 'diverse': [ iid for iid in self._diverse_pool if iid in available_ids and iid not in exclude ], 'random': [ iid for iid in self._random_pool if iid in available_ids and iid not in exclude ], 'disagreement': [ iid for iid in self._disagreement_pool if iid in available_ids and iid not in exclude ], 'edge_case_rule': [ iid for iid in self._edge_case_rule_pool if iid in available_ids and iid not in exclude ], 'cartography': [ iid for iid in self._cartography_pool if iid in available_ids and iid not in exclude ], 'llm_predicted': [ iid for iid in self._llm_predicted_pool if iid in available_ids and iid not in exclude ], } # Select pool based on weights selected_pool, pool_name = self._weighted_pool_selection(pools) if not selected_pool: # Fallback to any available instance remaining = [iid for iid in available_ids if iid not in exclude] if remaining: instance_id = self.random.choice(remaining) self._record_selection(instance_id, 'fallback') return instance_id return None # Select from pool if pool_name == 'low_confidence': # Sort by confidence (lowest first) and take first instance_id = self._select_lowest_confidence(selected_pool) elif pool_name == 'diverse': # Take first (already ordered by diversity) instance_id = selected_pool[0] elif pool_name == 'disagreement': # Random from disagreements instance_id = self.random.choice(selected_pool) elif pool_name == 'edge_case_rule': # Random from edge case rule matches instance_id = self.random.choice(selected_pool) elif pool_name == 'cartography': # Take first (already sorted by variability, highest first) instance_id = selected_pool[0] elif pool_name == 'llm_predicted': # Random from LLM-predicted instances instance_id = self.random.choice(selected_pool) else: # random instance_id = self.random.choice(selected_pool) self._record_selection(instance_id, pool_name) return instance_id def _weighted_pool_selection( self, pools: Dict[str, List[str]] ) -> Tuple[List[str], str]: """ Select a pool based on configured weights. Returns empty list if all pools are empty. """ # Build list of (pool, name, weight) for non-empty pools candidates = [] weights = [] pool_weights = { 'low_confidence': self.weights.low_confidence, 'diverse': self.weights.diverse, 'random': self.weights.random, 'disagreement': self.weights.disagreement, 'edge_case_rule': self.weights.edge_case_rule, 'cartography': self.weights.cartography, 'llm_predicted': self.weights.llm_predicted, } for name, pool in pools.items(): if pool: # Only consider non-empty pools candidates.append((pool, name)) weights.append(pool_weights[name]) if not candidates: return [], '' # Normalize weights total = sum(weights) if total > 0: weights = [w / total for w in weights] # Weighted random selection r = self.random.random() cumsum = 0 for (pool, name), weight in zip(candidates, weights): cumsum += weight if r <= cumsum: return pool, name # Fallback to last return candidates[-1] def _select_lowest_confidence(self, pool: List[str]) -> str: """Select the instance with lowest LLM confidence from cached predictions.""" min_conf = float('inf') best_id = pool[0] for instance_id in pool: if instance_id in self._predictions_cache: for pred in self._predictions_cache[instance_id].values(): conf = pred.get('confidence_score', 1.0) if conf < min_conf: min_conf = conf best_id = instance_id return best_id def _record_selection(self, instance_id: str, pool_name: str) -> None: """Record a selection for analytics.""" from datetime import datetime self.selection_history.append({ 'instance_id': instance_id, 'pool': pool_name, 'timestamp': datetime.now().isoformat(), }) def select_batch( self, available_ids: Set[str], batch_size: int, exclude_ids: Optional[Set[str]] = None ) -> List[str]: """ Select a batch of instances for annotation. Args: available_ids: Available instance IDs batch_size: Number of instances to select exclude_ids: IDs to exclude Returns: List of selected instance IDs """ selected = [] exclude = set(exclude_ids) if exclude_ids else set() for _ in range(batch_size): instance_id = self.select_next(available_ids, exclude) if instance_id is None: break selected.append(instance_id) exclude.add(instance_id) return selected def get_selection_stats(self) -> Dict[str, Any]: """Get statistics about selections made.""" with self._lock: from collections import Counter pool_counts = Counter(s['pool'] for s in self.selection_history) return { 'total_selections': len(self.selection_history), 'by_pool': dict(pool_counts), 'pool_sizes': { 'low_confidence': len(self._low_confidence_pool), 'diverse': len(self._diverse_pool), 'random': len(self._random_pool), 'disagreement': len(self._disagreement_pool), 'edge_case_rule': len(self._edge_case_rule_pool), 'cartography': len(self._cartography_pool), 'llm_predicted': len(self._llm_predicted_pool), }, 'weights': { 'low_confidence': self.weights.low_confidence, 'diverse': self.weights.diverse, 'random': self.weights.random, 'disagreement': self.weights.disagreement, 'edge_case_rule': self.weights.edge_case_rule, 'cartography': self.weights.cartography, 'llm_predicted': self.weights.llm_predicted, }, } def clear_history(self) -> None: """Clear selection history.""" with self._lock: self.selection_history.clear()