File size: 13,257 Bytes
1eb49d3
 
 
 
 
 
 
 
 
 
 
3ce4f6c
1eb49d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a1a4d0
 
1eb49d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
from typing import List, Dict, Tuple, Optional
from fastcoref import LingMessCoref
import math
import bisect
import spacy

class CoreferenceResolver:
    """
    Coreference resolution with confidence scoring and LLM fallback
    """

    def __init__(self, confidence_threshold=18.0, min_confidence_threshold=6.0, use_gpu=False, enable_validation=True):
        """
        Initialize coreference resolver

        Args:
            confidence_threshold: Average logit threshold (recommend 7-10)
            min_confidence_threshold: Minimum logit threshold (recommend 3-5)
                Any cluster with min_confidence below this needs LLM verification
            use_gpu: Whether to use GPU (faster but requires CUDA)
            enable_validation: Enable linguistic validation rules (HIGHLY RECOMMENDED)
                Catches errors like verbs in noun clusters, even with high confidence

        Logit Scale Reference:
            -2: ~12% probability (probably NOT coreferent)
             0: 50% probability (neutral)
            +2: ~88% probability (probably coreferent)
            +5: 99.3% probability (very confident)
            +8: 99.97% probability (extremely confident)
            +10: 99.995% probability (near certain)
            +15: 99.9997% probability (essentially certain)
            +30+: ~100% probability (but can still be WRONG!)

        CRITICAL: High confidence doesn't guarantee correctness!
        Bigger models may have logits of 30+ but still make linguistic errors.
        Always enable validation to catch these issues.
        """
        device = 'cuda:0' if use_gpu else 'cpu'
        print(f"Loading fastcoref model on {device}...")
        self.model = LingMessCoref(device=device, nlp='en_core_web_lg')
        self.confidence_threshold = confidence_threshold
        self.min_confidence_threshold = min_confidence_threshold
        self.enable_validation = enable_validation

        # Load spaCy for validation if enabled
        if enable_validation:
            try:
                self.nlp = spacy.load('en_core_web_lg')
            except:
                print("Warning: spaCy model not found. Install with: python -m spacy download en_core_web_lg")
                self.nlp = None
        else:
            self.nlp = None

    def _validate_cluster(self, text: str, cluster_strings: List[str]) -> Dict:
        """
        Validate a coreference cluster for linguistic correctness

        Returns:
            {
                'is_valid': bool,
                'issues': List[str],
                'severity': 'high' | 'medium' | 'low'
            }
        """
        if not self.enable_validation or not self.nlp:
            return {'is_valid': True, 'issues': [], 'severity': None}

        issues = []
        self.doc = self.nlp(text)

        # Extract POS tags for each mention
        mention_pos = {}
        for mention in cluster_strings:
            # Find mention in self.doc
            mention_doc = self.nlp(mention)
            if len(mention_doc) > 0:
                # Get the head word's POS
                head_pos = mention_doc[-1].pos_  # Last word usually the head
                mention_pos[mention] = head_pos

        # Rule 1: No verbs in noun coreference clusters
        has_verb = any(pos == 'VERB' for pos in mention_pos.values())
        has_noun = any(pos in ['NOUN', 'PROPN', 'PRON'] for pos in mention_pos.values())

        if has_verb and has_noun:
            issues.append(f"Cluster contains both VERBs and NOUNs: {mention_pos}")
            severity = 'high'

        # Rule 2: Check for mixed entity types (if spaCy NER available)
        entity_types = set()
        for ent in self.doc.ents:
            for mention in cluster_strings:
                if mention.lower() in ent.text.lower():
                    entity_types.add(ent.label_)

        # Incompatible entity types
        incompatible = [
            ({'PERSON'}, {'ORG', 'GPE'}),
            ({'ORG'}, {'PERSON'}),
            ({'DATE'}, {'PERSON', 'ORG'}),
        ]

        for incomp_set1, incomp_set2 in incompatible:
            if entity_types & incomp_set1 and entity_types & incomp_set2:
                issues.append(f"Incompatible entity types: {entity_types}")
                severity = 'high'

        # Rule 3: Pronouns should not cluster with verbs
        pronouns = {'he', 'she', 'it', 'they', 'him', 'her', 'them', 'his', 'hers', 'its', 'their'}
        has_pronoun = any(m.lower() in pronouns for m in cluster_strings)

        if has_pronoun and has_verb:
            issues.append(f"Pronoun clustered with VERB")
            severity = 'high'

        # Determine overall severity
        if not issues:
            severity = None
        elif not severity:
            severity = 'low'

        return {
            'is_valid': len(issues) == 0,
            'issues': issues,
            'severity': severity,
            'mention_pos': mention_pos
        }

    def resolve_with_confidence(
        self,
        text: str,
        return_clusters=True,
        return_resolved_text=True
    ) -> Dict:
        """
        Resolve coreferences and return results with confidence scores

        Args:
            text: Input text
            return_clusters: Whether to return coreference clusters
            return_resolved_text: Whether to return text with resolved pronouns

        Returns:
            {
                'clusters': List of clusters with confidence,
                'resolved_text': Text with pronouns replaced,
                'low_confidence_spans': Spans that need LLM verification,
                'needs_llm_fallback': Boolean indicating if LLM fallback needed
            }
        """
        # Get predictions
        preds = self.model.predict(texts=[text])
        pred = preds[0]

        # Get clusters as character indices
        clusters = pred.get_clusters(as_strings=False)
        clusters_strings = pred.get_clusters(as_strings=True)

        # Calculate confidence for each cluster
        clusters_with_confidence = []
        low_confidence_spans = []

        for i, (cluster_indices, cluster_strings) in enumerate(zip(clusters, clusters_strings)):
            # Get pairwise logits within cluster
            logits = []
            for j in range(len(cluster_indices) - 1):
                span_i = cluster_indices[j]
                span_j = cluster_indices[j + 1]

                try:
                    logit = pred.get_logit(span_i, span_j)
                    logits.append(logit)
                except:
                    # If can't get logit, assume low confidence
                    logits.append(0.0)

            # Calculate average confidence for cluster
            avg_logit = float(sum(logits) / len(logits) if logits else 0.0)
            min_logit = float(min(logits) if logits else 0.0)

            # Convert logit to probability for interpretability
            avg_prob = self._logit_to_prob(avg_logit)

            # Validate cluster for linguistic correctness
            validation = self._validate_cluster(text, cluster_strings)

            # Determine if cluster needs verification using BOTH thresholds AND validation
            # Fail if ANY condition is true:
            # 1. Average confidence is low (overall cluster quality)
            # 2. Minimum confidence is low (at least one bad pairing)
            # 3. Validation fails (linguistic errors)
            needs_verification = (
                avg_logit < self.confidence_threshold or
                min_logit < self.min_confidence_threshold or
                not validation['is_valid']
            )

            cluster_info = {
                'cluster_id': i,
                'mentions': cluster_strings,
                'spans': cluster_indices,
                'avg_confidence': avg_logit,
                'min_confidence': min_logit,
                'avg_probability': avg_prob,
                'validation': validation,
                'is_confident': not needs_verification,
                'reason': self._get_confidence_reason(avg_logit, min_logit, validation)
            }

            clusters_with_confidence.append(cluster_info)

            # Track low confidence clusters
            if needs_verification:
                low_confidence_spans.append(cluster_info)

        # Generate resolved text (replace pronouns with main mentions)
        resolved_text = self._generate_resolved_text(text, clusters, clusters_strings)

        # Determine if LLM fallback is needed
        needs_llm = len(low_confidence_spans) > 0

        return {
            'original_text': text,
            'clusters': clusters_with_confidence,
            'resolved_text': resolved_text,
            'low_confidence_spans': low_confidence_spans,
            'needs_llm_fallback': needs_llm,
            'num_clusters': len(clusters),
            'num_low_confidence': len(low_confidence_spans),
        }

    def _get_confidence_reason(self, avg_logit: float, min_logit: float, validation: Dict) -> str:
        """Explain why a cluster has low confidence or failed validation"""
        reasons = []

        if avg_logit < self.confidence_threshold:
            prob = self._logit_to_prob(avg_logit)
            reasons.append(f"Low average confidence (logit {avg_logit:.2f} = {prob:.2%})")

        if min_logit < self.min_confidence_threshold:
            prob = self._logit_to_prob(min_logit)
            reasons.append(f"Low minimum confidence (logit {min_logit:.2f} = {prob:.2%})")

        if not validation['is_valid']:
            for issue in validation['issues']:
                reasons.append(f"Validation failed: {issue}")

        if not reasons:
            avg_prob = self._logit_to_prob(avg_logit)
            return f"High confidence (avg logit {avg_logit:.2f} = {avg_prob:.2%}), validation passed"

        return "; ".join(reasons)

    def _logit_to_prob(self, logit: float) -> float:
        """Convert logit to probability using sigmoid"""
        return 1 / (1 + math.exp(-logit))

    def _generate_resolved_text(
        self,
        text: str,
        clusters: List[List[Tuple[int, int]]],
        clusters_strings: List[List[str]]
    ) -> str:
        """
        Generate text with pronouns replaced by their antecedents

        Args:
            text: Original text
            clusters: List of clusters as character indices
            clusters_strings: List of clusters as strings

        Returns:
            Text with resolved coreferences
        """
        # Create replacement map: (start, end) -> replacement_text
        replacements = {}

        for cluster_indices, cluster_strings in zip(clusters, clusters_strings):
            if len(cluster_strings) < 2:
                continue

            # Use first mention as the main mention (could be improved)
            main_mention = cluster_strings[0]

            # Replace all subsequent mentions with main mention
            for i, (start, end) in enumerate(cluster_indices[1:], 1):
                mention = cluster_strings[i]
                # Only replace pronouns, not full names
                if self._is_pronoun(mention):
                    replacements[(start, end)] = main_mention

        # Apply replacements (from end to start to maintain indices)
        sorted_replacements = sorted(replacements.items(),
                                    key=lambda x: x[0][0],
                                    reverse=True)

        resolved = text
        for (start, end), replacement in sorted_replacements:
            resolved = resolved[:start] + replacement + resolved[end:]

        return resolved

    def _is_pronoun(self, text: str) -> bool:
        """Simple pronoun detection"""
        pronouns = {
            'he', 'she', 'it', 'they', 'him', 'her', 'them',
            'his', 'hers', 'its', 'their', 'theirs',
            'himself', 'herself', 'itself', 'themselves'
        }
        return text.lower().strip() in pronouns

    def resolve_with_llm_fallback(
        self,
        text: str,
        llm_resolve_func: Optional[callable] = None
    ) -> Dict:
        """
        Resolve coreferences with automatic LLM fallback for low confidence

        Args:
            text: Input text
            llm_resolve_func: Function to call for LLM resolution
                             Should take (text, low_confidence_info) and return resolved_text

        Returns:
            Resolution results with LLM fallback applied if needed
        """
        # First try with fastcoref
        result = self.resolve_with_confidence(text)

        # If low confidence and LLM function provided, use fallback
        if result['needs_llm_fallback'] and llm_resolve_func:
            print(f"\n⚠️  Low confidence detected for {result['num_low_confidence']} clusters")
            print("🤖 Falling back to LLM for resolution...")

            # Call LLM for low confidence spans
            llm_result = llm_resolve_func(text, result['low_confidence_spans'])

            result['llm_resolved_text'] = llm_result
            result['resolution_method'] = 'hybrid (fastcoref + LLM)'
        else:
            result['resolution_method'] = 'fastcoref only'

        return result