import gradio as gr import torch import numpy as np import datetime from pyannote.audio import Audio from pyannote.core import Segment from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from sklearn.cluster import AgglomerativeClustering # Load the model (runs once when the Space starts) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embedding_model = PretrainedSpeakerEmbedding( "speechbrain/spkrec-ecapa-voxceleb", device=device ) audio_helper = Audio() def time_str(secs): return str(datetime.timedelta(seconds=round(secs))) def process_audio(audio_file, num_speakers): # 'audio_file' is the path to the uploaded file provided by Gradio duration = audio_helper.get_duration(audio_file) # 1. Extract Voiceprints (Embeddings) step = 2.0 embeddings = [] timestamps = [] for start in np.arange(0, duration, step): end = min(duration, start + step) clip = Segment(start, end) waveform, sample_rate = audio_helper.crop(audio_file, clip) # Ensure mono for the model if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) embeddings.append(embedding_model(waveform[None])) timestamps.append((start, end)) embeddings = np.nan_to_num(np.array(embeddings)) # 2. Perform Clustering based on user input (num_speakers) clustering = AgglomerativeClustering(num_speakers).fit(embeddings) labels = clustering.labels_ # 3. Create the Output String result = "--- SPEAKER DIARIZATION TIMELINE ---\n\n" current_speaker = None for i, label in enumerate(labels): speaker_name = f"Speaker {label + 1}" start, end = timestamps[i] if speaker_name != current_speaker: result += f"[{time_str(start)}] {speaker_name} starts speaking\n" current_speaker = speaker_name return result # 4. Define the Gradio Interface with gr.Blocks() as demo: gr.Markdown("# 🎙️ Speaker Diarization Tool") with gr.Row(): input_audio = gr.Audio(type="filepath", label="1. Upload or Record Audio") num_spks = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="2. Number of Speakers") btn = gr.Button("Analyze Speakers") output_text = gr.Textbox(label="3. Diarization Results", lines=10) btn.click(fn=process_audio, inputs=[input_audio, num_spks], outputs=output_text) demo.launch()