File size: 2,458 Bytes
d73543f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import dotenv
from pyannote.audio import Pipeline
import torch
import torchaudio
dotenv.load_dotenv()
SUBTIFY_TOKEN = os.getenv("SUBTIFY_TOKEN")
def diarize(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list:
"""
Diarize an audio file using Pyannote.
Args:
audio_path (str): The path to the audio file to diarize.
Returns:
list: A list of segments with start, duration, end, and speaker.
"""
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Parameters
params = {}
if num_speakers > 0:
params["num_speakers"] = num_speakers
if min_speakers > 0:
params["min_speakers"] = min_speakers
if max_speakers > 0:
params["max_speakers"] = max_speakers
# Device
device = torch.device(device)
# Create pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=SUBTIFY_TOKEN)
pipeline.to(device)
# Diarize
diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate}, **params)
return diarization
def parse_rttm(rttm_string):
"""
Parse an RTTM string into a list of segments.
Args:
rttm_string (str): The RTTM string to parse.
Returns:
list: A list of segments with start, duration, end, and speaker.
"""
# Parse RTTM
segments = []
# Parse each line
for line in rttm_string.strip().split('\n'):
# Split line into parts
parts = line.split()
# Create segment
segment = {
'start': float(parts[3]),
'duration': float(parts[4]),
'end': float(parts[3]) + float(parts[4]),
'speaker': parts[7]
}
# Add segment to list
segments.append(segment)
return segments
def diarize_audio(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list:
"""
Diarize an audio file using Pyannote.
Args:
audio_path (str): The path to the audio file to diarize.
Returns:
list: A list of segments with start, duration, end, and speaker.
"""
# Diarize
diarization = diarize(audio_path, num_speakers, min_speakers, max_speakers, device)
# Format diarization
rttm_output = diarization.to_rttm()
# Parse RTTM
return parse_rttm(rttm_output)
|