sohom004 commited on
Commit
1eb49d3
·
verified ·
1 Parent(s): 4e9f463

Create coreference_resolution.py

Browse files
Files changed (1) hide show
  1. coreference_resolution.py +346 -0
coreference_resolution.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple, Optional
2
+ from fastcoref import LingMessCoref
3
+ import math
4
+ import bisect
5
+ import spacy
6
+
7
+ class CoreferenceResolver:
8
+ """
9
+ Coreference resolution with confidence scoring and LLM fallback
10
+ """
11
+
12
+ def __init__(self, confidence_threshold=30.0, min_confidence_threshold=20.0, use_gpu=False, enable_validation=True):
13
+ """
14
+ Initialize coreference resolver
15
+
16
+ Args:
17
+ confidence_threshold: Average logit threshold (recommend 7-10)
18
+ min_confidence_threshold: Minimum logit threshold (recommend 3-5)
19
+ Any cluster with min_confidence below this needs LLM verification
20
+ use_gpu: Whether to use GPU (faster but requires CUDA)
21
+ enable_validation: Enable linguistic validation rules (HIGHLY RECOMMENDED)
22
+ Catches errors like verbs in noun clusters, even with high confidence
23
+
24
+ Logit Scale Reference:
25
+ -2: ~12% probability (probably NOT coreferent)
26
+ 0: 50% probability (neutral)
27
+ +2: ~88% probability (probably coreferent)
28
+ +5: 99.3% probability (very confident)
29
+ +8: 99.97% probability (extremely confident)
30
+ +10: 99.995% probability (near certain)
31
+ +15: 99.9997% probability (essentially certain)
32
+ +30+: ~100% probability (but can still be WRONG!)
33
+
34
+ CRITICAL: High confidence doesn't guarantee correctness!
35
+ Bigger models may have logits of 30+ but still make linguistic errors.
36
+ Always enable validation to catch these issues.
37
+ """
38
+ device = 'cuda:0' if use_gpu else 'cpu'
39
+ print(f"Loading fastcoref model on {device}...")
40
+ self.model = LingMessCoref(device=device, nlp='en_core_web_lg')
41
+ self.confidence_threshold = confidence_threshold
42
+ self.min_confidence_threshold = min_confidence_threshold
43
+ self.enable_validation = enable_validation
44
+
45
+ # Load spaCy for validation if enabled
46
+ if enable_validation:
47
+ try:
48
+ self.nlp = spacy.load('en_core_web_lg')
49
+ except:
50
+ print("Warning: spaCy model not found. Install with: python -m spacy download en_core_web_lg")
51
+ self.nlp = None
52
+ else:
53
+ self.nlp = None
54
+
55
+ def _validate_cluster(self, text: str, cluster_strings: List[str]) -> Dict:
56
+ """
57
+ Validate a coreference cluster for linguistic correctness
58
+
59
+ Returns:
60
+ {
61
+ 'is_valid': bool,
62
+ 'issues': List[str],
63
+ 'severity': 'high' | 'medium' | 'low'
64
+ }
65
+ """
66
+ if not self.enable_validation or not self.nlp:
67
+ return {'is_valid': True, 'issues': [], 'severity': None}
68
+
69
+ issues = []
70
+ self.doc = self.nlp(text)
71
+
72
+ # Extract POS tags for each mention
73
+ mention_pos = {}
74
+ for mention in cluster_strings:
75
+ # Find mention in self.doc
76
+ mention_doc = self.nlp(mention)
77
+ if len(mention_doc) > 0:
78
+ # Get the head word's POS
79
+ head_pos = mention_doc[-1].pos_ # Last word usually the head
80
+ mention_pos[mention] = head_pos
81
+
82
+ # Rule 1: No verbs in noun coreference clusters
83
+ has_verb = any(pos == 'VERB' for pos in mention_pos.values())
84
+ has_noun = any(pos in ['NOUN', 'PROPN', 'PRON'] for pos in mention_pos.values())
85
+
86
+ if has_verb and has_noun:
87
+ issues.append(f"Cluster contains both VERBs and NOUNs: {mention_pos}")
88
+ severity = 'high'
89
+
90
+ # Rule 2: Check for mixed entity types (if spaCy NER available)
91
+ entity_types = set()
92
+ for ent in self.doc.ents:
93
+ for mention in cluster_strings:
94
+ if mention.lower() in ent.text.lower():
95
+ entity_types.add(ent.label_)
96
+
97
+ # Incompatible entity types
98
+ incompatible = [
99
+ ({'PERSON'}, {'ORG', 'GPE'}),
100
+ ({'ORG'}, {'PERSON'}),
101
+ ({'DATE'}, {'PERSON', 'ORG'}),
102
+ ]
103
+
104
+ for incomp_set1, incomp_set2 in incompatible:
105
+ if entity_types & incomp_set1 and entity_types & incomp_set2:
106
+ issues.append(f"Incompatible entity types: {entity_types}")
107
+ severity = 'high'
108
+
109
+ # Rule 3: Pronouns should not cluster with verbs
110
+ pronouns = {'he', 'she', 'it', 'they', 'him', 'her', 'them', 'his', 'hers', 'its', 'their'}
111
+ has_pronoun = any(m.lower() in pronouns for m in cluster_strings)
112
+
113
+ if has_pronoun and has_verb:
114
+ issues.append(f"Pronoun clustered with VERB")
115
+ severity = 'high'
116
+
117
+ # Determine overall severity
118
+ if not issues:
119
+ severity = None
120
+ elif not severity:
121
+ severity = 'low'
122
+
123
+ return {
124
+ 'is_valid': len(issues) == 0,
125
+ 'issues': issues,
126
+ 'severity': severity,
127
+ 'mention_pos': mention_pos
128
+ }
129
+
130
+ def resolve_with_confidence(
131
+ self,
132
+ text: str,
133
+ return_clusters=True,
134
+ return_resolved_text=True
135
+ ) -> Dict:
136
+ """
137
+ Resolve coreferences and return results with confidence scores
138
+
139
+ Args:
140
+ text: Input text
141
+ return_clusters: Whether to return coreference clusters
142
+ return_resolved_text: Whether to return text with resolved pronouns
143
+
144
+ Returns:
145
+ {
146
+ 'clusters': List of clusters with confidence,
147
+ 'resolved_text': Text with pronouns replaced,
148
+ 'low_confidence_spans': Spans that need LLM verification,
149
+ 'needs_llm_fallback': Boolean indicating if LLM fallback needed
150
+ }
151
+ """
152
+ # Get predictions
153
+ preds = self.model.predict(texts=[text])
154
+ pred = preds[0]
155
+
156
+ # Get clusters as character indices
157
+ clusters = pred.get_clusters(as_strings=False)
158
+ clusters_strings = pred.get_clusters(as_strings=True)
159
+
160
+ # Calculate confidence for each cluster
161
+ clusters_with_confidence = []
162
+ low_confidence_spans = []
163
+
164
+ for i, (cluster_indices, cluster_strings) in enumerate(zip(clusters, clusters_strings)):
165
+ # Get pairwise logits within cluster
166
+ logits = []
167
+ for j in range(len(cluster_indices) - 1):
168
+ span_i = cluster_indices[j]
169
+ span_j = cluster_indices[j + 1]
170
+
171
+ try:
172
+ logit = pred.get_logit(span_i, span_j)
173
+ logits.append(logit)
174
+ except:
175
+ # If can't get logit, assume low confidence
176
+ logits.append(0.0)
177
+
178
+ # Calculate average confidence for cluster
179
+ avg_logit = sum(logits) / len(logits) if logits else 0.0
180
+ min_logit = min(logits) if logits else 0.0
181
+
182
+ # Convert logit to probability for interpretability
183
+ avg_prob = self._logit_to_prob(avg_logit)
184
+
185
+ # Validate cluster for linguistic correctness
186
+ validation = self._validate_cluster(text, cluster_strings)
187
+
188
+ # Determine if cluster needs verification using BOTH thresholds AND validation
189
+ # Fail if ANY condition is true:
190
+ # 1. Average confidence is low (overall cluster quality)
191
+ # 2. Minimum confidence is low (at least one bad pairing)
192
+ # 3. Validation fails (linguistic errors)
193
+ needs_verification = (
194
+ avg_logit < self.confidence_threshold or
195
+ min_logit < self.min_confidence_threshold or
196
+ not validation['is_valid']
197
+ )
198
+
199
+ cluster_info = {
200
+ 'cluster_id': i,
201
+ 'mentions': cluster_strings,
202
+ 'spans': cluster_indices,
203
+ 'avg_confidence': avg_logit,
204
+ 'min_confidence': min_logit,
205
+ 'avg_probability': avg_prob,
206
+ 'validation': validation,
207
+ 'is_confident': not needs_verification,
208
+ 'reason': self._get_confidence_reason(avg_logit, min_logit, validation)
209
+ }
210
+
211
+ clusters_with_confidence.append(cluster_info)
212
+
213
+ # Track low confidence clusters
214
+ if needs_verification:
215
+ low_confidence_spans.append(cluster_info)
216
+
217
+ # Generate resolved text (replace pronouns with main mentions)
218
+ resolved_text = self._generate_resolved_text(text, clusters, clusters_strings)
219
+
220
+ # Determine if LLM fallback is needed
221
+ needs_llm = len(low_confidence_spans) > 0
222
+
223
+ return {
224
+ 'original_text': text,
225
+ 'clusters': clusters_with_confidence,
226
+ 'resolved_text': resolved_text,
227
+ 'low_confidence_spans': low_confidence_spans,
228
+ 'needs_llm_fallback': needs_llm,
229
+ 'num_clusters': len(clusters),
230
+ 'num_low_confidence': len(low_confidence_spans),
231
+ 'preds': preds,
232
+ }
233
+
234
+ def _get_confidence_reason(self, avg_logit: float, min_logit: float, validation: Dict) -> str:
235
+ """Explain why a cluster has low confidence or failed validation"""
236
+ reasons = []
237
+
238
+ if avg_logit < self.confidence_threshold:
239
+ prob = self._logit_to_prob(avg_logit)
240
+ reasons.append(f"Low average confidence (logit {avg_logit:.2f} = {prob:.2%})")
241
+
242
+ if min_logit < self.min_confidence_threshold:
243
+ prob = self._logit_to_prob(min_logit)
244
+ reasons.append(f"Low minimum confidence (logit {min_logit:.2f} = {prob:.2%})")
245
+
246
+ if not validation['is_valid']:
247
+ for issue in validation['issues']:
248
+ reasons.append(f"Validation failed: {issue}")
249
+
250
+ if not reasons:
251
+ avg_prob = self._logit_to_prob(avg_logit)
252
+ return f"High confidence (avg logit {avg_logit:.2f} = {avg_prob:.2%}), validation passed"
253
+
254
+ return "; ".join(reasons)
255
+
256
+ def _logit_to_prob(self, logit: float) -> float:
257
+ """Convert logit to probability using sigmoid"""
258
+ return 1 / (1 + math.exp(-logit))
259
+
260
+ def _generate_resolved_text(
261
+ self,
262
+ text: str,
263
+ clusters: List[List[Tuple[int, int]]],
264
+ clusters_strings: List[List[str]]
265
+ ) -> str:
266
+ """
267
+ Generate text with pronouns replaced by their antecedents
268
+
269
+ Args:
270
+ text: Original text
271
+ clusters: List of clusters as character indices
272
+ clusters_strings: List of clusters as strings
273
+
274
+ Returns:
275
+ Text with resolved coreferences
276
+ """
277
+ # Create replacement map: (start, end) -> replacement_text
278
+ replacements = {}
279
+
280
+ for cluster_indices, cluster_strings in zip(clusters, clusters_strings):
281
+ if len(cluster_strings) < 2:
282
+ continue
283
+
284
+ # Use first mention as the main mention (could be improved)
285
+ main_mention = cluster_strings[0]
286
+
287
+ # Replace all subsequent mentions with main mention
288
+ for i, (start, end) in enumerate(cluster_indices[1:], 1):
289
+ mention = cluster_strings[i]
290
+ # Only replace pronouns, not full names
291
+ if self._is_pronoun(mention):
292
+ replacements[(start, end)] = main_mention
293
+
294
+ # Apply replacements (from end to start to maintain indices)
295
+ sorted_replacements = sorted(replacements.items(),
296
+ key=lambda x: x[0][0],
297
+ reverse=True)
298
+
299
+ resolved = text
300
+ for (start, end), replacement in sorted_replacements:
301
+ resolved = resolved[:start] + replacement + resolved[end:]
302
+
303
+ return resolved
304
+
305
+ def _is_pronoun(self, text: str) -> bool:
306
+ """Simple pronoun detection"""
307
+ pronouns = {
308
+ 'he', 'she', 'it', 'they', 'him', 'her', 'them',
309
+ 'his', 'hers', 'its', 'their', 'theirs',
310
+ 'himself', 'herself', 'itself', 'themselves'
311
+ }
312
+ return text.lower().strip() in pronouns
313
+
314
+ def resolve_with_llm_fallback(
315
+ self,
316
+ text: str,
317
+ llm_resolve_func: Optional[callable] = None
318
+ ) -> Dict:
319
+ """
320
+ Resolve coreferences with automatic LLM fallback for low confidence
321
+
322
+ Args:
323
+ text: Input text
324
+ llm_resolve_func: Function to call for LLM resolution
325
+ Should take (text, low_confidence_info) and return resolved_text
326
+
327
+ Returns:
328
+ Resolution results with LLM fallback applied if needed
329
+ """
330
+ # First try with fastcoref
331
+ result = self.resolve_with_confidence(text)
332
+
333
+ # If low confidence and LLM function provided, use fallback
334
+ if result['needs_llm_fallback'] and llm_resolve_func:
335
+ print(f"\n⚠️ Low confidence detected for {result['num_low_confidence']} clusters")
336
+ print("🤖 Falling back to LLM for resolution...")
337
+
338
+ # Call LLM for low confidence spans
339
+ llm_result = llm_resolve_func(text, result['low_confidence_spans'])
340
+
341
+ result['llm_resolved_text'] = llm_result
342
+ result['resolution_method'] = 'hybrid (fastcoref + LLM)'
343
+ else:
344
+ result['resolution_method'] = 'fastcoref only'
345
+
346
+ return result