Fix pyannote 4.x DiarizeOutput API (no itertracks)
Browse files- src/omnisub/diarize.py +19 -3
src/omnisub/diarize.py
CHANGED
|
@@ -27,6 +27,22 @@ class SpeakerTurn:
|
|
| 27 |
return max(0.0, self.end - self.start)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def diarize_audio(
|
| 31 |
audio_path: str | Path,
|
| 32 |
*,
|
|
@@ -58,11 +74,11 @@ def diarize_audio(
|
|
| 58 |
diarization = pipeline(str(audio_path), **params)
|
| 59 |
|
| 60 |
turns: List[SpeakerTurn] = []
|
| 61 |
-
for
|
| 62 |
-
if max_time is not None and
|
| 63 |
continue
|
| 64 |
turns.append(
|
| 65 |
-
SpeakerTurn(speaker=str(speaker), start=float(
|
| 66 |
)
|
| 67 |
turns.sort(key=lambda t: t.start)
|
| 68 |
return turns
|
|
|
|
| 27 |
return max(0.0, self.end - self.start)
|
| 28 |
|
| 29 |
|
| 30 |
+
def _iter_diarization_turns(diarization) -> List[tuple]:
|
| 31 |
+
"""Duyệt lượt nói — hỗ trợ pyannote 4.x (DiarizeOutput) và 3.x (Annotation)."""
|
| 32 |
+
# pyannote 4.x community-1: exclusive mode gọn hơn khi gán speaker cho subtitle
|
| 33 |
+
for attr in ("exclusive_speaker_diarization", "speaker_diarization"):
|
| 34 |
+
tracks = getattr(diarization, attr, None)
|
| 35 |
+
if tracks is not None:
|
| 36 |
+
return [(turn, speaker) for turn, speaker in tracks]
|
| 37 |
+
# pyannote 3.x legacy
|
| 38 |
+
if hasattr(diarization, "itertracks"):
|
| 39 |
+
return [
|
| 40 |
+
(segment, speaker)
|
| 41 |
+
for segment, _, speaker in diarization.itertracks(yield_label=True)
|
| 42 |
+
]
|
| 43 |
+
raise TypeError(f"Không hiểu output diarization: {type(diarization)!r}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
def diarize_audio(
|
| 47 |
audio_path: str | Path,
|
| 48 |
*,
|
|
|
|
| 74 |
diarization = pipeline(str(audio_path), **params)
|
| 75 |
|
| 76 |
turns: List[SpeakerTurn] = []
|
| 77 |
+
for turn, speaker in _iter_diarization_turns(diarization):
|
| 78 |
+
if max_time is not None and turn.start > max_time:
|
| 79 |
continue
|
| 80 |
turns.append(
|
| 81 |
+
SpeakerTurn(speaker=str(speaker), start=float(turn.start), end=float(turn.end))
|
| 82 |
)
|
| 83 |
turns.sort(key=lambda t: t.start)
|
| 84 |
return turns
|