|
|
from typing import List, Dict, Tuple, Optional |
|
|
from fastcoref import LingMessCoref |
|
|
import math |
|
|
import bisect |
|
|
import spacy |
|
|
|
|
|
class CoreferenceResolver: |
|
|
""" |
|
|
Coreference resolution with confidence scoring and LLM fallback |
|
|
""" |
|
|
|
|
|
def __init__(self, confidence_threshold=18.0, min_confidence_threshold=6.0, use_gpu=False, enable_validation=True): |
|
|
""" |
|
|
Initialize coreference resolver |
|
|
|
|
|
Args: |
|
|
confidence_threshold: Average logit threshold (recommend 7-10) |
|
|
min_confidence_threshold: Minimum logit threshold (recommend 3-5) |
|
|
Any cluster with min_confidence below this needs LLM verification |
|
|
use_gpu: Whether to use GPU (faster but requires CUDA) |
|
|
enable_validation: Enable linguistic validation rules (HIGHLY RECOMMENDED) |
|
|
Catches errors like verbs in noun clusters, even with high confidence |
|
|
|
|
|
Logit Scale Reference: |
|
|
-2: ~12% probability (probably NOT coreferent) |
|
|
0: 50% probability (neutral) |
|
|
+2: ~88% probability (probably coreferent) |
|
|
+5: 99.3% probability (very confident) |
|
|
+8: 99.97% probability (extremely confident) |
|
|
+10: 99.995% probability (near certain) |
|
|
+15: 99.9997% probability (essentially certain) |
|
|
+30+: ~100% probability (but can still be WRONG!) |
|
|
|
|
|
CRITICAL: High confidence doesn't guarantee correctness! |
|
|
Bigger models may have logits of 30+ but still make linguistic errors. |
|
|
Always enable validation to catch these issues. |
|
|
""" |
|
|
device = 'cuda:0' if use_gpu else 'cpu' |
|
|
print(f"Loading fastcoref model on {device}...") |
|
|
self.model = LingMessCoref(device=device, nlp='en_core_web_lg') |
|
|
self.confidence_threshold = confidence_threshold |
|
|
self.min_confidence_threshold = min_confidence_threshold |
|
|
self.enable_validation = enable_validation |
|
|
|
|
|
|
|
|
if enable_validation: |
|
|
try: |
|
|
self.nlp = spacy.load('en_core_web_lg') |
|
|
except: |
|
|
print("Warning: spaCy model not found. Install with: python -m spacy download en_core_web_lg") |
|
|
self.nlp = None |
|
|
else: |
|
|
self.nlp = None |
|
|
|
|
|
def _validate_cluster(self, text: str, cluster_strings: List[str]) -> Dict: |
|
|
""" |
|
|
Validate a coreference cluster for linguistic correctness |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
'is_valid': bool, |
|
|
'issues': List[str], |
|
|
'severity': 'high' | 'medium' | 'low' |
|
|
} |
|
|
""" |
|
|
if not self.enable_validation or not self.nlp: |
|
|
return {'is_valid': True, 'issues': [], 'severity': None} |
|
|
|
|
|
issues = [] |
|
|
self.doc = self.nlp(text) |
|
|
|
|
|
|
|
|
mention_pos = {} |
|
|
for mention in cluster_strings: |
|
|
|
|
|
mention_doc = self.nlp(mention) |
|
|
if len(mention_doc) > 0: |
|
|
|
|
|
head_pos = mention_doc[-1].pos_ |
|
|
mention_pos[mention] = head_pos |
|
|
|
|
|
|
|
|
has_verb = any(pos == 'VERB' for pos in mention_pos.values()) |
|
|
has_noun = any(pos in ['NOUN', 'PROPN', 'PRON'] for pos in mention_pos.values()) |
|
|
|
|
|
if has_verb and has_noun: |
|
|
issues.append(f"Cluster contains both VERBs and NOUNs: {mention_pos}") |
|
|
severity = 'high' |
|
|
|
|
|
|
|
|
entity_types = set() |
|
|
for ent in self.doc.ents: |
|
|
for mention in cluster_strings: |
|
|
if mention.lower() in ent.text.lower(): |
|
|
entity_types.add(ent.label_) |
|
|
|
|
|
|
|
|
incompatible = [ |
|
|
({'PERSON'}, {'ORG', 'GPE'}), |
|
|
({'ORG'}, {'PERSON'}), |
|
|
({'DATE'}, {'PERSON', 'ORG'}), |
|
|
] |
|
|
|
|
|
for incomp_set1, incomp_set2 in incompatible: |
|
|
if entity_types & incomp_set1 and entity_types & incomp_set2: |
|
|
issues.append(f"Incompatible entity types: {entity_types}") |
|
|
severity = 'high' |
|
|
|
|
|
|
|
|
pronouns = {'he', 'she', 'it', 'they', 'him', 'her', 'them', 'his', 'hers', 'its', 'their'} |
|
|
has_pronoun = any(m.lower() in pronouns for m in cluster_strings) |
|
|
|
|
|
if has_pronoun and has_verb: |
|
|
issues.append(f"Pronoun clustered with VERB") |
|
|
severity = 'high' |
|
|
|
|
|
|
|
|
if not issues: |
|
|
severity = None |
|
|
elif not severity: |
|
|
severity = 'low' |
|
|
|
|
|
return { |
|
|
'is_valid': len(issues) == 0, |
|
|
'issues': issues, |
|
|
'severity': severity, |
|
|
'mention_pos': mention_pos |
|
|
} |
|
|
|
|
|
def resolve_with_confidence( |
|
|
self, |
|
|
text: str, |
|
|
return_clusters=True, |
|
|
return_resolved_text=True |
|
|
) -> Dict: |
|
|
""" |
|
|
Resolve coreferences and return results with confidence scores |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
return_clusters: Whether to return coreference clusters |
|
|
return_resolved_text: Whether to return text with resolved pronouns |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
'clusters': List of clusters with confidence, |
|
|
'resolved_text': Text with pronouns replaced, |
|
|
'low_confidence_spans': Spans that need LLM verification, |
|
|
'needs_llm_fallback': Boolean indicating if LLM fallback needed |
|
|
} |
|
|
""" |
|
|
|
|
|
preds = self.model.predict(texts=[text]) |
|
|
pred = preds[0] |
|
|
|
|
|
|
|
|
clusters = pred.get_clusters(as_strings=False) |
|
|
clusters_strings = pred.get_clusters(as_strings=True) |
|
|
|
|
|
|
|
|
clusters_with_confidence = [] |
|
|
low_confidence_spans = [] |
|
|
|
|
|
for i, (cluster_indices, cluster_strings) in enumerate(zip(clusters, clusters_strings)): |
|
|
|
|
|
logits = [] |
|
|
for j in range(len(cluster_indices) - 1): |
|
|
span_i = cluster_indices[j] |
|
|
span_j = cluster_indices[j + 1] |
|
|
|
|
|
try: |
|
|
logit = pred.get_logit(span_i, span_j) |
|
|
logits.append(logit) |
|
|
except: |
|
|
|
|
|
logits.append(0.0) |
|
|
|
|
|
|
|
|
avg_logit = float(sum(logits) / len(logits) if logits else 0.0) |
|
|
min_logit = float(min(logits) if logits else 0.0) |
|
|
|
|
|
|
|
|
avg_prob = self._logit_to_prob(avg_logit) |
|
|
|
|
|
|
|
|
validation = self._validate_cluster(text, cluster_strings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
needs_verification = ( |
|
|
avg_logit < self.confidence_threshold or |
|
|
min_logit < self.min_confidence_threshold or |
|
|
not validation['is_valid'] |
|
|
) |
|
|
|
|
|
cluster_info = { |
|
|
'cluster_id': i, |
|
|
'mentions': cluster_strings, |
|
|
'spans': cluster_indices, |
|
|
'avg_confidence': avg_logit, |
|
|
'min_confidence': min_logit, |
|
|
'avg_probability': avg_prob, |
|
|
'validation': validation, |
|
|
'is_confident': not needs_verification, |
|
|
'reason': self._get_confidence_reason(avg_logit, min_logit, validation) |
|
|
} |
|
|
|
|
|
clusters_with_confidence.append(cluster_info) |
|
|
|
|
|
|
|
|
if needs_verification: |
|
|
low_confidence_spans.append(cluster_info) |
|
|
|
|
|
|
|
|
resolved_text = self._generate_resolved_text(text, clusters, clusters_strings) |
|
|
|
|
|
|
|
|
needs_llm = len(low_confidence_spans) > 0 |
|
|
|
|
|
return { |
|
|
'original_text': text, |
|
|
'clusters': clusters_with_confidence, |
|
|
'resolved_text': resolved_text, |
|
|
'low_confidence_spans': low_confidence_spans, |
|
|
'needs_llm_fallback': needs_llm, |
|
|
'num_clusters': len(clusters), |
|
|
'num_low_confidence': len(low_confidence_spans), |
|
|
} |
|
|
|
|
|
def _get_confidence_reason(self, avg_logit: float, min_logit: float, validation: Dict) -> str: |
|
|
"""Explain why a cluster has low confidence or failed validation""" |
|
|
reasons = [] |
|
|
|
|
|
if avg_logit < self.confidence_threshold: |
|
|
prob = self._logit_to_prob(avg_logit) |
|
|
reasons.append(f"Low average confidence (logit {avg_logit:.2f} = {prob:.2%})") |
|
|
|
|
|
if min_logit < self.min_confidence_threshold: |
|
|
prob = self._logit_to_prob(min_logit) |
|
|
reasons.append(f"Low minimum confidence (logit {min_logit:.2f} = {prob:.2%})") |
|
|
|
|
|
if not validation['is_valid']: |
|
|
for issue in validation['issues']: |
|
|
reasons.append(f"Validation failed: {issue}") |
|
|
|
|
|
if not reasons: |
|
|
avg_prob = self._logit_to_prob(avg_logit) |
|
|
return f"High confidence (avg logit {avg_logit:.2f} = {avg_prob:.2%}), validation passed" |
|
|
|
|
|
return "; ".join(reasons) |
|
|
|
|
|
def _logit_to_prob(self, logit: float) -> float: |
|
|
"""Convert logit to probability using sigmoid""" |
|
|
return 1 / (1 + math.exp(-logit)) |
|
|
|
|
|
def _generate_resolved_text( |
|
|
self, |
|
|
text: str, |
|
|
clusters: List[List[Tuple[int, int]]], |
|
|
clusters_strings: List[List[str]] |
|
|
) -> str: |
|
|
""" |
|
|
Generate text with pronouns replaced by their antecedents |
|
|
|
|
|
Args: |
|
|
text: Original text |
|
|
clusters: List of clusters as character indices |
|
|
clusters_strings: List of clusters as strings |
|
|
|
|
|
Returns: |
|
|
Text with resolved coreferences |
|
|
""" |
|
|
|
|
|
replacements = {} |
|
|
|
|
|
for cluster_indices, cluster_strings in zip(clusters, clusters_strings): |
|
|
if len(cluster_strings) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
main_mention = cluster_strings[0] |
|
|
|
|
|
|
|
|
for i, (start, end) in enumerate(cluster_indices[1:], 1): |
|
|
mention = cluster_strings[i] |
|
|
|
|
|
if self._is_pronoun(mention): |
|
|
replacements[(start, end)] = main_mention |
|
|
|
|
|
|
|
|
sorted_replacements = sorted(replacements.items(), |
|
|
key=lambda x: x[0][0], |
|
|
reverse=True) |
|
|
|
|
|
resolved = text |
|
|
for (start, end), replacement in sorted_replacements: |
|
|
resolved = resolved[:start] + replacement + resolved[end:] |
|
|
|
|
|
return resolved |
|
|
|
|
|
def _is_pronoun(self, text: str) -> bool: |
|
|
"""Simple pronoun detection""" |
|
|
pronouns = { |
|
|
'he', 'she', 'it', 'they', 'him', 'her', 'them', |
|
|
'his', 'hers', 'its', 'their', 'theirs', |
|
|
'himself', 'herself', 'itself', 'themselves' |
|
|
} |
|
|
return text.lower().strip() in pronouns |
|
|
|
|
|
def resolve_with_llm_fallback( |
|
|
self, |
|
|
text: str, |
|
|
llm_resolve_func: Optional[callable] = None |
|
|
) -> Dict: |
|
|
""" |
|
|
Resolve coreferences with automatic LLM fallback for low confidence |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
llm_resolve_func: Function to call for LLM resolution |
|
|
Should take (text, low_confidence_info) and return resolved_text |
|
|
|
|
|
Returns: |
|
|
Resolution results with LLM fallback applied if needed |
|
|
""" |
|
|
|
|
|
result = self.resolve_with_confidence(text) |
|
|
|
|
|
|
|
|
if result['needs_llm_fallback'] and llm_resolve_func: |
|
|
print(f"\n⚠️ Low confidence detected for {result['num_low_confidence']} clusters") |
|
|
print("🤖 Falling back to LLM for resolution...") |
|
|
|
|
|
|
|
|
llm_result = llm_resolve_func(text, result['low_confidence_spans']) |
|
|
|
|
|
result['llm_resolved_text'] = llm_result |
|
|
result['resolution_method'] = 'hybrid (fastcoref + LLM)' |
|
|
else: |
|
|
result['resolution_method'] = 'fastcoref only' |
|
|
|
|
|
return result |