diarize / app.py
Ratnesh-dev's picture
Update app.py
7dbe887 verified
from __future__ import annotations
import json
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Any
import gradio as gr
import spaces
import torch
import torchaudio
from huggingface_hub import snapshot_download
from pyannote.audio import Pipeline
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
_PIPELINE: Pipeline | None = None
def get_pipeline(hf_token: str) -> Pipeline:
global _PIPELINE
if _PIPELINE is not None:
return _PIPELINE
local_model_dir = Path("models/pyannote-speaker-diarization-community-1-b64")
if not local_model_dir.exists():
snapshot_download(
repo_id="pyannote/speaker-diarization-community-1",
token=hf_token,
local_dir=str(local_model_dir),
)
_PIPELINE = Pipeline.from_pretrained(str(local_model_dir))
_PIPELINE.segmentation_batch_size = 64
_PIPELINE.embedding_batch_size = 32
return _PIPELINE
def _audio_is_already_optimized(audio_path: str) -> bool:
command = [
"ffprobe",
"-v",
"error",
"-print_format",
"json",
"-show_streams",
"-select_streams",
"a:0",
audio_path,
]
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
payload = json.loads(result.stdout)
except (subprocess.CalledProcessError, json.JSONDecodeError):
return False
streams = payload.get("streams") or []
if not streams:
return False
stream = streams[0]
return (
stream.get("codec_name") == "pcm_s16le"
and int(stream.get("sample_rate", 0)) == 16000
and int(stream.get("channels", 0)) == 1
)
def _normalize_audio(audio_path: str) -> str:
if _audio_is_already_optimized(audio_path):
return audio_path
normalized_dir = Path(tempfile.mkdtemp(prefix="pyannote_audio_"))
normalized_path = normalized_dir / "normalized.wav"
command = [
"ffmpeg",
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
audio_path,
"-ac",
"1",
"-ar",
"16000",
"-c:a",
"pcm_s16le",
str(normalized_path),
]
try:
subprocess.run(command, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as exc:
message = exc.stderr.strip() or exc.stdout.strip() or str(exc)
raise gr.Error(f"Failed to normalize audio with ffmpeg: {message}") from exc
return str(normalized_path)
@spaces.GPU(duration=180)
def _run_diarization(audio: dict[str, Any]) -> tuple[Any, float]:
if _PIPELINE is None:
raise RuntimeError("Pipeline must be loaded before running diarization.")
device = torch.device("cuda")
started_at = time.perf_counter()
_PIPELINE.to(device)
try:
with torch.inference_mode():
output = _PIPELINE(audio)
finally:
_PIPELINE.to(torch.device("cpu"))
torch.cuda.empty_cache()
zerogpu_seconds = time.perf_counter() - started_at
return output, zerogpu_seconds
def diarize(
audio_path: str | None,
hf_token: str | None,
):
if not audio_path:
raise gr.Error("Upload an audio file first.")
if not Path(audio_path).exists():
raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.")
if not hf_token or not hf_token.strip():
raise gr.Error(
"A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call."
)
normalized_audio_path = _normalize_audio(audio_path)
waveform, sample_rate = torchaudio.load(normalized_audio_path)
hf_token = hf_token.strip()
# Load on CPU first so the ZeroGPU decorator only wraps actual inference.
get_pipeline(hf_token)
output, zerogpu_seconds = _run_diarization(
audio={"waveform": waveform, "sample_rate": sample_rate}
)
annotation = output.speaker_diarization
exclusive_annotation = getattr(output, "exclusive_speaker_diarization", None)
if exclusive_annotation is not None:
annotation = exclusive_annotation
segments: list[dict[str, Any]] = []
for index, (turn, _, speaker) in enumerate(annotation.itertracks(yield_label=True), start=1):
start = float(turn.start)
end = float(turn.end)
segments.append(
{
"segment_id": f"seg_{index:06d}",
"speaker_id": str(speaker),
"start": round(start, 3),
"end": round(end, 3),
"duration": round(max(0.0, end - start), 3),
}
)
response = {
"source": "pyannote/speaker-diarization-community-1",
"zerogpu_seconds": round(zerogpu_seconds, 3),
"segments": segments,
}
return response
def build_demo() -> gr.Blocks:
with gr.Blocks(title="Pyannote Speaker Diarization") as demo:
gr.Markdown(
"""
# Speaker diarization with pyannote Community-1
Upload or record audio and run speaker diarization using
[`pyannote/speaker-diarization-community-1`](https://huggingface.co/pyannote/speaker-diarization-community-1).
**Important:** before first use, accept the model conditions on Hugging Face and provide a
read-access token. Pass it as the `HF_TOKEN` input in the UI or API call.
"""
)
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
sources=["upload"],
type="filepath",
label="Audio",
)
token_input = gr.Textbox(
label="HF_TOKEN",
type="password",
placeholder="hf_xxx",
)
run_button = gr.Button("Run diarization", variant="primary")
with gr.Column(scale=1):
response_output = gr.JSON(label="Diarization JSON")
run_button.click(
fn=diarize,
inputs=[
audio_input,
token_input,
],
outputs=[response_output],
)
return demo
demo = build_demo()
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())