Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import subprocess | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import snapshot_download | |
| from pyannote.audio import Pipeline | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| _PIPELINE: Pipeline | None = None | |
| def get_pipeline(hf_token: str) -> Pipeline: | |
| global _PIPELINE | |
| if _PIPELINE is not None: | |
| return _PIPELINE | |
| local_model_dir = Path("models/pyannote-speaker-diarization-community-1-b64") | |
| if not local_model_dir.exists(): | |
| snapshot_download( | |
| repo_id="pyannote/speaker-diarization-community-1", | |
| token=hf_token, | |
| local_dir=str(local_model_dir), | |
| ) | |
| _PIPELINE = Pipeline.from_pretrained(str(local_model_dir)) | |
| _PIPELINE.segmentation_batch_size = 64 | |
| _PIPELINE.embedding_batch_size = 32 | |
| return _PIPELINE | |
| def _audio_is_already_optimized(audio_path: str) -> bool: | |
| command = [ | |
| "ffprobe", | |
| "-v", | |
| "error", | |
| "-print_format", | |
| "json", | |
| "-show_streams", | |
| "-select_streams", | |
| "a:0", | |
| audio_path, | |
| ] | |
| try: | |
| result = subprocess.run(command, check=True, capture_output=True, text=True) | |
| payload = json.loads(result.stdout) | |
| except (subprocess.CalledProcessError, json.JSONDecodeError): | |
| return False | |
| streams = payload.get("streams") or [] | |
| if not streams: | |
| return False | |
| stream = streams[0] | |
| return ( | |
| stream.get("codec_name") == "pcm_s16le" | |
| and int(stream.get("sample_rate", 0)) == 16000 | |
| and int(stream.get("channels", 0)) == 1 | |
| ) | |
| def _normalize_audio(audio_path: str) -> str: | |
| if _audio_is_already_optimized(audio_path): | |
| return audio_path | |
| normalized_dir = Path(tempfile.mkdtemp(prefix="pyannote_audio_")) | |
| normalized_path = normalized_dir / "normalized.wav" | |
| command = [ | |
| "ffmpeg", | |
| "-y", | |
| "-hide_banner", | |
| "-loglevel", | |
| "error", | |
| "-i", | |
| audio_path, | |
| "-ac", | |
| "1", | |
| "-ar", | |
| "16000", | |
| "-c:a", | |
| "pcm_s16le", | |
| str(normalized_path), | |
| ] | |
| try: | |
| subprocess.run(command, check=True, capture_output=True, text=True) | |
| except subprocess.CalledProcessError as exc: | |
| message = exc.stderr.strip() or exc.stdout.strip() or str(exc) | |
| raise gr.Error(f"Failed to normalize audio with ffmpeg: {message}") from exc | |
| return str(normalized_path) | |
| def _run_diarization(audio: dict[str, Any]) -> tuple[Any, float]: | |
| if _PIPELINE is None: | |
| raise RuntimeError("Pipeline must be loaded before running diarization.") | |
| device = torch.device("cuda") | |
| started_at = time.perf_counter() | |
| _PIPELINE.to(device) | |
| try: | |
| with torch.inference_mode(): | |
| output = _PIPELINE(audio) | |
| finally: | |
| _PIPELINE.to(torch.device("cpu")) | |
| torch.cuda.empty_cache() | |
| zerogpu_seconds = time.perf_counter() - started_at | |
| return output, zerogpu_seconds | |
| def diarize( | |
| audio_path: str | None, | |
| hf_token: str | None, | |
| ): | |
| if not audio_path: | |
| raise gr.Error("Upload an audio file first.") | |
| if not Path(audio_path).exists(): | |
| raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.") | |
| if not hf_token or not hf_token.strip(): | |
| raise gr.Error( | |
| "A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call." | |
| ) | |
| normalized_audio_path = _normalize_audio(audio_path) | |
| waveform, sample_rate = torchaudio.load(normalized_audio_path) | |
| hf_token = hf_token.strip() | |
| # Load on CPU first so the ZeroGPU decorator only wraps actual inference. | |
| get_pipeline(hf_token) | |
| output, zerogpu_seconds = _run_diarization( | |
| audio={"waveform": waveform, "sample_rate": sample_rate} | |
| ) | |
| annotation = output.speaker_diarization | |
| exclusive_annotation = getattr(output, "exclusive_speaker_diarization", None) | |
| if exclusive_annotation is not None: | |
| annotation = exclusive_annotation | |
| segments: list[dict[str, Any]] = [] | |
| for index, (turn, _, speaker) in enumerate(annotation.itertracks(yield_label=True), start=1): | |
| start = float(turn.start) | |
| end = float(turn.end) | |
| segments.append( | |
| { | |
| "segment_id": f"seg_{index:06d}", | |
| "speaker_id": str(speaker), | |
| "start": round(start, 3), | |
| "end": round(end, 3), | |
| "duration": round(max(0.0, end - start), 3), | |
| } | |
| ) | |
| response = { | |
| "source": "pyannote/speaker-diarization-community-1", | |
| "zerogpu_seconds": round(zerogpu_seconds, 3), | |
| "segments": segments, | |
| } | |
| return response | |
| def build_demo() -> gr.Blocks: | |
| with gr.Blocks(title="Pyannote Speaker Diarization") as demo: | |
| gr.Markdown( | |
| """ | |
| # Speaker diarization with pyannote Community-1 | |
| Upload or record audio and run speaker diarization using | |
| [`pyannote/speaker-diarization-community-1`](https://huggingface.co/pyannote/speaker-diarization-community-1). | |
| **Important:** before first use, accept the model conditions on Hugging Face and provide a | |
| read-access token. Pass it as the `HF_TOKEN` input in the UI or API call. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| sources=["upload"], | |
| type="filepath", | |
| label="Audio", | |
| ) | |
| token_input = gr.Textbox( | |
| label="HF_TOKEN", | |
| type="password", | |
| placeholder="hf_xxx", | |
| ) | |
| run_button = gr.Button("Run diarization", variant="primary") | |
| with gr.Column(scale=1): | |
| response_output = gr.JSON(label="Diarization JSON") | |
| run_button.click( | |
| fn=diarize, | |
| inputs=[ | |
| audio_input, | |
| token_input, | |
| ], | |
| outputs=[response_output], | |
| ) | |
| return demo | |
| demo = build_demo() | |
| demo.queue(default_concurrency_limit=1) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |