Spaces:
Build error
Build error
| from dataclasses import dataclass | |
| from typing import Any | |
| class Word: | |
| text: str | |
| start: float | |
| end: float | |
| class Segment: | |
| speaker: str | |
| start: float | |
| end: float | |
| def _as_float(value: Any) -> float | None: | |
| try: | |
| out = float(value) | |
| except Exception: | |
| return None | |
| if out != out: # NaN | |
| return None | |
| return out | |
| def _extract_parakeet_words(parakeet_response: dict[str, Any]) -> list[Word]: | |
| raw_output = parakeet_response.get("raw_output", {}) if isinstance(parakeet_response, dict) else {} | |
| output = raw_output.get("output", {}) if isinstance(raw_output, dict) else {} | |
| timestamp = output.get("timestamp", {}) if isinstance(output, dict) else {} | |
| word_items = timestamp.get("word", []) if isinstance(timestamp, dict) else [] | |
| words: list[Word] = [] | |
| for item in word_items if isinstance(word_items, list) else []: | |
| if not isinstance(item, dict): | |
| continue | |
| text = str(item.get("word", "")).strip() | |
| start = _as_float(item.get("start")) | |
| end = _as_float(item.get("end")) | |
| if not text or start is None or end is None or end < start: | |
| continue | |
| words.append(Word(text=text, start=start, end=end)) | |
| words.sort(key=lambda w: (w.start, w.end)) | |
| return words | |
| def _extract_pyannote_segments(pyannote_response: dict[str, Any], diarization_key: str) -> list[Segment]: | |
| raw_output = pyannote_response.get("raw_output", {}) if isinstance(pyannote_response, dict) else {} | |
| stitched = raw_output.get("stitched", {}) if isinstance(raw_output, dict) else {} | |
| seg_items = [] | |
| if isinstance(stitched, dict): | |
| seg_items = stitched.get(diarization_key, []) | |
| if not seg_items and isinstance(raw_output, dict): | |
| seg_items = ( | |
| raw_output.get(diarization_key, {}).get("segments", []) | |
| if isinstance(raw_output.get(diarization_key), dict) | |
| else [] | |
| ) | |
| segments: list[Segment] = [] | |
| for item in seg_items if isinstance(seg_items, list) else []: | |
| if not isinstance(item, dict): | |
| continue | |
| speaker = str(item.get("speaker", "")).strip() or "SPEAKER_XX" | |
| start = _as_float(item.get("start")) | |
| end = _as_float(item.get("end")) | |
| if start is None or end is None or end < start: | |
| continue | |
| segments.append(Segment(speaker=speaker, start=start, end=end)) | |
| segments.sort(key=lambda s: (s.start, s.end)) | |
| return segments | |
| def _segment_distance_to_time(seg: Segment, t: float) -> float: | |
| if seg.start <= t <= seg.end: | |
| return 0.0 | |
| if t < seg.start: | |
| return seg.start - t | |
| return t - seg.end | |
| def _assign_words_to_segments(words: list[Word], segments: list[Segment]) -> list[list[Word]]: | |
| assigned: list[list[Word]] = [[] for _ in segments] | |
| if not words or not segments: | |
| return assigned | |
| seg_idx = 0 | |
| n = len(segments) | |
| for w in words: | |
| mid = (w.start + w.end) / 2.0 | |
| while seg_idx + 1 < n and segments[seg_idx].end <= mid: | |
| seg_idx += 1 | |
| candidates = {seg_idx} | |
| if seg_idx - 1 >= 0: | |
| candidates.add(seg_idx - 1) | |
| if seg_idx + 1 < n: | |
| candidates.add(seg_idx + 1) | |
| best_idx = min(candidates, key=lambda i: _segment_distance_to_time(segments[i], mid)) | |
| assigned[best_idx].append(w) | |
| return assigned | |
| def _join_words(words: list[Word]) -> str: | |
| if not words: | |
| return "" | |
| out = words[0].text | |
| for w in words[1:]: | |
| if w.text and w.text[0] in ",.!?;:)]}": | |
| out += w.text | |
| else: | |
| out += " " + w.text | |
| return out.strip() | |
| def merge_parakeet_pyannote_outputs( | |
| parakeet_response: dict[str, Any], | |
| pyannote_response: dict[str, Any], | |
| diarization_key: str = "exclusive_speaker_diarization", | |
| ) -> dict[str, Any]: | |
| words = _extract_parakeet_words(parakeet_response) | |
| segments = _extract_pyannote_segments(pyannote_response, diarization_key=diarization_key) | |
| if not words: | |
| raise ValueError("No Parakeet word-level timestamps found.") | |
| if not segments: | |
| raise ValueError(f"No Pyannote segments found for key '{diarization_key}'.") | |
| words_by_segment = _assign_words_to_segments(words, segments) | |
| turns: list[dict[str, Any]] = [] | |
| for seg, seg_words in zip(segments, words_by_segment): | |
| if not seg_words: | |
| continue | |
| text = _join_words(seg_words) | |
| if not text: | |
| continue | |
| first_word_start = seg_words[0].start | |
| last_word_end = seg_words[-1].end | |
| start = min(seg.start, first_word_start) | |
| end = max(seg.end, last_word_end) | |
| if turns and turns[-1]["speaker"] == seg.speaker: | |
| turns[-1]["end"] = round(end, 4) | |
| turns[-1]["text"] = (turns[-1]["text"] + " " + text).strip() | |
| else: | |
| turns.append( | |
| { | |
| "speaker": seg.speaker, | |
| "start": round(start, 4), | |
| "end": round(end, 4), | |
| "text": text, | |
| } | |
| ) | |
| assigned_word_count = sum(len(seg_words) for seg_words in words_by_segment) | |
| transcript_lines = [f'{t["speaker"]}: {t["text"]}' for t in turns] | |
| return { | |
| "summary": { | |
| "diarization_key_used": diarization_key, | |
| "parakeet_word_count": len(words), | |
| "pyannote_segment_count": len(segments), | |
| "turn_count": len(turns), | |
| "assigned_word_count": assigned_word_count, | |
| "unassigned_word_count": len(words) - assigned_word_count, | |
| }, | |
| "turns": turns, | |
| "transcript_text": "\n".join(transcript_lines), | |
| } | |