File size: 6,506 Bytes
99a85b5
 
83e68e1
667e520
99a85b5
7def15a
99a85b5
 
 
 
96ec82d
99a85b5
83e68e1
 
99a85b5
 
 
83e68e1
 
 
99a85b5
 
 
 
 
 
 
 
 
46c3876
83e68e1
 
 
 
 
 
 
 
46c3876
1984b17
99a85b5
 
 
83e68e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667e520
83e68e1
 
 
667e520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ccfd03
83e68e1
7a4d826
 
 
 
7def15a
99a85b5
7a4d826
99a85b5
83e68e1
 
99a85b5
7a4d826
 
99a85b5
7def15a
7a4d826
99a85b5
 
 
 
 
 
 
18bf750
99a85b5
 
 
 
96ec82d
 
 
 
 
667e520
83e68e1
96ec82d
99a85b5
 
96ec82d
99a85b5
83e68e1
 
 
99a85b5
7a4d826
 
 
 
 
 
 
 
 
 
18bf750
 
7a4d826
 
 
 
18bf750
7a4d826
 
 
 
 
 
18bf750
99a85b5
18bf750
99a85b5
 
 
d09fc0f
99a85b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1103803
99a85b5
 
 
 
 
 
 
 
 
 
 
18bf750
99a85b5
 
 
 
 
 
 
18bf750
99a85b5
 
 
 
 
 
 
 
 
 
d09fc0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from __future__ import annotations

import json
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Any

import gradio as gr
import spaces
import torch
import torchaudio
from huggingface_hub import snapshot_download
from pyannote.audio import Pipeline


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

_PIPELINE: Pipeline | None = None


def get_pipeline(hf_token: str) -> Pipeline:
    global _PIPELINE

    if _PIPELINE is not None:
        return _PIPELINE

    local_model_dir = Path("models/pyannote-speaker-diarization-community-1-b64")
    if not local_model_dir.exists():
        snapshot_download(
            repo_id="pyannote/speaker-diarization-community-1",
            token=hf_token,
            local_dir=str(local_model_dir),
        )

    _PIPELINE = Pipeline.from_pretrained(str(local_model_dir))
    _PIPELINE.segmentation_batch_size = 64
    _PIPELINE.embedding_batch_size = 32
    return _PIPELINE


def _audio_is_already_optimized(audio_path: str) -> bool:
    command = [
        "ffprobe",
        "-v",
        "error",
        "-print_format",
        "json",
        "-show_streams",
        "-select_streams",
        "a:0",
        audio_path,
    ]

    try:
        result = subprocess.run(command, check=True, capture_output=True, text=True)
        payload = json.loads(result.stdout)
    except (subprocess.CalledProcessError, json.JSONDecodeError):
        return False

    streams = payload.get("streams") or []
    if not streams:
        return False

    stream = streams[0]
    return (
        stream.get("codec_name") == "pcm_s16le"
        and int(stream.get("sample_rate", 0)) == 16000
        and int(stream.get("channels", 0)) == 1
    )


def _normalize_audio(audio_path: str) -> str:
    if _audio_is_already_optimized(audio_path):
        return audio_path

    normalized_dir = Path(tempfile.mkdtemp(prefix="pyannote_audio_"))
    normalized_path = normalized_dir / "normalized.wav"

    command = [
        "ffmpeg",
        "-y",
        "-hide_banner",
        "-loglevel",
        "error",
        "-i",
        audio_path,
        "-ac",
        "1",
        "-ar",
        "16000",
        "-c:a",
        "pcm_s16le",
        str(normalized_path),
    ]

    try:
        subprocess.run(command, check=True, capture_output=True, text=True)
    except subprocess.CalledProcessError as exc:
        message = exc.stderr.strip() or exc.stdout.strip() or str(exc)
        raise gr.Error(f"Failed to normalize audio with ffmpeg: {message}") from exc

    return str(normalized_path)


@spaces.GPU(duration=180)
def _run_diarization(audio: dict[str, Any]) -> tuple[Any, float]:
    if _PIPELINE is None:
        raise RuntimeError("Pipeline must be loaded before running diarization.")

    device = torch.device("cuda")
    started_at = time.perf_counter()

    _PIPELINE.to(device)
    try:
        with torch.inference_mode():
            output = _PIPELINE(audio)
    finally:
        _PIPELINE.to(torch.device("cpu"))
        torch.cuda.empty_cache()

    zerogpu_seconds = time.perf_counter() - started_at
    return output, zerogpu_seconds


def diarize(
    audio_path: str | None,
    hf_token: str | None,
):
    if not audio_path:
        raise gr.Error("Upload an audio file first.")

    if not Path(audio_path).exists():
        raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.")

    if not hf_token or not hf_token.strip():
        raise gr.Error(
            "A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call."
        )

    normalized_audio_path = _normalize_audio(audio_path)
    waveform, sample_rate = torchaudio.load(normalized_audio_path)
    hf_token = hf_token.strip()

    # Load on CPU first so the ZeroGPU decorator only wraps actual inference.
    get_pipeline(hf_token)

    output, zerogpu_seconds = _run_diarization(
        audio={"waveform": waveform, "sample_rate": sample_rate}
    )

    annotation = output.speaker_diarization
    exclusive_annotation = getattr(output, "exclusive_speaker_diarization", None)
    if exclusive_annotation is not None:
        annotation = exclusive_annotation

    segments: list[dict[str, Any]] = []
    for index, (turn, _, speaker) in enumerate(annotation.itertracks(yield_label=True), start=1):
        start = float(turn.start)
        end = float(turn.end)
        segments.append(
            {
                "segment_id": f"seg_{index:06d}",
                "speaker_id": str(speaker),
                "start": round(start, 3),
                "end": round(end, 3),
                "duration": round(max(0.0, end - start), 3),
            }
        )

    response = {
        "source": "pyannote/speaker-diarization-community-1",
        "zerogpu_seconds": round(zerogpu_seconds, 3),
        "segments": segments,
    }

    return response


def build_demo() -> gr.Blocks:
    with gr.Blocks(title="Pyannote Speaker Diarization") as demo:
        gr.Markdown(
            """
            # Speaker diarization with pyannote Community-1

            Upload or record audio and run speaker diarization using
            [`pyannote/speaker-diarization-community-1`](https://huggingface.co/pyannote/speaker-diarization-community-1).

            **Important:** before first use, accept the model conditions on Hugging Face and provide a
            read-access token. Pass it as the `HF_TOKEN` input in the UI or API call.
            """
        )

        with gr.Row():
            with gr.Column(scale=1):
                audio_input = gr.Audio(
                    sources=["upload"],
                    type="filepath",
                    label="Audio",
                )
                token_input = gr.Textbox(
                    label="HF_TOKEN",
                    type="password",
                    placeholder="hf_xxx",
                )
                run_button = gr.Button("Run diarization", variant="primary")

            with gr.Column(scale=1):
                response_output = gr.JSON(label="Diarization JSON")

        run_button.click(
            fn=diarize,
            inputs=[
                audio_input,
                token_input,
            ],
            outputs=[response_output],
        )


    return demo


demo = build_demo()
demo.queue(default_concurrency_limit=1)

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())