from __future__ import annotations import json import subprocess import tempfile import time from pathlib import Path from typing import Any import gradio as gr import spaces import torch import torchaudio from huggingface_hub import snapshot_download from pyannote.audio import Pipeline torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True _PIPELINE: Pipeline | None = None def get_pipeline(hf_token: str) -> Pipeline: global _PIPELINE if _PIPELINE is not None: return _PIPELINE local_model_dir = Path("models/pyannote-speaker-diarization-community-1-b64") if not local_model_dir.exists(): snapshot_download( repo_id="pyannote/speaker-diarization-community-1", token=hf_token, local_dir=str(local_model_dir), ) _PIPELINE = Pipeline.from_pretrained(str(local_model_dir)) _PIPELINE.segmentation_batch_size = 64 _PIPELINE.embedding_batch_size = 32 return _PIPELINE def _audio_is_already_optimized(audio_path: str) -> bool: command = [ "ffprobe", "-v", "error", "-print_format", "json", "-show_streams", "-select_streams", "a:0", audio_path, ] try: result = subprocess.run(command, check=True, capture_output=True, text=True) payload = json.loads(result.stdout) except (subprocess.CalledProcessError, json.JSONDecodeError): return False streams = payload.get("streams") or [] if not streams: return False stream = streams[0] return ( stream.get("codec_name") == "pcm_s16le" and int(stream.get("sample_rate", 0)) == 16000 and int(stream.get("channels", 0)) == 1 ) def _normalize_audio(audio_path: str) -> str: if _audio_is_already_optimized(audio_path): return audio_path normalized_dir = Path(tempfile.mkdtemp(prefix="pyannote_audio_")) normalized_path = normalized_dir / "normalized.wav" command = [ "ffmpeg", "-y", "-hide_banner", "-loglevel", "error", "-i", audio_path, "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", str(normalized_path), ] try: subprocess.run(command, check=True, capture_output=True, text=True) except subprocess.CalledProcessError as exc: message = exc.stderr.strip() or exc.stdout.strip() or str(exc) raise gr.Error(f"Failed to normalize audio with ffmpeg: {message}") from exc return str(normalized_path) @spaces.GPU(duration=180) def _run_diarization(audio: dict[str, Any]) -> tuple[Any, float]: if _PIPELINE is None: raise RuntimeError("Pipeline must be loaded before running diarization.") device = torch.device("cuda") started_at = time.perf_counter() _PIPELINE.to(device) try: with torch.inference_mode(): output = _PIPELINE(audio) finally: _PIPELINE.to(torch.device("cpu")) torch.cuda.empty_cache() zerogpu_seconds = time.perf_counter() - started_at return output, zerogpu_seconds def diarize( audio_path: str | None, hf_token: str | None, ): if not audio_path: raise gr.Error("Upload an audio file first.") if not Path(audio_path).exists(): raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.") if not hf_token or not hf_token.strip(): raise gr.Error( "A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call." ) normalized_audio_path = _normalize_audio(audio_path) waveform, sample_rate = torchaudio.load(normalized_audio_path) hf_token = hf_token.strip() # Load on CPU first so the ZeroGPU decorator only wraps actual inference. get_pipeline(hf_token) output, zerogpu_seconds = _run_diarization( audio={"waveform": waveform, "sample_rate": sample_rate} ) annotation = output.speaker_diarization exclusive_annotation = getattr(output, "exclusive_speaker_diarization", None) if exclusive_annotation is not None: annotation = exclusive_annotation segments: list[dict[str, Any]] = [] for index, (turn, _, speaker) in enumerate(annotation.itertracks(yield_label=True), start=1): start = float(turn.start) end = float(turn.end) segments.append( { "segment_id": f"seg_{index:06d}", "speaker_id": str(speaker), "start": round(start, 3), "end": round(end, 3), "duration": round(max(0.0, end - start), 3), } ) response = { "source": "pyannote/speaker-diarization-community-1", "zerogpu_seconds": round(zerogpu_seconds, 3), "segments": segments, } return response def build_demo() -> gr.Blocks: with gr.Blocks(title="Pyannote Speaker Diarization") as demo: gr.Markdown( """ # Speaker diarization with pyannote Community-1 Upload or record audio and run speaker diarization using [`pyannote/speaker-diarization-community-1`](https://huggingface.co/pyannote/speaker-diarization-community-1). **Important:** before first use, accept the model conditions on Hugging Face and provide a read-access token. Pass it as the `HF_TOKEN` input in the UI or API call. """ ) with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio( sources=["upload"], type="filepath", label="Audio", ) token_input = gr.Textbox( label="HF_TOKEN", type="password", placeholder="hf_xxx", ) run_button = gr.Button("Run diarization", variant="primary") with gr.Column(scale=1): response_output = gr.JSON(label="Diarization JSON") run_button.click( fn=diarize, inputs=[ audio_input, token_input, ], outputs=[response_output], ) return demo demo = build_demo() demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())