Spaces:
Sleeping
Sleeping
File size: 6,506 Bytes
99a85b5 83e68e1 667e520 99a85b5 7def15a 99a85b5 96ec82d 99a85b5 83e68e1 99a85b5 83e68e1 99a85b5 46c3876 83e68e1 46c3876 1984b17 99a85b5 83e68e1 667e520 83e68e1 667e520 3ccfd03 83e68e1 7a4d826 7def15a 99a85b5 7a4d826 99a85b5 83e68e1 99a85b5 7a4d826 99a85b5 7def15a 7a4d826 99a85b5 18bf750 99a85b5 96ec82d 667e520 83e68e1 96ec82d 99a85b5 96ec82d 99a85b5 83e68e1 99a85b5 7a4d826 18bf750 7a4d826 18bf750 7a4d826 18bf750 99a85b5 18bf750 99a85b5 d09fc0f 99a85b5 1103803 99a85b5 18bf750 99a85b5 18bf750 99a85b5 d09fc0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | from __future__ import annotations
import json
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Any
import gradio as gr
import spaces
import torch
import torchaudio
from huggingface_hub import snapshot_download
from pyannote.audio import Pipeline
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
_PIPELINE: Pipeline | None = None
def get_pipeline(hf_token: str) -> Pipeline:
global _PIPELINE
if _PIPELINE is not None:
return _PIPELINE
local_model_dir = Path("models/pyannote-speaker-diarization-community-1-b64")
if not local_model_dir.exists():
snapshot_download(
repo_id="pyannote/speaker-diarization-community-1",
token=hf_token,
local_dir=str(local_model_dir),
)
_PIPELINE = Pipeline.from_pretrained(str(local_model_dir))
_PIPELINE.segmentation_batch_size = 64
_PIPELINE.embedding_batch_size = 32
return _PIPELINE
def _audio_is_already_optimized(audio_path: str) -> bool:
command = [
"ffprobe",
"-v",
"error",
"-print_format",
"json",
"-show_streams",
"-select_streams",
"a:0",
audio_path,
]
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
payload = json.loads(result.stdout)
except (subprocess.CalledProcessError, json.JSONDecodeError):
return False
streams = payload.get("streams") or []
if not streams:
return False
stream = streams[0]
return (
stream.get("codec_name") == "pcm_s16le"
and int(stream.get("sample_rate", 0)) == 16000
and int(stream.get("channels", 0)) == 1
)
def _normalize_audio(audio_path: str) -> str:
if _audio_is_already_optimized(audio_path):
return audio_path
normalized_dir = Path(tempfile.mkdtemp(prefix="pyannote_audio_"))
normalized_path = normalized_dir / "normalized.wav"
command = [
"ffmpeg",
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
audio_path,
"-ac",
"1",
"-ar",
"16000",
"-c:a",
"pcm_s16le",
str(normalized_path),
]
try:
subprocess.run(command, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as exc:
message = exc.stderr.strip() or exc.stdout.strip() or str(exc)
raise gr.Error(f"Failed to normalize audio with ffmpeg: {message}") from exc
return str(normalized_path)
@spaces.GPU(duration=180)
def _run_diarization(audio: dict[str, Any]) -> tuple[Any, float]:
if _PIPELINE is None:
raise RuntimeError("Pipeline must be loaded before running diarization.")
device = torch.device("cuda")
started_at = time.perf_counter()
_PIPELINE.to(device)
try:
with torch.inference_mode():
output = _PIPELINE(audio)
finally:
_PIPELINE.to(torch.device("cpu"))
torch.cuda.empty_cache()
zerogpu_seconds = time.perf_counter() - started_at
return output, zerogpu_seconds
def diarize(
audio_path: str | None,
hf_token: str | None,
):
if not audio_path:
raise gr.Error("Upload an audio file first.")
if not Path(audio_path).exists():
raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.")
if not hf_token or not hf_token.strip():
raise gr.Error(
"A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call."
)
normalized_audio_path = _normalize_audio(audio_path)
waveform, sample_rate = torchaudio.load(normalized_audio_path)
hf_token = hf_token.strip()
# Load on CPU first so the ZeroGPU decorator only wraps actual inference.
get_pipeline(hf_token)
output, zerogpu_seconds = _run_diarization(
audio={"waveform": waveform, "sample_rate": sample_rate}
)
annotation = output.speaker_diarization
exclusive_annotation = getattr(output, "exclusive_speaker_diarization", None)
if exclusive_annotation is not None:
annotation = exclusive_annotation
segments: list[dict[str, Any]] = []
for index, (turn, _, speaker) in enumerate(annotation.itertracks(yield_label=True), start=1):
start = float(turn.start)
end = float(turn.end)
segments.append(
{
"segment_id": f"seg_{index:06d}",
"speaker_id": str(speaker),
"start": round(start, 3),
"end": round(end, 3),
"duration": round(max(0.0, end - start), 3),
}
)
response = {
"source": "pyannote/speaker-diarization-community-1",
"zerogpu_seconds": round(zerogpu_seconds, 3),
"segments": segments,
}
return response
def build_demo() -> gr.Blocks:
with gr.Blocks(title="Pyannote Speaker Diarization") as demo:
gr.Markdown(
"""
# Speaker diarization with pyannote Community-1
Upload or record audio and run speaker diarization using
[`pyannote/speaker-diarization-community-1`](https://huggingface.co/pyannote/speaker-diarization-community-1).
**Important:** before first use, accept the model conditions on Hugging Face and provide a
read-access token. Pass it as the `HF_TOKEN` input in the UI or API call.
"""
)
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
sources=["upload"],
type="filepath",
label="Audio",
)
token_input = gr.Textbox(
label="HF_TOKEN",
type="password",
placeholder="hf_xxx",
)
run_button = gr.Button("Run diarization", variant="primary")
with gr.Column(scale=1):
response_output = gr.JSON(label="Diarization JSON")
run_button.click(
fn=diarize,
inputs=[
audio_input,
token_input,
],
outputs=[response_output],
)
return demo
demo = build_demo()
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())
|