| | import argparse |
| | import gc |
| | import json |
| | import os |
| | from pathlib import Path |
| | import tempfile |
| | from typing import TYPE_CHECKING, List |
| | import torch |
| |
|
| | import ffmpeg |
| |
|
| | class DiarizationEntry: |
| | def __init__(self, start, end, speaker): |
| | self.start = start |
| | self.end = end |
| | self.speaker = speaker |
| |
|
| | def __repr__(self): |
| | return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>" |
| | |
| | def toJson(self): |
| | return { |
| | "start": self.start, |
| | "end": self.end, |
| | "speaker": self.speaker |
| | } |
| |
|
| | class Diarization: |
| | def __init__(self, auth_token=None): |
| | if auth_token is None: |
| | auth_token = os.environ.get("HK_ACCESS_TOKEN") |
| | if auth_token is None: |
| | raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HK_ACCESS_TOKEN environment variable") |
| | |
| | self.auth_token = auth_token |
| | self.initialized = False |
| | self.pipeline = None |
| |
|
| | @staticmethod |
| | def has_libraries(): |
| | try: |
| | import pyannote.audio |
| | import intervaltree |
| | return True |
| | except ImportError: |
| | return False |
| |
|
| | def initialize(self): |
| | if self.initialized: |
| | return |
| | from pyannote.audio import Pipeline |
| |
|
| | self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1", use_auth_token=self.auth_token) |
| | self.initialized = True |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if device == "cuda": |
| | print("Diarization - using GPU") |
| | self.pipeline = self.pipeline.to(torch.device(0)) |
| | else: |
| | print("Diarization - using CPU") |
| |
|
| | def run(self, audio_file, **kwargs): |
| | self.initialize() |
| | audio_file_obj = Path(audio_file) |
| |
|
| | |
| | if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]: |
| | target_file = audio_file |
| | else: |
| | |
| | target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav") |
| | try: |
| | ffmpeg.input(audio_file).output(target_file, ac=1).run() |
| | except ffmpeg.Error as e: |
| | print(f"Error occurred during audio conversion: {e.stderr}") |
| |
|
| | diarization = self.pipeline(target_file, **kwargs) |
| |
|
| | if target_file != audio_file: |
| | |
| | os.remove(target_file) |
| |
|
| | |
| | for turn, _, speaker in diarization.itertracks(yield_label=True): |
| | yield DiarizationEntry(turn.start, turn.end, speaker) |
| | |
| | def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict): |
| | from intervaltree import IntervalTree |
| | result = whisper_result.copy() |
| |
|
| | |
| | tree = IntervalTree() |
| | for entry in diarization_result: |
| | tree[entry.start:entry.end] = entry |
| |
|
| | |
| | for segment in result["segments"]: |
| | segment_start = segment["start"] |
| | segment_end = segment["end"] |
| |
|
| | |
| | overlapping_speakers = tree[segment_start:segment_end] |
| |
|
| | |
| | if not overlapping_speakers: |
| | continue |
| |
|
| | |
| | longest_speaker = None |
| | longest_duration = 0 |
| | |
| | for speaker_interval in overlapping_speakers: |
| | overlap_start = max(speaker_interval.begin, segment_start) |
| | overlap_end = min(speaker_interval.end, segment_end) |
| | overlap_duration = overlap_end - overlap_start |
| |
|
| | if overlap_duration > longest_duration: |
| | longest_speaker = speaker_interval.data.speaker |
| | longest_duration = overlap_duration |
| |
|
| | |
| | segment["longest_speaker"] = longest_speaker |
| | segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers]) |
| |
|
| | |
| |
|
| | return result |
| |
|
| | def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None): |
| | if input_file is None: |
| | raise ValueError("input_file is required") |
| | if file_writer is None: |
| | raise ValueError("file_writer is required") |
| |
|
| | |
| | if output_path is None: |
| | effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension |
| | else: |
| | effective_path = output_path |
| |
|
| | with open(effective_path, 'w+', encoding="utf-8") as f: |
| | file_writer(f) |
| |
|
| | print(f"Output saved to {effective_path}") |
| |
|
| | def main(): |
| | from src.utils import write_srt |
| | from src.diarization.transcriptLoader import load_transcript |
| |
|
| | parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.') |
| | parser.add_argument('audio_file', type=str, help='Input audio file') |
| | parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file') |
| | parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)') |
| | parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)') |
| | parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)') |
| | parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)") |
| | parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers") |
| | parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers") |
| | parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers") |
| |
|
| | args = parser.parse_args() |
| |
|
| | print("\nReading whisper JSON from " + args.whisper_file) |
| |
|
| | |
| | whisper_result = load_transcript(args.whisper_file) |
| |
|
| | diarization = Diarization(auth_token=args.auth_token) |
| | diarization_result = list(diarization.run(args.audio_file, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers)) |
| |
|
| | |
| | print("Diarization result:") |
| | for entry in diarization_result: |
| | print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}") |
| |
|
| | marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result) |
| |
|
| | |
| | _write_file(args.whisper_file, args.output_json_file, ".json", |
| | lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False)) |
| |
|
| | |
| | _write_file(args.whisper_file, args.output_srt_file, ".srt", |
| | lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width)) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| | |
| | |
| | |
| | |
| |
|
| | |