test / coreference_resolution.py
sohom004's picture
Update coreference_resolution.py
3ce4f6c verified
from typing import List, Dict, Tuple, Optional
from fastcoref import LingMessCoref
import math
import bisect
import spacy
class CoreferenceResolver:
"""
Coreference resolution with confidence scoring and LLM fallback
"""
def __init__(self, confidence_threshold=18.0, min_confidence_threshold=6.0, use_gpu=False, enable_validation=True):
"""
Initialize coreference resolver
Args:
confidence_threshold: Average logit threshold (recommend 7-10)
min_confidence_threshold: Minimum logit threshold (recommend 3-5)
Any cluster with min_confidence below this needs LLM verification
use_gpu: Whether to use GPU (faster but requires CUDA)
enable_validation: Enable linguistic validation rules (HIGHLY RECOMMENDED)
Catches errors like verbs in noun clusters, even with high confidence
Logit Scale Reference:
-2: ~12% probability (probably NOT coreferent)
0: 50% probability (neutral)
+2: ~88% probability (probably coreferent)
+5: 99.3% probability (very confident)
+8: 99.97% probability (extremely confident)
+10: 99.995% probability (near certain)
+15: 99.9997% probability (essentially certain)
+30+: ~100% probability (but can still be WRONG!)
CRITICAL: High confidence doesn't guarantee correctness!
Bigger models may have logits of 30+ but still make linguistic errors.
Always enable validation to catch these issues.
"""
device = 'cuda:0' if use_gpu else 'cpu'
print(f"Loading fastcoref model on {device}...")
self.model = LingMessCoref(device=device, nlp='en_core_web_lg')
self.confidence_threshold = confidence_threshold
self.min_confidence_threshold = min_confidence_threshold
self.enable_validation = enable_validation
# Load spaCy for validation if enabled
if enable_validation:
try:
self.nlp = spacy.load('en_core_web_lg')
except:
print("Warning: spaCy model not found. Install with: python -m spacy download en_core_web_lg")
self.nlp = None
else:
self.nlp = None
def _validate_cluster(self, text: str, cluster_strings: List[str]) -> Dict:
"""
Validate a coreference cluster for linguistic correctness
Returns:
{
'is_valid': bool,
'issues': List[str],
'severity': 'high' | 'medium' | 'low'
}
"""
if not self.enable_validation or not self.nlp:
return {'is_valid': True, 'issues': [], 'severity': None}
issues = []
self.doc = self.nlp(text)
# Extract POS tags for each mention
mention_pos = {}
for mention in cluster_strings:
# Find mention in self.doc
mention_doc = self.nlp(mention)
if len(mention_doc) > 0:
# Get the head word's POS
head_pos = mention_doc[-1].pos_ # Last word usually the head
mention_pos[mention] = head_pos
# Rule 1: No verbs in noun coreference clusters
has_verb = any(pos == 'VERB' for pos in mention_pos.values())
has_noun = any(pos in ['NOUN', 'PROPN', 'PRON'] for pos in mention_pos.values())
if has_verb and has_noun:
issues.append(f"Cluster contains both VERBs and NOUNs: {mention_pos}")
severity = 'high'
# Rule 2: Check for mixed entity types (if spaCy NER available)
entity_types = set()
for ent in self.doc.ents:
for mention in cluster_strings:
if mention.lower() in ent.text.lower():
entity_types.add(ent.label_)
# Incompatible entity types
incompatible = [
({'PERSON'}, {'ORG', 'GPE'}),
({'ORG'}, {'PERSON'}),
({'DATE'}, {'PERSON', 'ORG'}),
]
for incomp_set1, incomp_set2 in incompatible:
if entity_types & incomp_set1 and entity_types & incomp_set2:
issues.append(f"Incompatible entity types: {entity_types}")
severity = 'high'
# Rule 3: Pronouns should not cluster with verbs
pronouns = {'he', 'she', 'it', 'they', 'him', 'her', 'them', 'his', 'hers', 'its', 'their'}
has_pronoun = any(m.lower() in pronouns for m in cluster_strings)
if has_pronoun and has_verb:
issues.append(f"Pronoun clustered with VERB")
severity = 'high'
# Determine overall severity
if not issues:
severity = None
elif not severity:
severity = 'low'
return {
'is_valid': len(issues) == 0,
'issues': issues,
'severity': severity,
'mention_pos': mention_pos
}
def resolve_with_confidence(
self,
text: str,
return_clusters=True,
return_resolved_text=True
) -> Dict:
"""
Resolve coreferences and return results with confidence scores
Args:
text: Input text
return_clusters: Whether to return coreference clusters
return_resolved_text: Whether to return text with resolved pronouns
Returns:
{
'clusters': List of clusters with confidence,
'resolved_text': Text with pronouns replaced,
'low_confidence_spans': Spans that need LLM verification,
'needs_llm_fallback': Boolean indicating if LLM fallback needed
}
"""
# Get predictions
preds = self.model.predict(texts=[text])
pred = preds[0]
# Get clusters as character indices
clusters = pred.get_clusters(as_strings=False)
clusters_strings = pred.get_clusters(as_strings=True)
# Calculate confidence for each cluster
clusters_with_confidence = []
low_confidence_spans = []
for i, (cluster_indices, cluster_strings) in enumerate(zip(clusters, clusters_strings)):
# Get pairwise logits within cluster
logits = []
for j in range(len(cluster_indices) - 1):
span_i = cluster_indices[j]
span_j = cluster_indices[j + 1]
try:
logit = pred.get_logit(span_i, span_j)
logits.append(logit)
except:
# If can't get logit, assume low confidence
logits.append(0.0)
# Calculate average confidence for cluster
avg_logit = float(sum(logits) / len(logits) if logits else 0.0)
min_logit = float(min(logits) if logits else 0.0)
# Convert logit to probability for interpretability
avg_prob = self._logit_to_prob(avg_logit)
# Validate cluster for linguistic correctness
validation = self._validate_cluster(text, cluster_strings)
# Determine if cluster needs verification using BOTH thresholds AND validation
# Fail if ANY condition is true:
# 1. Average confidence is low (overall cluster quality)
# 2. Minimum confidence is low (at least one bad pairing)
# 3. Validation fails (linguistic errors)
needs_verification = (
avg_logit < self.confidence_threshold or
min_logit < self.min_confidence_threshold or
not validation['is_valid']
)
cluster_info = {
'cluster_id': i,
'mentions': cluster_strings,
'spans': cluster_indices,
'avg_confidence': avg_logit,
'min_confidence': min_logit,
'avg_probability': avg_prob,
'validation': validation,
'is_confident': not needs_verification,
'reason': self._get_confidence_reason(avg_logit, min_logit, validation)
}
clusters_with_confidence.append(cluster_info)
# Track low confidence clusters
if needs_verification:
low_confidence_spans.append(cluster_info)
# Generate resolved text (replace pronouns with main mentions)
resolved_text = self._generate_resolved_text(text, clusters, clusters_strings)
# Determine if LLM fallback is needed
needs_llm = len(low_confidence_spans) > 0
return {
'original_text': text,
'clusters': clusters_with_confidence,
'resolved_text': resolved_text,
'low_confidence_spans': low_confidence_spans,
'needs_llm_fallback': needs_llm,
'num_clusters': len(clusters),
'num_low_confidence': len(low_confidence_spans),
}
def _get_confidence_reason(self, avg_logit: float, min_logit: float, validation: Dict) -> str:
"""Explain why a cluster has low confidence or failed validation"""
reasons = []
if avg_logit < self.confidence_threshold:
prob = self._logit_to_prob(avg_logit)
reasons.append(f"Low average confidence (logit {avg_logit:.2f} = {prob:.2%})")
if min_logit < self.min_confidence_threshold:
prob = self._logit_to_prob(min_logit)
reasons.append(f"Low minimum confidence (logit {min_logit:.2f} = {prob:.2%})")
if not validation['is_valid']:
for issue in validation['issues']:
reasons.append(f"Validation failed: {issue}")
if not reasons:
avg_prob = self._logit_to_prob(avg_logit)
return f"High confidence (avg logit {avg_logit:.2f} = {avg_prob:.2%}), validation passed"
return "; ".join(reasons)
def _logit_to_prob(self, logit: float) -> float:
"""Convert logit to probability using sigmoid"""
return 1 / (1 + math.exp(-logit))
def _generate_resolved_text(
self,
text: str,
clusters: List[List[Tuple[int, int]]],
clusters_strings: List[List[str]]
) -> str:
"""
Generate text with pronouns replaced by their antecedents
Args:
text: Original text
clusters: List of clusters as character indices
clusters_strings: List of clusters as strings
Returns:
Text with resolved coreferences
"""
# Create replacement map: (start, end) -> replacement_text
replacements = {}
for cluster_indices, cluster_strings in zip(clusters, clusters_strings):
if len(cluster_strings) < 2:
continue
# Use first mention as the main mention (could be improved)
main_mention = cluster_strings[0]
# Replace all subsequent mentions with main mention
for i, (start, end) in enumerate(cluster_indices[1:], 1):
mention = cluster_strings[i]
# Only replace pronouns, not full names
if self._is_pronoun(mention):
replacements[(start, end)] = main_mention
# Apply replacements (from end to start to maintain indices)
sorted_replacements = sorted(replacements.items(),
key=lambda x: x[0][0],
reverse=True)
resolved = text
for (start, end), replacement in sorted_replacements:
resolved = resolved[:start] + replacement + resolved[end:]
return resolved
def _is_pronoun(self, text: str) -> bool:
"""Simple pronoun detection"""
pronouns = {
'he', 'she', 'it', 'they', 'him', 'her', 'them',
'his', 'hers', 'its', 'their', 'theirs',
'himself', 'herself', 'itself', 'themselves'
}
return text.lower().strip() in pronouns
def resolve_with_llm_fallback(
self,
text: str,
llm_resolve_func: Optional[callable] = None
) -> Dict:
"""
Resolve coreferences with automatic LLM fallback for low confidence
Args:
text: Input text
llm_resolve_func: Function to call for LLM resolution
Should take (text, low_confidence_info) and return resolved_text
Returns:
Resolution results with LLM fallback applied if needed
"""
# First try with fastcoref
result = self.resolve_with_confidence(text)
# If low confidence and LLM function provided, use fallback
if result['needs_llm_fallback'] and llm_resolve_func:
print(f"\n⚠️ Low confidence detected for {result['num_low_confidence']} clusters")
print("🤖 Falling back to LLM for resolution...")
# Call LLM for low confidence spans
llm_result = llm_resolve_func(text, result['low_confidence_spans'])
result['llm_resolved_text'] = llm_result
result['resolution_method'] = 'hybrid (fastcoref + LLM)'
else:
result['resolution_method'] = 'fastcoref only'
return result