File size: 2,611 Bytes
8cac351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
import numpy as np
import datetime
from pyannote.audio import Audio
from pyannote.core import Segment
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from sklearn.cluster import AgglomerativeClustering

# Load the model (runs once when the Space starts)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = PretrainedSpeakerEmbedding(
    "speechbrain/spkrec-ecapa-voxceleb",
    device=device
)
audio_helper = Audio()

def time_str(secs):
    return str(datetime.timedelta(seconds=round(secs)))

def process_audio(audio_file, num_speakers):
    # 'audio_file' is the path to the uploaded file provided by Gradio
    duration = audio_helper.get_duration(audio_file)
    
    # 1. Extract Voiceprints (Embeddings)
    step = 2.0
    embeddings = []
    timestamps = []
    
    for start in np.arange(0, duration, step):
        end = min(duration, start + step)
        clip = Segment(start, end)
        waveform, sample_rate = audio_helper.crop(audio_file, clip)
        
        # Ensure mono for the model
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
            
        embeddings.append(embedding_model(waveform[None]))
        timestamps.append((start, end))
    
    embeddings = np.nan_to_num(np.array(embeddings))
    
    # 2. Perform Clustering based on user input (num_speakers)
    clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
    labels = clustering.labels_
    
    # 3. Create the Output String
    result = "--- SPEAKER DIARIZATION TIMELINE ---\n\n"
    current_speaker = None
    
    for i, label in enumerate(labels):
        speaker_name = f"Speaker {label + 1}"
        start, end = timestamps[i]
        
        if speaker_name != current_speaker:
            result += f"[{time_str(start)}] {speaker_name} starts speaking\n"
            current_speaker = speaker_name
            
    return result

# 4. Define the Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speaker Diarization Tool")
    
    with gr.Row():
        input_audio = gr.Audio(type="filepath", label="1. Upload or Record Audio")
        num_spks = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="2. Number of Speakers")
    
    btn = gr.Button("Analyze Speakers")
    output_text = gr.Textbox(label="3. Diarization Results", lines=10)

    btn.click(fn=process_audio, inputs=[input_audio, num_spks], outputs=output_text)

demo.launch()