Subtrans / app /services /precision_patch.py
arjun-ms's picture
Initial commit: Subtrans Subtitle Pipeline
57bbccb
"""
Precision Patch: Post-transcription NER + Confidence correction service.
This service identifies proper nouns and ambiguous tokens (ORG, PRODUCT, PERSON,
GPE, LOC, CARDINAL) in transcribed text using spaCy, cross-references their
confidence against Whisper's word-level probabilities, and sends only "suspicious"
segments to the LLM for correction.
Key design decisions:
- CARDINAL is included because spaCy sometimes mis-tags unknown proper nouns
(e.g. "NowCree") as CARDINAL - we still want to catch those.
- URLs (e.g. "notebookklem.google.com") are NOT tagged by spaCy's NER at all.
They are captured separately via a regex fallback.
- The LLM correction pass is batched: all suspicious segments are sent in ONE call.
"""
import re
import spacy
# Regex to find URL-like tokens whisper may have garbled
_URL_PATTERN = re.compile(r'\b[\w.-]+\.(?:com|org|net|io|ai|google|co)\b', re.IGNORECASE)
class PrecisionPatch:
"""
Identifies and corrects low-confidence proper nouns in Whisper transcriptions.
"""
# Entity labels considered "name-like" - includes CARDINAL because spaCy
# sometimes misclassifies unknown capitalized words (like brand names) as CARDINAL.
ENTITY_LABELS = {"ORG", "PRODUCT", "PERSON", "GPE", "LOC", "CARDINAL"}
# Confidence threshold - entities below this are considered suspicious
CONFIDENCE_THRESHOLD = 0.85
def __init__(self):
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
import subprocess, sys
subprocess.run(
[sys.executable, "-m", "spacy", "download", "en_core_web_sm"],
check=True
)
self.nlp = spacy.load("en_core_web_sm")
def find_entities(self, text: str) -> list[dict]:
"""
Identify named entities AND URL-like tokens in text that could be
brand names or proper nouns worth verifying.
Args:
text: The transcript segment text.
Returns:
List of dicts with keys: text, start (char offset), end (char offset), label
"""
doc = self.nlp(text)
entities = [
{
"text": ent.text,
"start": ent.start_char,
"end": ent.end_char,
"label": ent.label_,
}
for ent in doc.ents
if ent.label_ in self.ENTITY_LABELS
]
# Regex fallback: catch URL-like tokens spaCy's NER misses entirely
seen_spans = {(e["start"], e["end"]) for e in entities}
for m in _URL_PATTERN.finditer(text):
span = (m.start(), m.end())
if span not in seen_spans:
entities.append({
"text": m.group(),
"start": m.start(),
"end": m.end(),
"label": "URL",
})
seen_spans.add(span)
return entities
def map_entities_to_confidence(self, entities: list[dict], whisper_words: list, segment_text: str) -> list[dict]:
"""
Calculates average probability for each spaCy entity based on Whisper words.
Uses character offset alignment between the text and whisper word objects.
"""
if not whisper_words:
for ent in entities:
ent["confidence"] = 0.0
return entities
# Pre-calculate char offsets for each whisper word in the segment_text
word_offsets = []
current_pos = 0
for w in whisper_words:
# Whisper words usually have leading spaces, so we find where it appears
# relative to our current position in the segment_text.
start_idx = segment_text.find(w.word, current_pos)
if start_idx == -1:
# Fallback: if not found, just assume it follows immediately
start_idx = current_pos
end_idx = start_idx + len(w.word)
word_offsets.append({
"start": start_idx,
"end": end_idx,
"prob": w.probability
})
current_pos = end_idx
for ent in entities:
overlapping_probs = []
for w_off in word_offsets:
# Check for any overlap between entity span and word span
if max(ent["start"], w_off["start"]) < min(ent["end"], w_off["end"]):
overlapping_probs.append(w_off["prob"])
if overlapping_probs:
ent["confidence"] = sum(overlapping_probs) / len(overlapping_probs)
else:
ent["confidence"] = 0.0
return entities
def get_suspicious_indices(self, segments: list) -> list[int]:
"""
Identifies indices of segments that contain low-confidence entities.
"""
suspicious_indices = []
for i, seg in enumerate(segments):
entities = self.find_entities(seg.text)
if not entities:
continue
entities = self.map_entities_to_confidence(entities, seg.words, seg.text)
is_suspicious = any(e["confidence"] < self.CONFIDENCE_THRESHOLD for e in entities)
if is_suspicious:
suspicious_indices.append(i)
return suspicious_indices
def apply_patch(self, segments: list, suspicious_indices: list[int]):
"""
Takes segments and suspicious indices, uses Gemini to correct them,
and updates segments in place. Includes surrounding context for better accuracy.
"""
if not suspicious_indices:
return segments
from app.services.translators.gemini_adapter import GeminiAdapter
gemini = GeminiAdapter()
# Build a set of indices to send, including 1 line of context
indices_to_send = set()
for idx in suspicious_indices:
if idx > 0:
indices_to_send.add(idx - 1)
indices_to_send.add(idx)
if idx < len(segments) - 1:
indices_to_send.add(idx + 1)
sorted_indices = sorted(list(indices_to_send))
original_lines = [segments[i].text for i in sorted_indices]
# Call Gemini for batch correction
corrected_lines = gemini.correct_batch(original_lines)
# Apply corrections back to segments
for i, corrected_text in zip(sorted_indices, corrected_lines):
original_text = segments[i].text
# Defensive check: If the correction is a fragment (e.g. just the word "Naukri")
# we reject it to prevent massive context loss.
# Rule: If original has > 2 words and correction has 1 word, it's likely a fragment.
orig_words = original_text.split()
corr_words = corrected_text.split()
if len(orig_words) > 2 and len(corr_words) <= 1:
print(f" ⚠️ Warning: Precision Patch rejected a fragmented response for line {i+1} to preserve context.")
continue
segments[i].text = corrected_text
return segments
def apply_precision_patch(segments: list):
"""
Convenience function to run the full Precision Patch workflow on a list of segments.
"""
patcher = PrecisionPatch()
suspicious_indices = patcher.get_suspicious_indices(segments)
if suspicious_indices:
print(f" ✨ Precision Patch: Found {len(suspicious_indices)} segments with low-confidence entities. Correcting...")
patcher.apply_patch(segments, suspicious_indices)
else:
print(" ✅ Precision Patch: No suspicious entities found.")
return segments