bhuvanabala0504's picture
Upload 2 files
8cac351 verified
import gradio as gr
import torch
import numpy as np
import datetime
from pyannote.audio import Audio
from pyannote.core import Segment
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from sklearn.cluster import AgglomerativeClustering
# Load the model (runs once when the Space starts)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = PretrainedSpeakerEmbedding(
"speechbrain/spkrec-ecapa-voxceleb",
device=device
)
audio_helper = Audio()
def time_str(secs):
return str(datetime.timedelta(seconds=round(secs)))
def process_audio(audio_file, num_speakers):
# 'audio_file' is the path to the uploaded file provided by Gradio
duration = audio_helper.get_duration(audio_file)
# 1. Extract Voiceprints (Embeddings)
step = 2.0
embeddings = []
timestamps = []
for start in np.arange(0, duration, step):
end = min(duration, start + step)
clip = Segment(start, end)
waveform, sample_rate = audio_helper.crop(audio_file, clip)
# Ensure mono for the model
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
embeddings.append(embedding_model(waveform[None]))
timestamps.append((start, end))
embeddings = np.nan_to_num(np.array(embeddings))
# 2. Perform Clustering based on user input (num_speakers)
clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
labels = clustering.labels_
# 3. Create the Output String
result = "--- SPEAKER DIARIZATION TIMELINE ---\n\n"
current_speaker = None
for i, label in enumerate(labels):
speaker_name = f"Speaker {label + 1}"
start, end = timestamps[i]
if speaker_name != current_speaker:
result += f"[{time_str(start)}] {speaker_name} starts speaking\n"
current_speaker = speaker_name
return result
# 4. Define the Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speaker Diarization Tool")
with gr.Row():
input_audio = gr.Audio(type="filepath", label="1. Upload or Record Audio")
num_spks = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="2. Number of Speakers")
btn = gr.Button("Analyze Speakers")
output_text = gr.Textbox(label="3. Diarization Results", lines=10)
btn.click(fn=process_audio, inputs=[input_audio, num_spks], outputs=output_text)
demo.launch()