Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import whisper | |
| import datetime | |
| import torch | |
| import subprocess | |
| import os | |
| from pyannote.audio import Audio | |
| from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding | |
| from pyannote.core import Segment | |
| import wave | |
| import contextlib | |
| from sklearn.cluster import AgglomerativeClustering | |
| import numpy as np | |
| # Load Whisper model | |
| model_size = "medium.en" | |
| model = whisper.load_model(model_size) | |
| audio = Audio() | |
| embedding_model = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb", device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
| def transcribe_and_diarize(audio_file, num_speakers=2): | |
| try: | |
| path = audio_file.name | |
| # Convert to WAV if necessary | |
| if not path.endswith('.wav'): | |
| subprocess.call(['ffmpeg', '-i', path, 'audio.wav', '-y']) | |
| path = 'audio.wav' | |
| # Transcribe audio | |
| result = model.transcribe(path) | |
| segments = result["segments"] | |
| # Get audio duration | |
| with contextlib.closing(wave.open(path, 'r')) as f: | |
| frames = f.getnframes() | |
| rate = f.getframerate() | |
| duration = frames / float(rate) | |
| # Define function to extract segment embeddings | |
| def segment_embedding(segment): | |
| start = segment["start"] | |
| end = min(duration, segment["end"]) | |
| clip = Segment(start, end) | |
| waveform, sample_rate = audio.crop(path, clip) | |
| return embedding_model(waveform[None]) | |
| # Extract embeddings for each segment | |
| embeddings = np.zeros(shape=(len(segments), 192)) | |
| for i, segment in enumerate(segments): | |
| embeddings[i] = segment_embedding(segment) | |
| embeddings = np.nan_to_num(embeddings) | |
| # Perform speaker clustering | |
| clustering = AgglomerativeClustering(num_speakers).fit(embeddings) | |
| labels = clustering.labels_ | |
| for i in range(len(segments)): | |
| segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1) | |
| # Generate transcript | |
| transcript = "" | |
| for i, segment in enumerate(segments): | |
| if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: | |
| transcript += "\n" + segment["speaker"] + ' ' + str(datetime.timedelta(seconds=round(segment["start"]))) + '\n' | |
| transcript += segment["text"][1:] + ' ' | |
| transcript += "\n\n" | |
| return transcript | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| iface = gr.Interface( | |
| fn=transcribe_and_diarize, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Upload Audio File"), | |
| gr.Number(value=2, label="Number of Speakers") | |
| ], | |
| outputs="text", | |
| title="Audio Transcription and Speaker Diarization", | |
| description="Upload an audio file to get a transcription with speaker diarization." | |
| ) | |
| iface.launch() | |