PoC_ASR_v6_dev / app /services /alignment.py
vyluong's picture
Update app/services/alignment.py
ad7b24d verified
"""
- Precision alignment service - Word-center-based speaker assignment.
- Keep softformer diarization service
- Remove diarization noise using conf + duration
- Preserve DOUBLE_TALK word by word
- Reduce transcript fragmentation
- Better KH/NV continuity
- Stable realtime transcript rendering
Merges word-level transcription with speaker diarization using precise timestamps.
"""
import logging
from pathlib import Path
from typing import List, Tuple, Optional
from dataclasses import dataclass
from collections import Counter
from app.core.config import get_settings
from app.services.transcription import WordTimestamp
from app.services.diarization import SpeakerSegment
from app.schemas.models import TranscriptSegment
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class WordWithSpeaker:
"""A word with assigned speaker."""
word: str
start: float
end: float
speaker: str
confidence: float = 1.0
class AlignmentService:
"""
Precision alignment service.
Uses word-center-based algorithm for accurate speaker-to-text mapping.
"""
CENTER_TOL = 0.18 # 180 ms
OVERLAP_TH = 0.10 # > x% segments
# diarization
DIA_MERGE_GAP = 0.35
MIN_DIAR_DURATION = 0.12
MIN_DIAR_CONFIDENCE = 0.45
# segment
PAUSE_THRESHOLD = 0.65
MAX_SEGMENT_DURATION = 12.0
# merge
MERGE_GAP = 0.55
MAX_MERGED_DURATION = 10.0
# noise
MIN_SEGMENT_DURATION = 0.35
MIN_SEGMENT_AVG_CONF = 0.28
# interruption
SHORT_INTERRUPT_MAX_WORDS = 2
SHORT_INTERRUPT_MAX_DURATION = 1.25
@staticmethod
def get_word_center(word: WordTimestamp) -> float:
"""Calculate the center time of a word."""
return (word.start + word.end) / 2
@staticmethod
def overlap_ratio(w_start, w_end, s_start, s_end):
overlap = max(0.0, min(w_end, s_end) - max(w_start, s_start))
dur = max(1e-6, w_end - w_start)
return overlap / dur
@classmethod
def clean_diarization_segments(
cls,
segments: List[SpeakerSegment],
) -> List[SpeakerSegment]:
if not segments:
return []
segments = sorted(
segments,
key=lambda x: x.start
)
cleaned = []
for seg in segments:
dur = seg.end - seg.start
conf = getattr(
seg,
"confidence",
1.0
)
# obvious diarization noise
if (
dur < cls.MIN_DIAR_DURATION
and conf < cls.MIN_DIAR_CONFIDENCE
):
continue
cleaned.append(seg)
if not cleaned:
return []
merged = [cleaned[0]]
for seg in cleaned[1:]:
prev = merged[-1]
gap = seg.start - prev.end
if (
seg.speaker == prev.speaker
and gap <= cls.DIA_MERGE_GAP
):
prev.end = max(
prev.end,
seg.end
)
if hasattr(prev, "confidence"):
prev.confidence = max(
getattr(prev, "confidence", 1.0),
getattr(seg, "confidence", 1.0)
)
else:
merged.append(seg)
return merged
# FIND SPEAKER
@classmethod
def find_speaker_center(
cls,
time: float,
speaker_segments: List[SpeakerSegment],
) -> Optional[str]:
for seg in speaker_segments:
if (
seg.start - cls.CENTER_TOL
<= time
<= seg.end + cls.CENTER_TOL
):
return seg.speaker
return None
@staticmethod
def find_closest_speaker(
time: float,
speaker_segments: List[SpeakerSegment],
) -> str:
if not speaker_segments:
return "UNKNOWN"
best_dist = float("inf")
best_spk = "UNKNOWN"
for seg in speaker_segments:
d = min(
abs(time - seg.start),
abs(time - seg.end)
)
if d < best_dist:
best_dist = d
best_spk = seg.speaker
return best_spk
# ASSIGN SPEAKER TO WORDS
@classmethod
def assign_speakers_to_words(
cls,
words: List[WordTimestamp],
speaker_segments: List[SpeakerSegment],
) -> List[WordWithSpeaker]:
words = [
w for w in words
if w.word and w.word.strip()
]
if not words:
return []
speaker_segments = cls.clean_diarization_segments(
speaker_segments
)
# fallback
if not speaker_segments:
return [
WordWithSpeaker(
word=w.word,
start=w.start,
end=w.end,
speaker="Speaker 1",
confidence=getattr(w, "confidence", 1.0)
)
for w in words
]
results = []
for word in words:
center = cls.get_word_center(word)
speaker = cls.find_speaker_center(
center,
speaker_segments
)
# overlap fallback
if speaker is None:
best_ratio = 0.0
best_spk = None
for seg in speaker_segments:
r = cls.overlap_ratio(
word.start,
word.end,
seg.start,
seg.end
)
if r > best_ratio:
best_ratio = r
best_spk = seg.speaker
if best_ratio >= cls.OVERLAP_TH:
speaker = best_spk
# nearest fallback
if speaker is None:
speaker = cls.find_closest_speaker(
center,
speaker_segments
)
results.append(
WordWithSpeaker(
word=word.word,
start=word.start,
end=word.end,
speaker=speaker,
confidence=getattr(word, "confidence", 1.0)
)
)
return results
# ========================================================
# BUILD SEGMENT
# ========================================================
@classmethod
def build_segment(
cls,
words: List[WordWithSpeaker],
) -> TranscriptSegment:
if not words:
return None
speaker_votes = [
w.speaker for w in words
]
speaker = Counter(
speaker_votes
).most_common(1)[0][0]
avg_conf = (
sum(w.confidence for w in words)
/ max(1, len(words))
)
segment = TranscriptSegment(
start=words[0].start,
end=words[-1].end,
speaker=speaker,
role="UNKNOWN",
text=" ".join(
w.word for w in words
),
)
# INTERNAL ONLY
setattr(segment, "_avg_conf", avg_conf)
setattr(segment, "_word_count", len(words))
return segment
@classmethod
def reconstruct_segments(
cls,
words_with_speakers: List[WordWithSpeaker],
) -> List[TranscriptSegment]:
if not words_with_speakers:
return []
segments = []
cur_words = [words_with_speakers[0]]
for i in range(1, len(words_with_speakers)):
prev = words_with_speakers[i - 1]
curr = words_with_speakers[i]
pause = curr.start - prev.end
speaker_changed = (
curr.speaker != prev.speaker
)
long_pause = (
pause > cls.PAUSE_THRESHOLD
)
current_duration = (
cur_words[-1].end
- cur_words[0].start
)
too_long = (
current_duration > cls.MAX_SEGMENT_DURATION
and pause > 0.25
)
# =================================================
# SHORT INTERRUPTION
# =================================================
if speaker_changed:
lookahead = []
for j in range(
i,
min(i + 3, len(words_with_speakers))
):
lookahead.append(
words_with_speakers[j]
)
interrupt_duration = (
lookahead[-1].end
- lookahead[0].start
)
interrupt_speakers = [
x.speaker
for x in lookahead
]
interrupt_same = (
len(set(interrupt_speakers)) == 1
)
tiny_interrupt = (
interrupt_same
and len(lookahead)
<= cls.SHORT_INTERRUPT_MAX_WORDS
and interrupt_duration
<= cls.SHORT_INTERRUPT_MAX_DURATION
)
# preserve continuity
if tiny_interrupt:
cur_words.append(curr)
continue
# real speaker switch
segments.append(
cls.build_segment(cur_words)
)
cur_words = [curr]
continue
# =================================================
# SPLIT
# =================================================
if long_pause or too_long:
segments.append(
cls.build_segment(cur_words)
)
cur_words = [curr]
else:
cur_words.append(curr)
if cur_words:
segments.append(
cls.build_segment(cur_words)
)
return segments
# ========================================================
# FILTER NOISE
# ========================================================
@classmethod
def filter_noise_segments(
cls,
segments: List[TranscriptSegment],
) -> List[TranscriptSegment]:
filtered = []
for seg in segments:
duration = seg.end - seg.start
avg_conf = getattr(
seg,
"_avg_conf",
1.0
)
word_count = getattr(
seg,
"_word_count",
len(seg.text.split())
)
# hallucination/noise
if (
duration < cls.MIN_SEGMENT_DURATION
and avg_conf < cls.MIN_SEGMENT_AVG_CONF
):
continue
# single-word garbage
if (
word_count <= 1
and avg_conf < 0.20
):
continue
filtered.append(seg)
return filtered
# ========================================================
# REDUCE FRAGMENTATION
# ========================================================
@classmethod
def resize_and_merge_segments(
cls,
segments: List[TranscriptSegment],
) -> List[TranscriptSegment]:
if not segments:
return []
segments = sorted(
segments,
key=lambda x: x.start
)
merged = [segments[0]]
for seg in segments[1:]:
prev = merged[-1]
gap = seg.start - prev.end
combined_duration = (
seg.end - prev.start
)
same_speaker = (
seg.speaker == prev.speaker
)
can_merge = (
same_speaker
and gap <= cls.MERGE_GAP
and combined_duration <= cls.MAX_MERGED_DURATION
)
if can_merge:
prev.text = (
prev.text.strip()
+ " "
+ seg.text.strip()
).strip()
prev.end = seg.end
else:
merged.append(seg)
return merged
@classmethod
def align_precision(
cls,
words: List[WordTimestamp],
speaker_segments: List[SpeakerSegment]
) -> List[TranscriptSegment]:
"""
Full precision alignment pipeline.
Args:
words: Word-level timestamps from transcription
speaker_segments: Speaker segments from diarization
Returns:
List of TranscriptSegment with proper speaker assignments
"""
# Step 1: Assign speakers to words
words_with_speakers = cls.assign_speakers_to_words(words, speaker_segments)
# Step 2: Reconstruct segments
segments = cls.reconstruct_segments(words_with_speakers)
# Step 3: Remove noise
segments = cls.filter_noise_segments(
segments
)
# Step 4: Clustering/Merging (Optimization)
segments = cls.resize_and_merge_segments(segments)
logger.info(
f"Alignment output segments = {len(segments)}"
)
return segments