Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
| import numpy as np | |
| from pyannote.audio import Pipeline | |
| import os | |
| from dotenv import load_dotenv | |
| import plotly.graph_objects as go | |
| load_dotenv() | |
| # Check and set device | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Model and pipeline setup | |
| model_id = "distil-whisper/distil-small.en" | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
| ) | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| max_new_tokens=128, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY") | |
| ) | |
| # returns diarization info such as segment start and end times, and speaker id | |
| def diarization_info(res): | |
| starts = [] | |
| ends = [] | |
| speakers = [] | |
| for segment, _, speaker in res.itertracks(yield_label=True): | |
| starts.append(segment.start) | |
| ends.append(segment.end) | |
| speakers.append(speaker) | |
| return starts, ends, speakers | |
| # plot diarization results on a graph | |
| def plot_diarization(starts, ends, speakers): | |
| fig = go.Figure() | |
| # Define a color map for different speakers | |
| num_speakers = len(set(speakers)) | |
| colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)] | |
| # Plot each segment with its speaker's color | |
| for start, end, speaker in zip(starts, ends, speakers): | |
| speaker_id = list(set(speakers)).index(speaker) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[start, end], | |
| y=[speaker_id, speaker_id], | |
| mode="lines", | |
| line=dict(color=colors[speaker_id], width=15), | |
| showlegend=False, | |
| ) | |
| ) | |
| fig.update_layout( | |
| title="Speaker Diarization", | |
| xaxis=dict(title="Time"), | |
| yaxis=dict(title="Speaker"), | |
| height=600, | |
| width=800, | |
| ) | |
| return fig | |
| def transcribe(sr, data): | |
| processed_data = np.array(data).astype(np.float32) / 32767.0 | |
| # results from the pipeline | |
| transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"] | |
| return transcription_res | |
| def transcribe_diarize(audio): | |
| sr, data = audio | |
| processed_data = np.array(data).astype(np.float32) / 32767.0 | |
| waveform_tensor = torch.tensor(processed_data[np.newaxis, :]) | |
| transcription_res = transcribe(sr, data) | |
| # results from the diarization pipeline | |
| diarization_res = diarization_pipeline( | |
| {"waveform": waveform_tensor, "sample_rate": sr} | |
| ) | |
| # Get diarization information | |
| starts, ends, speakers = diarization_info(diarization_res) | |
| # results from the transcription pipeline | |
| diarized_transcription = "" | |
| # Get transcription results for each speaker segment | |
| for start_time, end_time, speaker_id in zip(starts, ends, speakers): | |
| segment = data[int(start_time * sr) : int(end_time * sr)] | |
| diarized_transcription += f"{speaker_id} {round(start_time, 2)}:{round(end_time, 2)} \t {transcribe(sr, segment)}\n" | |
| # Plot diarization | |
| diarization_plot = plot_diarization(starts, ends, speakers) | |
| return transcription_res, diarized_transcription, diarization_plot | |
| # creating the gradio interface | |
| demo = gr.Interface( | |
| fn=transcribe_diarize, | |
| inputs=gr.Audio(sources=["upload", "microphone"]), | |
| outputs=[ | |
| gr.Textbox(lines=3, label="Text Transcription"), | |
| gr.Textbox(label="Diarized Transcription"), | |
| gr.Plot(label="Visualization"), | |
| ], | |
| examples=["sample1.wav"], | |
| title="Automatic Speech Recognition with Diarization 🗣️", | |
| description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |