from typing import List, Dict, Tuple, Optional from fastcoref import LingMessCoref import math import bisect import spacy class CoreferenceResolver: """ Coreference resolution with confidence scoring and LLM fallback """ def __init__(self, confidence_threshold=18.0, min_confidence_threshold=6.0, use_gpu=False, enable_validation=True): """ Initialize coreference resolver Args: confidence_threshold: Average logit threshold (recommend 7-10) min_confidence_threshold: Minimum logit threshold (recommend 3-5) Any cluster with min_confidence below this needs LLM verification use_gpu: Whether to use GPU (faster but requires CUDA) enable_validation: Enable linguistic validation rules (HIGHLY RECOMMENDED) Catches errors like verbs in noun clusters, even with high confidence Logit Scale Reference: -2: ~12% probability (probably NOT coreferent) 0: 50% probability (neutral) +2: ~88% probability (probably coreferent) +5: 99.3% probability (very confident) +8: 99.97% probability (extremely confident) +10: 99.995% probability (near certain) +15: 99.9997% probability (essentially certain) +30+: ~100% probability (but can still be WRONG!) CRITICAL: High confidence doesn't guarantee correctness! Bigger models may have logits of 30+ but still make linguistic errors. Always enable validation to catch these issues. """ device = 'cuda:0' if use_gpu else 'cpu' print(f"Loading fastcoref model on {device}...") self.model = LingMessCoref(device=device, nlp='en_core_web_lg') self.confidence_threshold = confidence_threshold self.min_confidence_threshold = min_confidence_threshold self.enable_validation = enable_validation # Load spaCy for validation if enabled if enable_validation: try: self.nlp = spacy.load('en_core_web_lg') except: print("Warning: spaCy model not found. Install with: python -m spacy download en_core_web_lg") self.nlp = None else: self.nlp = None def _validate_cluster(self, text: str, cluster_strings: List[str]) -> Dict: """ Validate a coreference cluster for linguistic correctness Returns: { 'is_valid': bool, 'issues': List[str], 'severity': 'high' | 'medium' | 'low' } """ if not self.enable_validation or not self.nlp: return {'is_valid': True, 'issues': [], 'severity': None} issues = [] self.doc = self.nlp(text) # Extract POS tags for each mention mention_pos = {} for mention in cluster_strings: # Find mention in self.doc mention_doc = self.nlp(mention) if len(mention_doc) > 0: # Get the head word's POS head_pos = mention_doc[-1].pos_ # Last word usually the head mention_pos[mention] = head_pos # Rule 1: No verbs in noun coreference clusters has_verb = any(pos == 'VERB' for pos in mention_pos.values()) has_noun = any(pos in ['NOUN', 'PROPN', 'PRON'] for pos in mention_pos.values()) if has_verb and has_noun: issues.append(f"Cluster contains both VERBs and NOUNs: {mention_pos}") severity = 'high' # Rule 2: Check for mixed entity types (if spaCy NER available) entity_types = set() for ent in self.doc.ents: for mention in cluster_strings: if mention.lower() in ent.text.lower(): entity_types.add(ent.label_) # Incompatible entity types incompatible = [ ({'PERSON'}, {'ORG', 'GPE'}), ({'ORG'}, {'PERSON'}), ({'DATE'}, {'PERSON', 'ORG'}), ] for incomp_set1, incomp_set2 in incompatible: if entity_types & incomp_set1 and entity_types & incomp_set2: issues.append(f"Incompatible entity types: {entity_types}") severity = 'high' # Rule 3: Pronouns should not cluster with verbs pronouns = {'he', 'she', 'it', 'they', 'him', 'her', 'them', 'his', 'hers', 'its', 'their'} has_pronoun = any(m.lower() in pronouns for m in cluster_strings) if has_pronoun and has_verb: issues.append(f"Pronoun clustered with VERB") severity = 'high' # Determine overall severity if not issues: severity = None elif not severity: severity = 'low' return { 'is_valid': len(issues) == 0, 'issues': issues, 'severity': severity, 'mention_pos': mention_pos } def resolve_with_confidence( self, text: str, return_clusters=True, return_resolved_text=True ) -> Dict: """ Resolve coreferences and return results with confidence scores Args: text: Input text return_clusters: Whether to return coreference clusters return_resolved_text: Whether to return text with resolved pronouns Returns: { 'clusters': List of clusters with confidence, 'resolved_text': Text with pronouns replaced, 'low_confidence_spans': Spans that need LLM verification, 'needs_llm_fallback': Boolean indicating if LLM fallback needed } """ # Get predictions preds = self.model.predict(texts=[text]) pred = preds[0] # Get clusters as character indices clusters = pred.get_clusters(as_strings=False) clusters_strings = pred.get_clusters(as_strings=True) # Calculate confidence for each cluster clusters_with_confidence = [] low_confidence_spans = [] for i, (cluster_indices, cluster_strings) in enumerate(zip(clusters, clusters_strings)): # Get pairwise logits within cluster logits = [] for j in range(len(cluster_indices) - 1): span_i = cluster_indices[j] span_j = cluster_indices[j + 1] try: logit = pred.get_logit(span_i, span_j) logits.append(logit) except: # If can't get logit, assume low confidence logits.append(0.0) # Calculate average confidence for cluster avg_logit = float(sum(logits) / len(logits) if logits else 0.0) min_logit = float(min(logits) if logits else 0.0) # Convert logit to probability for interpretability avg_prob = self._logit_to_prob(avg_logit) # Validate cluster for linguistic correctness validation = self._validate_cluster(text, cluster_strings) # Determine if cluster needs verification using BOTH thresholds AND validation # Fail if ANY condition is true: # 1. Average confidence is low (overall cluster quality) # 2. Minimum confidence is low (at least one bad pairing) # 3. Validation fails (linguistic errors) needs_verification = ( avg_logit < self.confidence_threshold or min_logit < self.min_confidence_threshold or not validation['is_valid'] ) cluster_info = { 'cluster_id': i, 'mentions': cluster_strings, 'spans': cluster_indices, 'avg_confidence': avg_logit, 'min_confidence': min_logit, 'avg_probability': avg_prob, 'validation': validation, 'is_confident': not needs_verification, 'reason': self._get_confidence_reason(avg_logit, min_logit, validation) } clusters_with_confidence.append(cluster_info) # Track low confidence clusters if needs_verification: low_confidence_spans.append(cluster_info) # Generate resolved text (replace pronouns with main mentions) resolved_text = self._generate_resolved_text(text, clusters, clusters_strings) # Determine if LLM fallback is needed needs_llm = len(low_confidence_spans) > 0 return { 'original_text': text, 'clusters': clusters_with_confidence, 'resolved_text': resolved_text, 'low_confidence_spans': low_confidence_spans, 'needs_llm_fallback': needs_llm, 'num_clusters': len(clusters), 'num_low_confidence': len(low_confidence_spans), } def _get_confidence_reason(self, avg_logit: float, min_logit: float, validation: Dict) -> str: """Explain why a cluster has low confidence or failed validation""" reasons = [] if avg_logit < self.confidence_threshold: prob = self._logit_to_prob(avg_logit) reasons.append(f"Low average confidence (logit {avg_logit:.2f} = {prob:.2%})") if min_logit < self.min_confidence_threshold: prob = self._logit_to_prob(min_logit) reasons.append(f"Low minimum confidence (logit {min_logit:.2f} = {prob:.2%})") if not validation['is_valid']: for issue in validation['issues']: reasons.append(f"Validation failed: {issue}") if not reasons: avg_prob = self._logit_to_prob(avg_logit) return f"High confidence (avg logit {avg_logit:.2f} = {avg_prob:.2%}), validation passed" return "; ".join(reasons) def _logit_to_prob(self, logit: float) -> float: """Convert logit to probability using sigmoid""" return 1 / (1 + math.exp(-logit)) def _generate_resolved_text( self, text: str, clusters: List[List[Tuple[int, int]]], clusters_strings: List[List[str]] ) -> str: """ Generate text with pronouns replaced by their antecedents Args: text: Original text clusters: List of clusters as character indices clusters_strings: List of clusters as strings Returns: Text with resolved coreferences """ # Create replacement map: (start, end) -> replacement_text replacements = {} for cluster_indices, cluster_strings in zip(clusters, clusters_strings): if len(cluster_strings) < 2: continue # Use first mention as the main mention (could be improved) main_mention = cluster_strings[0] # Replace all subsequent mentions with main mention for i, (start, end) in enumerate(cluster_indices[1:], 1): mention = cluster_strings[i] # Only replace pronouns, not full names if self._is_pronoun(mention): replacements[(start, end)] = main_mention # Apply replacements (from end to start to maintain indices) sorted_replacements = sorted(replacements.items(), key=lambda x: x[0][0], reverse=True) resolved = text for (start, end), replacement in sorted_replacements: resolved = resolved[:start] + replacement + resolved[end:] return resolved def _is_pronoun(self, text: str) -> bool: """Simple pronoun detection""" pronouns = { 'he', 'she', 'it', 'they', 'him', 'her', 'them', 'his', 'hers', 'its', 'their', 'theirs', 'himself', 'herself', 'itself', 'themselves' } return text.lower().strip() in pronouns def resolve_with_llm_fallback( self, text: str, llm_resolve_func: Optional[callable] = None ) -> Dict: """ Resolve coreferences with automatic LLM fallback for low confidence Args: text: Input text llm_resolve_func: Function to call for LLM resolution Should take (text, low_confidence_info) and return resolved_text Returns: Resolution results with LLM fallback applied if needed """ # First try with fastcoref result = self.resolve_with_confidence(text) # If low confidence and LLM function provided, use fallback if result['needs_llm_fallback'] and llm_resolve_func: print(f"\n⚠️ Low confidence detected for {result['num_low_confidence']} clusters") print("🤖 Falling back to LLM for resolution...") # Call LLM for low confidence spans llm_result = llm_resolve_func(text, result['low_confidence_spans']) result['llm_resolved_text'] = llm_result result['resolution_method'] = 'hybrid (fastcoref + LLM)' else: result['resolution_method'] = 'fastcoref only' return result