STBack23 commited on
Commit
8564246
·
verified ·
1 Parent(s): 3700132

Fix pyannote 4.x DiarizeOutput API (no itertracks)

Browse files
Files changed (1) hide show
  1. src/omnisub/diarize.py +19 -3
src/omnisub/diarize.py CHANGED
@@ -27,6 +27,22 @@ class SpeakerTurn:
27
  return max(0.0, self.end - self.start)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def diarize_audio(
31
  audio_path: str | Path,
32
  *,
@@ -58,11 +74,11 @@ def diarize_audio(
58
  diarization = pipeline(str(audio_path), **params)
59
 
60
  turns: List[SpeakerTurn] = []
61
- for segment, _, speaker in diarization.itertracks(yield_label=True):
62
- if max_time is not None and segment.start > max_time:
63
  continue
64
  turns.append(
65
- SpeakerTurn(speaker=str(speaker), start=float(segment.start), end=float(segment.end))
66
  )
67
  turns.sort(key=lambda t: t.start)
68
  return turns
 
27
  return max(0.0, self.end - self.start)
28
 
29
 
30
+ def _iter_diarization_turns(diarization) -> List[tuple]:
31
+ """Duyệt lượt nói — hỗ trợ pyannote 4.x (DiarizeOutput) và 3.x (Annotation)."""
32
+ # pyannote 4.x community-1: exclusive mode gọn hơn khi gán speaker cho subtitle
33
+ for attr in ("exclusive_speaker_diarization", "speaker_diarization"):
34
+ tracks = getattr(diarization, attr, None)
35
+ if tracks is not None:
36
+ return [(turn, speaker) for turn, speaker in tracks]
37
+ # pyannote 3.x legacy
38
+ if hasattr(diarization, "itertracks"):
39
+ return [
40
+ (segment, speaker)
41
+ for segment, _, speaker in diarization.itertracks(yield_label=True)
42
+ ]
43
+ raise TypeError(f"Không hiểu output diarization: {type(diarization)!r}")
44
+
45
+
46
  def diarize_audio(
47
  audio_path: str | Path,
48
  *,
 
74
  diarization = pipeline(str(audio_path), **params)
75
 
76
  turns: List[SpeakerTurn] = []
77
+ for turn, speaker in _iter_diarization_turns(diarization):
78
+ if max_time is not None and turn.start > max_time:
79
  continue
80
  turns.append(
81
+ SpeakerTurn(speaker=str(speaker), start=float(turn.start), end=float(turn.end))
82
  )
83
  turns.sort(key=lambda t: t.start)
84
  return turns