File size: 7,794 Bytes
57bbccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Precision Patch: Post-transcription NER + Confidence correction service.

This service identifies proper nouns and ambiguous tokens (ORG, PRODUCT, PERSON,
GPE, LOC, CARDINAL) in transcribed text using spaCy, cross-references their
confidence against Whisper's word-level probabilities, and sends only "suspicious"
segments to the LLM for correction.

Key design decisions:
- CARDINAL is included because spaCy sometimes mis-tags unknown proper nouns
  (e.g. "NowCree") as CARDINAL - we still want to catch those.
- URLs (e.g. "notebookklem.google.com") are NOT tagged by spaCy's NER at all.
  They are captured separately via a regex fallback.
- The LLM correction pass is batched: all suspicious segments are sent in ONE call.
"""
import re
import spacy


# Regex to find URL-like tokens whisper may have garbled
_URL_PATTERN = re.compile(r'\b[\w.-]+\.(?:com|org|net|io|ai|google|co)\b', re.IGNORECASE)


class PrecisionPatch:
    """
    Identifies and corrects low-confidence proper nouns in Whisper transcriptions.
    """

    # Entity labels considered "name-like" - includes CARDINAL because spaCy
    # sometimes misclassifies unknown capitalized words (like brand names) as CARDINAL.
    ENTITY_LABELS = {"ORG", "PRODUCT", "PERSON", "GPE", "LOC", "CARDINAL"}

    # Confidence threshold - entities below this are considered suspicious
    CONFIDENCE_THRESHOLD = 0.85

    def __init__(self):
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            import subprocess, sys
            subprocess.run(
                [sys.executable, "-m", "spacy", "download", "en_core_web_sm"],
                check=True
            )
            self.nlp = spacy.load("en_core_web_sm")

    def find_entities(self, text: str) -> list[dict]:
        """
        Identify named entities AND URL-like tokens in text that could be
        brand names or proper nouns worth verifying.

        Args:
            text: The transcript segment text.

        Returns:
            List of dicts with keys: text, start (char offset), end (char offset), label
        """
        doc = self.nlp(text)
        entities = [
            {
                "text": ent.text,
                "start": ent.start_char,
                "end": ent.end_char,
                "label": ent.label_,
            }
            for ent in doc.ents
            if ent.label_ in self.ENTITY_LABELS
        ]

        # Regex fallback: catch URL-like tokens spaCy's NER misses entirely
        seen_spans = {(e["start"], e["end"]) for e in entities}
        for m in _URL_PATTERN.finditer(text):
            span = (m.start(), m.end())
            if span not in seen_spans:
                entities.append({
                    "text": m.group(),
                    "start": m.start(),
                    "end": m.end(),
                    "label": "URL",
                })
                seen_spans.add(span)

        return entities

    def map_entities_to_confidence(self, entities: list[dict], whisper_words: list, segment_text: str) -> list[dict]:
        """
        Calculates average probability for each spaCy entity based on Whisper words.
        Uses character offset alignment between the text and whisper word objects.
        """
        if not whisper_words:
            for ent in entities:
                ent["confidence"] = 0.0
            return entities

        # Pre-calculate char offsets for each whisper word in the segment_text
        word_offsets = []
        current_pos = 0
        for w in whisper_words:
            # Whisper words usually have leading spaces, so we find where it appears
            # relative to our current position in the segment_text.
            start_idx = segment_text.find(w.word, current_pos)
            if start_idx == -1:
                # Fallback: if not found, just assume it follows immediately
                start_idx = current_pos
            
            end_idx = start_idx + len(w.word)
            word_offsets.append({
                "start": start_idx,
                "end": end_idx,
                "prob": w.probability
            })
            current_pos = end_idx

        for ent in entities:
            overlapping_probs = []
            for w_off in word_offsets:
                # Check for any overlap between entity span and word span
                if max(ent["start"], w_off["start"]) < min(ent["end"], w_off["end"]):
                    overlapping_probs.append(w_off["prob"])
            
            if overlapping_probs:
                ent["confidence"] = sum(overlapping_probs) / len(overlapping_probs)
            else:
                ent["confidence"] = 0.0
        
        return entities

    def get_suspicious_indices(self, segments: list) -> list[int]:
        """
        Identifies indices of segments that contain low-confidence entities.
        """
        suspicious_indices = []
        for i, seg in enumerate(segments):
            entities = self.find_entities(seg.text)
            if not entities:
                continue
                
            entities = self.map_entities_to_confidence(entities, seg.words, seg.text)
            
            is_suspicious = any(e["confidence"] < self.CONFIDENCE_THRESHOLD for e in entities)
            if is_suspicious:
                suspicious_indices.append(i)
                
        return suspicious_indices

    def apply_patch(self, segments: list, suspicious_indices: list[int]):
        """
        Takes segments and suspicious indices, uses Gemini to correct them, 
        and updates segments in place. Includes surrounding context for better accuracy.
        """
        if not suspicious_indices:
            return segments

        from app.services.translators.gemini_adapter import GeminiAdapter
        gemini = GeminiAdapter()

        # Build a set of indices to send, including 1 line of context
        indices_to_send = set()
        for idx in suspicious_indices:
            if idx > 0:
                indices_to_send.add(idx - 1)
            indices_to_send.add(idx)
            if idx < len(segments) - 1:
                indices_to_send.add(idx + 1)

        sorted_indices = sorted(list(indices_to_send))
        original_lines = [segments[i].text for i in sorted_indices]

        # Call Gemini for batch correction
        corrected_lines = gemini.correct_batch(original_lines)

        # Apply corrections back to segments
        for i, corrected_text in zip(sorted_indices, corrected_lines):
            original_text = segments[i].text
            
            # Defensive check: If the correction is a fragment (e.g. just the word "Naukri")
            # we reject it to prevent massive context loss.
            # Rule: If original has > 2 words and correction has 1 word, it's likely a fragment.
            orig_words = original_text.split()
            corr_words = corrected_text.split()
            
            if len(orig_words) > 2 and len(corr_words) <= 1:
                print(f"  ⚠️ Warning: Precision Patch rejected a fragmented response for line {i+1} to preserve context.")
                continue
                
            segments[i].text = corrected_text

        return segments

def apply_precision_patch(segments: list):
    """
    Convenience function to run the full Precision Patch workflow on a list of segments.
    """
    patcher = PrecisionPatch()
    suspicious_indices = patcher.get_suspicious_indices(segments)
    if suspicious_indices:
        print(f"  ✨ Precision Patch: Found {len(suspicious_indices)} segments with low-confidence entities. Correcting...")
        patcher.apply_patch(segments, suspicious_indices)
    else:
        print("  ✅ Precision Patch: No suspicious entities found.")
    return segments