File size: 2,458 Bytes
d73543f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import dotenv
from pyannote.audio import Pipeline
import torch
import torchaudio

dotenv.load_dotenv()
SUBTIFY_TOKEN = os.getenv("SUBTIFY_TOKEN")

def diarize(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list:
    """
    Diarize an audio file using Pyannote.

    Args:
        audio_path (str): The path to the audio file to diarize.

    Returns:
        list: A list of segments with start, duration, end, and speaker.
    """
    # Load audio
    waveform, sample_rate = torchaudio.load(audio_path)

    # Parameters
    params = {}
    if num_speakers > 0:
        params["num_speakers"] = num_speakers
    if min_speakers > 0:
        params["min_speakers"] = min_speakers
    if max_speakers > 0:
        params["max_speakers"] = max_speakers

    # Device
    device = torch.device(device)

    # Create pipeline
    pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=SUBTIFY_TOKEN)
    pipeline.to(device)

    # Diarize
    diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate}, **params)
    
    return diarization

def parse_rttm(rttm_string):
    """
    Parse an RTTM string into a list of segments.

    Args:
        rttm_string (str): The RTTM string to parse.

    Returns:
        list: A list of segments with start, duration, end, and speaker.
    """

    # Parse RTTM
    segments = []

    # Parse each line
    for line in rttm_string.strip().split('\n'):
        # Split line into parts
        parts = line.split()

        # Create segment
        segment = {
            'start': float(parts[3]),
            'duration': float(parts[4]),
            'end': float(parts[3]) + float(parts[4]),
            'speaker': parts[7]
        }

        # Add segment to list
        segments.append(segment)
    return segments

def diarize_audio(audio_path: str, num_speakers: int = 0, min_speakers: int = 0, max_speakers: int = 0, device: str = "cpu") -> list:
    """
    Diarize an audio file using Pyannote.

    Args:
        audio_path (str): The path to the audio file to diarize.

    Returns:
        list: A list of segments with start, duration, end, and speaker.
    """

    # Diarize
    diarization = diarize(audio_path, num_speakers, min_speakers, max_speakers, device)

    # Format diarization
    rttm_output = diarization.to_rttm()

    # Parse RTTM
    return parse_rttm(rttm_output)