Spaces:
Running
on
Zero
Running
on
Zero
| from io import BytesIO | |
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from pyannote.audio import Pipeline | |
| import torchaudio | |
| from pydub import AudioSegment | |
| from pyannote.audio import Pipeline | |
| import json | |
| import requests | |
| # Authenticate with Huggingface | |
| AUTH_TOKEN = os.getenv("HF_TOKEN") | |
| # Load the diarization pipeline | |
| device = torch.device("cuda") | |
| pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-community-1", | |
| token=AUTH_TOKEN).to(device) | |
| def preprocess_audio(audio_path): | |
| """Convert audio to mono, 16kHz WAV format suitable for pyannote.""" | |
| try: | |
| if isinstance(audio_path, str): | |
| bytes = False | |
| else: | |
| bytes = True | |
| # Load audio with pydub | |
| audio = AudioSegment.from_file(BytesIO(audio_path) if bytes else audio_path) | |
| # Convert to mono and set sample rate to 16kHz | |
| audio = audio.set_channels(1).set_frame_rate(16000) | |
| # Export to temporary WAV file | |
| temp_wav = "temp_audio.wav" | |
| audio.export(temp_wav, format="wav") | |
| return temp_wav | |
| except Exception as e: | |
| raise ValueError(f"Error preprocessing audio: {str(e)}") | |
| def handle_audio(url, audio_path, num_speakers): | |
| """Handle audio processing and diarization.""" | |
| if url: | |
| response = requests.get(url, timeout=60) | |
| audio_path = response.content | |
| audio_path = preprocess_audio(audio_path) | |
| res = diarize_audio(audio_path, num_speakers) | |
| # Clean up temporary file | |
| if os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| return res | |
| def diarize_audio(audio_path, num_speakers): | |
| """Perform speaker diarization and return formatted results.""" | |
| try: | |
| # Load audio for pyannote | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| audio_dict = {"waveform": waveform, "sample_rate": sample_rate} | |
| # Configure pipeline with number of speakers | |
| pipeline_params = {"num_speakers": num_speakers} if num_speakers > 0 else { "min_speakers": 2, "max_speakers": 6 } | |
| diarization = pipeline(audio_dict, **pipeline_params) | |
| # Format results | |
| results = [] | |
| for turn, speaker in diarization.exclusive_speaker_diarization: | |
| result = { | |
| "start": round(turn.start, 3), | |
| "end": round(turn.end, 3), | |
| "speaker_id": speaker | |
| } | |
| results.append(result) | |
| return json.dumps(results, indent=2) | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Speaker Diarization with speaker-diarization-community-1") | |
| gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.") | |
| with gr.Row(): | |
| url_input = gr.Textbox(label="URL") | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| num_speakers = gr.Slider(minimum=0, maximum=10, step=1, label="Number of Speakers", value=2) | |
| submit_btn = gr.Button("Diarize") | |
| with gr.Row(): | |
| json_output = gr.Textbox(label="Diarization Results (JSON)") | |
| submit_btn.click( | |
| fn=handle_audio, | |
| inputs=[url_input, audio_input, num_speakers], | |
| outputs=[json_output], | |
| concurrency_limit=2, | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() | |