import os import dotenv from pyannote.audio import Pipeline import torch import torchaudio dotenv.load_dotenv() SUBTIFY_TOKEN = os.getenv("SUBTIFY_TOKEN") def diarize(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list: """ Diarize an audio file using Pyannote. Args: audio_path (str): The path to the audio file to diarize. Returns: list: A list of segments with start, duration, end, and speaker. """ # Load audio waveform, sample_rate = torchaudio.load(audio_path) # Parameters params = {} if num_speakers > 0: params["num_speakers"] = num_speakers if min_speakers > 0: params["min_speakers"] = min_speakers if max_speakers > 0: params["max_speakers"] = max_speakers # Device device = torch.device(device) # Create pipeline pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=SUBTIFY_TOKEN) pipeline.to(device) # Diarize diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate}, **params) return diarization def parse_rttm(rttm_string): """ Parse an RTTM string into a list of segments. Args: rttm_string (str): The RTTM string to parse. Returns: list: A list of segments with start, duration, end, and speaker. """ # Parse RTTM segments = [] # Parse each line for line in rttm_string.strip().split('\n'): # Split line into parts parts = line.split() # Create segment segment = { 'start': float(parts[3]), 'duration': float(parts[4]), 'end': float(parts[3]) + float(parts[4]), 'speaker': parts[7] } # Add segment to list segments.append(segment) return segments def diarize_audio(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list: """ Diarize an audio file using Pyannote. Args: audio_path (str): The path to the audio file to diarize. Returns: list: A list of segments with start, duration, end, and speaker. """ # Diarize diarization = diarize(audio_path, num_speakers, min_speakers, max_speakers, device) # Format diarization rttm_output = diarization.to_rttm() # Parse RTTM return parse_rttm(rttm_output)