transcribe-diarize / src /merge_service.py
Ratnesh-dev's picture
Revert repository state to c7d2aa0
ca6855a
from dataclasses import dataclass
from typing import Any
@dataclass
class Word:
text: str
start: float
end: float
@dataclass
class Segment:
speaker: str
start: float
end: float
def _as_float(value: Any) -> float | None:
try:
out = float(value)
except Exception:
return None
if out != out: # NaN
return None
return out
def _extract_parakeet_words(parakeet_response: dict[str, Any]) -> list[Word]:
raw_output = parakeet_response.get("raw_output", {}) if isinstance(parakeet_response, dict) else {}
output = raw_output.get("output", {}) if isinstance(raw_output, dict) else {}
timestamp = output.get("timestamp", {}) if isinstance(output, dict) else {}
word_items = timestamp.get("word", []) if isinstance(timestamp, dict) else []
words: list[Word] = []
for item in word_items if isinstance(word_items, list) else []:
if not isinstance(item, dict):
continue
text = str(item.get("word", "")).strip()
start = _as_float(item.get("start"))
end = _as_float(item.get("end"))
if not text or start is None or end is None or end < start:
continue
words.append(Word(text=text, start=start, end=end))
words.sort(key=lambda w: (w.start, w.end))
return words
def _extract_pyannote_segments(pyannote_response: dict[str, Any], diarization_key: str) -> list[Segment]:
raw_output = pyannote_response.get("raw_output", {}) if isinstance(pyannote_response, dict) else {}
stitched = raw_output.get("stitched", {}) if isinstance(raw_output, dict) else {}
seg_items = []
if isinstance(stitched, dict):
seg_items = stitched.get(diarization_key, [])
if not seg_items and isinstance(raw_output, dict):
seg_items = (
raw_output.get(diarization_key, {}).get("segments", [])
if isinstance(raw_output.get(diarization_key), dict)
else []
)
segments: list[Segment] = []
for item in seg_items if isinstance(seg_items, list) else []:
if not isinstance(item, dict):
continue
speaker = str(item.get("speaker", "")).strip() or "SPEAKER_XX"
start = _as_float(item.get("start"))
end = _as_float(item.get("end"))
if start is None or end is None or end < start:
continue
segments.append(Segment(speaker=speaker, start=start, end=end))
segments.sort(key=lambda s: (s.start, s.end))
return segments
def _segment_distance_to_time(seg: Segment, t: float) -> float:
if seg.start <= t <= seg.end:
return 0.0
if t < seg.start:
return seg.start - t
return t - seg.end
def _assign_words_to_segments(words: list[Word], segments: list[Segment]) -> list[list[Word]]:
assigned: list[list[Word]] = [[] for _ in segments]
if not words or not segments:
return assigned
seg_idx = 0
n = len(segments)
for w in words:
mid = (w.start + w.end) / 2.0
while seg_idx + 1 < n and segments[seg_idx].end <= mid:
seg_idx += 1
candidates = {seg_idx}
if seg_idx - 1 >= 0:
candidates.add(seg_idx - 1)
if seg_idx + 1 < n:
candidates.add(seg_idx + 1)
best_idx = min(candidates, key=lambda i: _segment_distance_to_time(segments[i], mid))
assigned[best_idx].append(w)
return assigned
def _join_words(words: list[Word]) -> str:
if not words:
return ""
out = words[0].text
for w in words[1:]:
if w.text and w.text[0] in ",.!?;:)]}":
out += w.text
else:
out += " " + w.text
return out.strip()
def merge_parakeet_pyannote_outputs(
parakeet_response: dict[str, Any],
pyannote_response: dict[str, Any],
diarization_key: str = "exclusive_speaker_diarization",
) -> dict[str, Any]:
words = _extract_parakeet_words(parakeet_response)
segments = _extract_pyannote_segments(pyannote_response, diarization_key=diarization_key)
if not words:
raise ValueError("No Parakeet word-level timestamps found.")
if not segments:
raise ValueError(f"No Pyannote segments found for key '{diarization_key}'.")
words_by_segment = _assign_words_to_segments(words, segments)
turns: list[dict[str, Any]] = []
for seg, seg_words in zip(segments, words_by_segment):
if not seg_words:
continue
text = _join_words(seg_words)
if not text:
continue
first_word_start = seg_words[0].start
last_word_end = seg_words[-1].end
start = min(seg.start, first_word_start)
end = max(seg.end, last_word_end)
if turns and turns[-1]["speaker"] == seg.speaker:
turns[-1]["end"] = round(end, 4)
turns[-1]["text"] = (turns[-1]["text"] + " " + text).strip()
else:
turns.append(
{
"speaker": seg.speaker,
"start": round(start, 4),
"end": round(end, 4),
"text": text,
}
)
assigned_word_count = sum(len(seg_words) for seg_words in words_by_segment)
transcript_lines = [f'{t["speaker"]}: {t["text"]}' for t in turns]
return {
"summary": {
"diarization_key_used": diarization_key,
"parakeet_word_count": len(words),
"pyannote_segment_count": len(segments),
"turn_count": len(turns),
"assigned_word_count": assigned_word_count,
"unassigned_word_count": len(words) - assigned_word_count,
},
"turns": turns,
"transcript_text": "\n".join(transcript_lines),
}