Ratnesh-dev commited on
Commit
96ec82d
·
1 Parent(s): 1103803

Simplify API-focused diarization app

Browse files
Files changed (1) hide show
  1. app.py +18 -79
app.py CHANGED
@@ -1,7 +1,5 @@
1
  from __future__ import annotations
2
 
3
- import csv
4
- import io
5
  import subprocess
6
  import tempfile
7
  import time
@@ -9,54 +7,21 @@ from pathlib import Path
9
  from typing import Any
10
 
11
  import gradio as gr
 
12
  import torch
13
  from pyannote.audio import Pipeline
14
 
15
- try:
16
- import spaces
17
- except ImportError: # local fallback when the ZeroGPU helper is unavailable
18
- class _SpacesShim:
19
- def GPU(self, *args, **kwargs):
20
- if args and callable(args[0]) and len(args) == 1 and not kwargs:
21
- return args[0]
22
-
23
- def decorator(func):
24
- return func
25
-
26
- return decorator
27
-
28
- spaces = _SpacesShim()
29
-
30
-
31
- MODEL_ID = "pyannote/speaker-diarization-community-1"
32
- GPU_DURATION_SECONDS = 30
33
 
34
  _PIPELINE: Pipeline | None = None
35
 
36
 
37
- def _resolve_token(hf_token: str | None) -> str:
38
- if hf_token and hf_token.strip():
39
- return hf_token.strip()
40
-
41
- raise gr.Error(
42
- "A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call."
43
- )
44
-
45
-
46
  def get_pipeline(hf_token: str) -> Pipeline:
47
  global _PIPELINE
48
 
49
  if _PIPELINE is not None:
50
  return _PIPELINE
51
 
52
- try:
53
- _PIPELINE = Pipeline.from_pretrained(MODEL_ID, token=hf_token)
54
- except Exception as exc: # pragma: no cover - depends on runtime/network/token state
55
- raise gr.Error(
56
- "Failed to load the pyannote pipeline. Make sure you accepted the model conditions "
57
- f"for {MODEL_ID} and provided a valid token. Original error: {exc}"
58
- ) from exc
59
-
60
  return _PIPELINE
61
 
62
 
@@ -98,11 +63,11 @@ def _normalize_audio(audio_path: str) -> str:
98
  return str(normalized_path)
99
 
100
 
101
- @spaces.GPU(duration=GPU_DURATION_SECONDS)
102
  def _run_diarization(
103
  audio_path: str,
104
  hf_token: str,
105
- ) -> tuple[list[dict[str, Any]], str, str, float]:
106
  pipeline = get_pipeline(hf_token)
107
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
  started_at = time.perf_counter()
@@ -110,8 +75,6 @@ def _run_diarization(
110
  pipeline.to(device)
111
  try:
112
  output = pipeline(audio_path)
113
- except Exception as exc: # pragma: no cover - depends on model/runtime/audio
114
- raise gr.Error(f"Diarization failed: {exc}") from exc
115
  finally:
116
  if device.type == "cuda":
117
  pipeline.to(torch.device("cpu"))
@@ -140,33 +103,7 @@ def _run_diarization(
140
  }
141
  )
142
 
143
- rttm_buffer = io.StringIO()
144
- annotation.write_rttm(rttm_buffer)
145
-
146
- return segments, rttm_buffer.getvalue(), annotation_label, zerogpu_seconds
147
-
148
-
149
- def _write_artifacts(segments: list[dict[str, Any]], rttm_text: str) -> list[str]:
150
- output_dir = Path(tempfile.mkdtemp(prefix="pyannote_diarization_"))
151
-
152
- csv_path = output_dir / "segments.csv"
153
- with csv_path.open("w", newline="", encoding="utf-8") as csv_file:
154
- writer = csv.DictWriter(csv_file, fieldnames=["speaker", "start", "end", "duration"])
155
- writer.writeheader()
156
- writer.writerows(segments)
157
-
158
- txt_path = output_dir / "segments.txt"
159
- with txt_path.open("w", encoding="utf-8") as txt_file:
160
- for segment in segments:
161
- txt_file.write(
162
- f"{segment['speaker']} | {_format_timestamp(segment['start'])} --> "
163
- f"{_format_timestamp(segment['end'])}\n"
164
- )
165
-
166
- rttm_path = output_dir / "diarization.rttm"
167
- rttm_path.write_text(rttm_text, encoding="utf-8")
168
-
169
- return [str(csv_path), str(txt_path), str(rttm_path)]
170
 
171
 
172
  def diarize(
@@ -179,15 +116,20 @@ def diarize(
179
  if not Path(audio_path).exists():
180
  raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.")
181
 
 
 
 
 
 
182
  normalized_audio_path = _normalize_audio(audio_path)
183
- resolved_token = _resolve_token(hf_token)
184
 
185
  # Load on CPU first so the ZeroGPU decorator only wraps actual inference.
186
- get_pipeline(resolved_token)
187
 
188
- segments, rttm_text, annotation_label, zerogpu_seconds = _run_diarization(
189
  audio_path=normalized_audio_path,
190
- hf_token=resolved_token,
191
  )
192
 
193
  if not segments:
@@ -196,7 +138,7 @@ def diarize(
196
  f"Inference completed with `{annotation_label}` output, but it contained no segments."
197
  )
198
  summary += f"\n- ZeroGPU time used: **{zerogpu_seconds:.2f}s**"
199
- return summary, round(zerogpu_seconds, 3), [], "", _write_artifacts(segments, rttm_text)
200
 
201
  unique_speakers = sorted({segment["speaker"] for segment in segments})
202
  total_speech = sum(segment["duration"] for segment in segments)
@@ -227,8 +169,7 @@ def diarize(
227
  for segment in segments
228
  )
229
 
230
- artifacts = _write_artifacts(segments, rttm_text)
231
- return summary, round(zerogpu_seconds, 3), segments_json, turns_text, artifacts
232
 
233
 
234
  def build_demo() -> gr.Blocks:
@@ -268,7 +209,6 @@ def build_demo() -> gr.Blocks:
268
  lines=14,
269
  buttons=["copy"],
270
  )
271
- files_output = gr.File(label="Download outputs", file_count="multiple")
272
 
273
  run_button.click(
274
  fn=diarize,
@@ -276,13 +216,12 @@ def build_demo() -> gr.Blocks:
276
  audio_input,
277
  token_input,
278
  ],
279
- outputs=[summary_output, zerogpu_seconds_output, segments_output, turns_output, files_output],
280
  )
281
 
282
  gr.Markdown(
283
  """
284
- Outputs include segments as JSON, a plain-text speaker-turn list, and downloadable
285
- `CSV`, `TXT`, and `RTTM` files.
286
  """
287
  )
288
 
 
1
  from __future__ import annotations
2
 
 
 
3
  import subprocess
4
  import tempfile
5
  import time
 
7
  from typing import Any
8
 
9
  import gradio as gr
10
+ import spaces
11
  import torch
12
  from pyannote.audio import Pipeline
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  _PIPELINE: Pipeline | None = None
16
 
17
 
 
 
 
 
 
 
 
 
 
18
  def get_pipeline(hf_token: str) -> Pipeline:
19
  global _PIPELINE
20
 
21
  if _PIPELINE is not None:
22
  return _PIPELINE
23
 
24
+ _PIPELINE = Pipeline.from_pretrained("pyannote/speaker-diarization-community-1", token=hf_token)
 
 
 
 
 
 
 
25
  return _PIPELINE
26
 
27
 
 
63
  return str(normalized_path)
64
 
65
 
66
+ @spaces.GPU(duration=120)
67
  def _run_diarization(
68
  audio_path: str,
69
  hf_token: str,
70
+ ) -> tuple[list[dict[str, Any]], str, float]:
71
  pipeline = get_pipeline(hf_token)
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  started_at = time.perf_counter()
 
75
  pipeline.to(device)
76
  try:
77
  output = pipeline(audio_path)
 
 
78
  finally:
79
  if device.type == "cuda":
80
  pipeline.to(torch.device("cpu"))
 
103
  }
104
  )
105
 
106
+ return segments, annotation_label, zerogpu_seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  def diarize(
 
116
  if not Path(audio_path).exists():
117
  raise gr.Error("The uploaded audio file could not be found. Please re-upload it and try again.")
118
 
119
+ if not hf_token or not hf_token.strip():
120
+ raise gr.Error(
121
+ "A Hugging Face access token is required. Accept the model conditions first, then pass `HF_TOKEN` in the UI or API call."
122
+ )
123
+
124
  normalized_audio_path = _normalize_audio(audio_path)
125
+ hf_token = hf_token.strip()
126
 
127
  # Load on CPU first so the ZeroGPU decorator only wraps actual inference.
128
+ get_pipeline(hf_token)
129
 
130
+ segments, annotation_label, zerogpu_seconds = _run_diarization(
131
  audio_path=normalized_audio_path,
132
+ hf_token=hf_token,
133
  )
134
 
135
  if not segments:
 
138
  f"Inference completed with `{annotation_label}` output, but it contained no segments."
139
  )
140
  summary += f"\n- ZeroGPU time used: **{zerogpu_seconds:.2f}s**"
141
+ return summary, round(zerogpu_seconds, 3), [], ""
142
 
143
  unique_speakers = sorted({segment["speaker"] for segment in segments})
144
  total_speech = sum(segment["duration"] for segment in segments)
 
169
  for segment in segments
170
  )
171
 
172
+ return summary, round(zerogpu_seconds, 3), segments_json, turns_text
 
173
 
174
 
175
  def build_demo() -> gr.Blocks:
 
209
  lines=14,
210
  buttons=["copy"],
211
  )
 
212
 
213
  run_button.click(
214
  fn=diarize,
 
216
  audio_input,
217
  token_input,
218
  ],
219
+ outputs=[summary_output, zerogpu_seconds_output, segments_output, turns_output],
220
  )
221
 
222
  gr.Markdown(
223
  """
224
+ Outputs include segments as JSON and a plain-text speaker-turn list.
 
225
  """
226
  )
227