Spaces:
Paused
Paused
| """ | |
| Validation Tracker for Solo Mode | |
| This module tracks agreement metrics between human and LLM annotations, | |
| manages thresholds for phase transitions, and provides validation | |
| sampling for final quality assurance. | |
| """ | |
| import logging | |
| import random | |
| import threading | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional, Set, Tuple | |
| from collections import defaultdict | |
| logger = logging.getLogger(__name__) | |
| class AgreementMetrics: | |
| """Metrics for human-LLM agreement.""" | |
| total_compared: int = 0 | |
| agreements: int = 0 | |
| disagreements: int = 0 | |
| agreement_rate: float = 0.0 | |
| # Per-label metrics | |
| label_agreements: Dict[str, int] = field(default_factory=dict) | |
| label_disagreements: Dict[str, int] = field(default_factory=dict) | |
| # Confusion tracking | |
| confusion_matrix: Dict[Tuple[str, str], int] = field(default_factory=dict) | |
| # Time-based tracking | |
| recent_agreement_rate: float = 0.0 # Last N comparisons | |
| trend: str = "stable" # "improving", "declining", "stable" | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Serialize to dictionary.""" | |
| return { | |
| 'total_compared': self.total_compared, | |
| 'agreements': self.agreements, | |
| 'disagreements': self.disagreements, | |
| 'agreement_rate': self.agreement_rate, | |
| 'label_agreements': self.label_agreements, | |
| 'label_disagreements': self.label_disagreements, | |
| 'confusion_matrix': { | |
| f"{k[0]}|{k[1]}": v | |
| for k, v in self.confusion_matrix.items() | |
| }, | |
| 'recent_agreement_rate': self.recent_agreement_rate, | |
| 'trend': self.trend, | |
| } | |
| class ValidationSample: | |
| """A sample selected for final validation.""" | |
| instance_id: str | |
| llm_label: Any | |
| llm_confidence: float | |
| selected_at: datetime = field(default_factory=datetime.now) | |
| # Human validation results | |
| human_label: Optional[Any] = None | |
| validated_at: Optional[datetime] = None | |
| agrees: Optional[bool] = None | |
| notes: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Serialize to dictionary.""" | |
| return { | |
| 'instance_id': self.instance_id, | |
| 'llm_label': self.llm_label, | |
| 'llm_confidence': self.llm_confidence, | |
| 'selected_at': self.selected_at.isoformat(), | |
| 'human_label': self.human_label, | |
| 'validated_at': self.validated_at.isoformat() if self.validated_at else None, | |
| 'agrees': self.agrees, | |
| 'notes': self.notes, | |
| } | |
| class ValidationTracker: | |
| """ | |
| Tracks agreement metrics and manages validation sampling. | |
| This class monitors the agreement between human and LLM annotations, | |
| determines when thresholds are met for phase transitions, and | |
| manages sampling for final validation. | |
| """ | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| """ | |
| Initialize the validation tracker. | |
| Args: | |
| config: Configuration dictionary with threshold settings | |
| """ | |
| self.config = config or {} | |
| self._lock = threading.RLock() | |
| # Load thresholds from config | |
| solo_config = self.config.get('solo_mode', {}) | |
| thresholds = solo_config.get('thresholds', {}) | |
| self.end_human_threshold = thresholds.get( | |
| 'end_human_annotation_agreement', 0.90 | |
| ) | |
| self.minimum_validation_sample = thresholds.get( | |
| 'minimum_validation_sample', 50 | |
| ) | |
| self.periodic_review_interval = thresholds.get( | |
| 'periodic_review_interval', 100 | |
| ) | |
| # Metrics tracking | |
| self._metrics = AgreementMetrics() | |
| self._comparison_history: List[Dict[str, Any]] = [] | |
| self._recent_window = 50 # Window for recent agreement rate | |
| # Validation samples | |
| self._validation_samples: Dict[str, ValidationSample] = {} | |
| self._validation_sample_size = self.minimum_validation_sample | |
| # Random state for reproducible sampling | |
| self._random = random.Random() | |
| # Tracking for periodic review | |
| self._llm_labels_since_review = 0 | |
| def record_comparison( | |
| self, | |
| instance_id: str, | |
| human_label: Any, | |
| llm_label: Any, | |
| schema_name: str, | |
| agrees: bool | |
| ) -> None: | |
| """ | |
| Record a comparison between human and LLM labels. | |
| Args: | |
| instance_id: The instance ID | |
| human_label: The human-assigned label | |
| llm_label: The LLM-predicted label | |
| schema_name: The annotation schema name | |
| agrees: Whether the labels agree | |
| """ | |
| with self._lock: | |
| # Update overall metrics | |
| self._metrics.total_compared += 1 | |
| if agrees: | |
| self._metrics.agreements += 1 | |
| else: | |
| self._metrics.disagreements += 1 | |
| # Calculate agreement rate | |
| if self._metrics.total_compared > 0: | |
| self._metrics.agreement_rate = ( | |
| self._metrics.agreements / self._metrics.total_compared | |
| ) | |
| # Update per-label metrics | |
| human_str = str(human_label) | |
| llm_str = str(llm_label) | |
| if agrees: | |
| self._metrics.label_agreements[human_str] = ( | |
| self._metrics.label_agreements.get(human_str, 0) + 1 | |
| ) | |
| else: | |
| self._metrics.label_disagreements[human_str] = ( | |
| self._metrics.label_disagreements.get(human_str, 0) + 1 | |
| ) | |
| # Track confusion | |
| key = (llm_str, human_str) # LLM predicted, human corrected | |
| self._metrics.confusion_matrix[key] = ( | |
| self._metrics.confusion_matrix.get(key, 0) + 1 | |
| ) | |
| # Record in history | |
| self._comparison_history.append({ | |
| 'instance_id': instance_id, | |
| 'human_label': human_label, | |
| 'llm_label': llm_label, | |
| 'schema_name': schema_name, | |
| 'agrees': agrees, | |
| 'timestamp': datetime.now().isoformat(), | |
| }) | |
| # Update recent agreement rate | |
| self._update_recent_metrics() | |
| logger.debug( | |
| f"Recorded comparison for {instance_id}: " | |
| f"agrees={agrees}, rate={self._metrics.agreement_rate:.2%}" | |
| ) | |
| def _update_recent_metrics(self) -> None: | |
| """Update metrics based on recent comparisons.""" | |
| if len(self._comparison_history) < 2: | |
| return | |
| # Calculate recent agreement rate | |
| recent = self._comparison_history[-self._recent_window:] | |
| recent_agreements = sum(1 for c in recent if c['agrees']) | |
| self._metrics.recent_agreement_rate = recent_agreements / len(recent) | |
| # Determine trend | |
| if len(self._comparison_history) >= self._recent_window * 2: | |
| older = self._comparison_history[ | |
| -self._recent_window * 2:-self._recent_window | |
| ] | |
| older_rate = sum(1 for c in older if c['agrees']) / len(older) | |
| diff = self._metrics.recent_agreement_rate - older_rate | |
| if diff > 0.05: | |
| self._metrics.trend = "improving" | |
| elif diff < -0.05: | |
| self._metrics.trend = "declining" | |
| else: | |
| self._metrics.trend = "stable" | |
| def get_metrics(self) -> AgreementMetrics: | |
| """Get current agreement metrics.""" | |
| with self._lock: | |
| return self._metrics | |
| def get_comparison_history(self) -> List[Dict[str, Any]]: | |
| """Get the full comparison history.""" | |
| with self._lock: | |
| return list(self._comparison_history) | |
| def should_end_human_annotation(self) -> bool: | |
| """ | |
| Check if agreement threshold is met for ending human annotation. | |
| Returns: | |
| True if the agreement rate meets the threshold | |
| """ | |
| with self._lock: | |
| # Need minimum number of comparisons | |
| if self._metrics.total_compared < self.minimum_validation_sample: | |
| return False | |
| # Check if agreement rate meets threshold | |
| return self._metrics.agreement_rate >= self.end_human_threshold | |
| def should_trigger_periodic_review(self) -> bool: | |
| """ | |
| Check if it's time for periodic review of LLM labels. | |
| Returns: | |
| True if periodic review should be triggered | |
| """ | |
| with self._lock: | |
| return self._llm_labels_since_review >= self.periodic_review_interval | |
| def record_llm_label(self, instance_id: str) -> None: | |
| """Record that an LLM label was generated (for periodic review tracking).""" | |
| with self._lock: | |
| self._llm_labels_since_review += 1 | |
| def reset_periodic_review_counter(self) -> None: | |
| """Reset the periodic review counter after a review.""" | |
| with self._lock: | |
| self._llm_labels_since_review = 0 | |
| def select_validation_sample( | |
| self, | |
| llm_labeled_instances: Dict[str, Dict[str, Any]], | |
| sample_size: Optional[int] = None | |
| ) -> List[str]: | |
| """ | |
| Select a sample of LLM-labeled instances for final validation. | |
| Uses stratified sampling based on confidence levels. | |
| Args: | |
| llm_labeled_instances: Dict of instance_id -> {label, confidence} | |
| sample_size: Number of instances to sample (default: minimum_validation_sample) | |
| Returns: | |
| List of selected instance IDs | |
| """ | |
| with self._lock: | |
| size = sample_size or self._validation_sample_size | |
| if len(llm_labeled_instances) <= size: | |
| # If we have fewer instances than sample size, use all | |
| selected_ids = list(llm_labeled_instances.keys()) | |
| else: | |
| # Stratified sampling by confidence | |
| selected_ids = self._stratified_sample( | |
| llm_labeled_instances, size | |
| ) | |
| # Create validation samples | |
| for instance_id in selected_ids: | |
| pred = llm_labeled_instances[instance_id] | |
| self._validation_samples[instance_id] = ValidationSample( | |
| instance_id=instance_id, | |
| llm_label=pred.get('label'), | |
| llm_confidence=pred.get('confidence', 0.5), | |
| ) | |
| logger.info(f"Selected {len(selected_ids)} instances for validation") | |
| return selected_ids | |
| def _stratified_sample( | |
| self, | |
| instances: Dict[str, Dict[str, Any]], | |
| sample_size: int | |
| ) -> List[str]: | |
| """ | |
| Perform stratified sampling based on confidence levels. | |
| Samples more from low-confidence instances to catch potential errors. | |
| """ | |
| # Split into confidence strata | |
| low_conf = [] # < 0.5 | |
| mid_conf = [] # 0.5 - 0.8 | |
| high_conf = [] # >= 0.8 | |
| for instance_id, pred in instances.items(): | |
| conf = pred.get('confidence', 0.5) | |
| if conf < 0.5: | |
| low_conf.append(instance_id) | |
| elif conf < 0.8: | |
| mid_conf.append(instance_id) | |
| else: | |
| high_conf.append(instance_id) | |
| # Sample proportions: 40% low, 35% mid, 25% high | |
| # (oversamples low confidence) | |
| n_low = min(len(low_conf), int(sample_size * 0.4)) | |
| n_mid = min(len(mid_conf), int(sample_size * 0.35)) | |
| n_high = min(len(high_conf), sample_size - n_low - n_mid) | |
| # Adjust if strata are too small | |
| remaining = sample_size - n_low - n_mid - n_high | |
| if remaining > 0: | |
| # Redistribute remaining to available strata | |
| for stratum, current in [ | |
| (high_conf, n_high), | |
| (mid_conf, n_mid), | |
| (low_conf, n_low) | |
| ]: | |
| available = len(stratum) - current | |
| take = min(available, remaining) | |
| if stratum is high_conf: | |
| n_high += take | |
| elif stratum is mid_conf: | |
| n_mid += take | |
| else: | |
| n_low += take | |
| remaining -= take | |
| if remaining <= 0: | |
| break | |
| # Perform random sampling from each stratum | |
| selected = [] | |
| if low_conf and n_low > 0: | |
| selected.extend(self._random.sample(low_conf, n_low)) | |
| if mid_conf and n_mid > 0: | |
| selected.extend(self._random.sample(mid_conf, n_mid)) | |
| if high_conf and n_high > 0: | |
| selected.extend(self._random.sample(high_conf, n_high)) | |
| return selected | |
| def record_validation_result( | |
| self, | |
| instance_id: str, | |
| human_label: Any, | |
| notes: Optional[str] = None | |
| ) -> bool: | |
| """ | |
| Record the human validation result for a sample. | |
| Args: | |
| instance_id: The instance ID | |
| human_label: The human-assigned label | |
| notes: Optional validation notes | |
| Returns: | |
| True if the result was recorded | |
| """ | |
| with self._lock: | |
| if instance_id not in self._validation_samples: | |
| logger.warning(f"Unknown validation sample: {instance_id}") | |
| return False | |
| sample = self._validation_samples[instance_id] | |
| sample.human_label = human_label | |
| sample.validated_at = datetime.now() | |
| sample.agrees = (sample.llm_label == human_label) | |
| sample.notes = notes | |
| logger.debug( | |
| f"Recorded validation for {instance_id}: " | |
| f"agrees={sample.agrees}" | |
| ) | |
| return True | |
| def get_validation_progress(self) -> Dict[str, Any]: | |
| """Get progress on validation sample.""" | |
| with self._lock: | |
| total = len(self._validation_samples) | |
| validated = sum( | |
| 1 for s in self._validation_samples.values() | |
| if s.validated_at is not None | |
| ) | |
| agreements = sum( | |
| 1 for s in self._validation_samples.values() | |
| if s.agrees is True | |
| ) | |
| return { | |
| 'total_samples': total, | |
| 'validated': validated, | |
| 'remaining': total - validated, | |
| 'agreements': agreements, | |
| 'disagreements': validated - agreements, | |
| 'validation_accuracy': ( | |
| agreements / validated if validated > 0 else 0.0 | |
| ), | |
| 'percent_complete': ( | |
| validated / total * 100 if total > 0 else 0.0 | |
| ), | |
| } | |
| def get_unvalidated_samples(self) -> List[ValidationSample]: | |
| """Get validation samples that haven't been validated yet.""" | |
| with self._lock: | |
| return [ | |
| s for s in self._validation_samples.values() | |
| if s.validated_at is None | |
| ] | |
| def get_validation_samples(self) -> List[ValidationSample]: | |
| """Get all validation samples.""" | |
| with self._lock: | |
| return list(self._validation_samples.values()) | |
| def get_confusion_analysis(self) -> Dict[str, Any]: | |
| """ | |
| Analyze confusion patterns between human and LLM labels. | |
| Returns: | |
| Analysis of common confusion patterns | |
| """ | |
| with self._lock: | |
| if not self._metrics.confusion_matrix: | |
| return {'patterns': [], 'most_confused': None} | |
| # Sort by frequency | |
| sorted_confusion = sorted( | |
| self._metrics.confusion_matrix.items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| patterns = [] | |
| for (llm_label, human_label), count in sorted_confusion[:10]: | |
| patterns.append({ | |
| 'llm_predicted': llm_label, | |
| 'human_corrected': human_label, | |
| 'count': count, | |
| 'percent': ( | |
| count / self._metrics.disagreements * 100 | |
| if self._metrics.disagreements > 0 else 0 | |
| ), | |
| }) | |
| return { | |
| 'patterns': patterns, | |
| 'most_confused': patterns[0] if patterns else None, | |
| 'total_disagreements': self._metrics.disagreements, | |
| } | |
| def get_label_accuracy(self) -> Dict[str, float]: | |
| """ | |
| Get per-label accuracy rates. | |
| Returns: | |
| Dict of label -> accuracy rate | |
| """ | |
| with self._lock: | |
| accuracies = {} | |
| all_labels = set(self._metrics.label_agreements.keys()) | set( | |
| self._metrics.label_disagreements.keys() | |
| ) | |
| for label in all_labels: | |
| agreements = self._metrics.label_agreements.get(label, 0) | |
| disagreements = self._metrics.label_disagreements.get(label, 0) | |
| total = agreements + disagreements | |
| if total > 0: | |
| accuracies[label] = agreements / total | |
| else: | |
| accuracies[label] = 0.0 | |
| return accuracies | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get comprehensive tracker status.""" | |
| with self._lock: | |
| return { | |
| 'metrics': self._metrics.to_dict(), | |
| 'thresholds': { | |
| 'end_human_annotation': self.end_human_threshold, | |
| 'minimum_validation_sample': self.minimum_validation_sample, | |
| 'periodic_review_interval': self.periodic_review_interval, | |
| }, | |
| 'should_end_human_annotation': self.should_end_human_annotation(), | |
| 'should_trigger_review': self.should_trigger_periodic_review(), | |
| 'llm_labels_since_review': self._llm_labels_since_review, | |
| 'validation_progress': self.get_validation_progress(), | |
| 'label_accuracy': self.get_label_accuracy(), | |
| } | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Serialize to dictionary for persistence.""" | |
| with self._lock: | |
| return { | |
| 'metrics': self._metrics.to_dict(), | |
| 'comparison_history': self._comparison_history, | |
| 'validation_samples': { | |
| sid: sample.to_dict() | |
| for sid, sample in self._validation_samples.items() | |
| }, | |
| 'llm_labels_since_review': self._llm_labels_since_review, | |
| } | |
| def from_dict(self, data: Dict[str, Any]) -> None: | |
| """Load from dictionary.""" | |
| with self._lock: | |
| # Restore metrics | |
| metrics_data = data.get('metrics', {}) | |
| self._metrics = AgreementMetrics( | |
| total_compared=metrics_data.get('total_compared', 0), | |
| agreements=metrics_data.get('agreements', 0), | |
| disagreements=metrics_data.get('disagreements', 0), | |
| agreement_rate=metrics_data.get('agreement_rate', 0.0), | |
| label_agreements=metrics_data.get('label_agreements', {}), | |
| label_disagreements=metrics_data.get('label_disagreements', {}), | |
| recent_agreement_rate=metrics_data.get('recent_agreement_rate', 0.0), | |
| trend=metrics_data.get('trend', 'stable'), | |
| ) | |
| # Restore confusion matrix | |
| confusion_data = metrics_data.get('confusion_matrix', {}) | |
| for key_str, count in confusion_data.items(): | |
| parts = key_str.split('|') | |
| if len(parts) == 2: | |
| self._metrics.confusion_matrix[(parts[0], parts[1])] = count | |
| # Restore history | |
| self._comparison_history = data.get('comparison_history', []) | |
| # Restore validation samples | |
| samples_data = data.get('validation_samples', {}) | |
| for sid, sample_data in samples_data.items(): | |
| self._validation_samples[sid] = ValidationSample( | |
| instance_id=sample_data['instance_id'], | |
| llm_label=sample_data['llm_label'], | |
| llm_confidence=sample_data['llm_confidence'], | |
| selected_at=datetime.fromisoformat(sample_data['selected_at']), | |
| human_label=sample_data.get('human_label'), | |
| validated_at=( | |
| datetime.fromisoformat(sample_data['validated_at']) | |
| if sample_data.get('validated_at') else None | |
| ), | |
| agrees=sample_data.get('agrees'), | |
| notes=sample_data.get('notes'), | |
| ) | |
| self._llm_labels_since_review = data.get('llm_labels_since_review', 0) | |