ConvxO2 commited on
Commit
f54b658
·
1 Parent(s): 7ff62ef

Fix portability, model cache handling, and deploy token safety

Browse files
Files changed (3) hide show
  1. app/main.py +63 -52
  2. deploy_hf.py +60 -0
  3. models/embedder.py +35 -28
app/main.py CHANGED
@@ -1,9 +1,5 @@
1
- """
2
- Speaker Diarization API - FastAPI Application
3
- """
4
 
5
- import io
6
- import time
7
  import asyncio
8
  import tempfile
9
  import traceback
@@ -19,12 +15,9 @@ from fastapi import (
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from fastapi.staticfiles import StaticFiles
21
  from fastapi.responses import HTMLResponse
22
- from pydantic import BaseModel, Field
23
  from loguru import logger
24
 
25
- # ---------------------------------------------------------------------------
26
- # Schemas
27
- # ---------------------------------------------------------------------------
28
 
29
  class SegmentOut(BaseModel):
30
  start: float
@@ -49,13 +42,9 @@ class HealthResponse(BaseModel):
49
  version: str = "1.0.0"
50
 
51
 
52
- # ---------------------------------------------------------------------------
53
- # App
54
- # ---------------------------------------------------------------------------
55
-
56
  app = FastAPI(
57
  title="Speaker Diarization API",
58
- description="Who Spoke When Speaker diarization using ECAPA-TDNN + AHC Clustering",
59
  version="1.0.0",
60
  )
61
 
@@ -69,12 +58,16 @@ app.add_middleware(
69
 
70
  _pipeline = None
71
 
 
72
  def get_pipeline():
73
  global _pipeline
74
  if _pipeline is None:
75
  from app.pipeline import DiarizationPipeline
76
- import os
77
- cache_dir = os.getenv("CACHE_DIR", "/tmp/model_cache")
 
 
 
78
  _pipeline = DiarizationPipeline(
79
  device="auto",
80
  use_pyannote_vad=True,
@@ -85,10 +78,6 @@ def get_pipeline():
85
  return _pipeline
86
 
87
 
88
- # ---------------------------------------------------------------------------
89
- # Endpoints
90
- # ---------------------------------------------------------------------------
91
-
92
  @app.get("/health", response_model=HealthResponse, tags=["System"])
93
  async def health_check():
94
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -136,6 +125,7 @@ async def diarize_from_url(
136
  ):
137
  """Diarize audio from a URL."""
138
  import httpx
 
139
  try:
140
  async with httpx.AsyncClient(timeout=60.0) as client:
141
  resp = await client.get(audio_url)
@@ -169,6 +159,7 @@ async def stream_diarization(websocket: WebSocket):
169
  """Real-time streaming diarization via WebSocket."""
170
  await websocket.accept()
171
  import numpy as np
 
172
  audio_buffer = bytearray()
173
  sample_rate = 16000
174
  num_speakers = None
@@ -179,10 +170,12 @@ async def stream_diarization(websocket: WebSocket):
179
  sample_rate = config_msg.get("sample_rate", 16000)
180
  num_speakers = config_msg.get("num_speakers", None)
181
 
182
- await websocket.send_json({
183
- "type": "progress",
184
- "data": {"message": "Config received. Send audio chunks.", "chunks_received": 0},
185
- })
 
 
186
 
187
  while True:
188
  try:
@@ -194,12 +187,18 @@ async def stream_diarization(websocket: WebSocket):
194
  if "bytes" in msg:
195
  audio_buffer.extend(msg["bytes"])
196
  chunk_count += 1
197
- await websocket.send_json({
198
- "type": "progress",
199
- "data": {"message": f"Received chunk {chunk_count}", "chunks_received": chunk_count},
200
- })
 
 
 
 
 
201
  elif "text" in msg:
202
  import json
 
203
  data = json.loads(msg["text"])
204
  if data.get("type") == "eof":
205
  break
@@ -208,14 +207,17 @@ async def stream_diarization(websocket: WebSocket):
208
  await websocket.send_json({"type": "error", "data": {"message": "No audio received"}})
209
  return
210
 
211
- import torch
 
212
  audio_np = np.frombuffer(audio_buffer, dtype=np.float32).copy()
213
- audio_tensor = torch.from_numpy(audio_np)
214
 
215
- await websocket.send_json({
216
- "type": "progress",
217
- "data": {"message": "Running diarization pipeline..."},
218
- })
 
 
219
 
220
  loop = asyncio.get_event_loop()
221
  pipeline = get_pipeline()
@@ -227,15 +229,17 @@ async def stream_diarization(websocket: WebSocket):
227
  for seg in result.segments:
228
  await websocket.send_json({"type": "segment", "data": seg.to_dict()})
229
 
230
- await websocket.send_json({
231
- "type": "done",
232
- "data": {
233
- "num_speakers": result.num_speakers,
234
- "total_segments": len(result.segments),
235
- "audio_duration": result.audio_duration,
236
- "processing_time": result.processing_time,
237
- },
238
- })
 
 
239
 
240
  except WebSocketDisconnect:
241
  logger.info("WebSocket client disconnected")
@@ -249,26 +253,33 @@ async def stream_diarization(websocket: WebSocket):
249
 
250
  @app.get("/", response_class=HTMLResponse, include_in_schema=False)
251
  async def serve_ui():
252
- ui_path = Path("static/index.html")
253
  if ui_path.exists():
254
- return HTMLResponse(ui_path.read_text())
255
- return HTMLResponse("<h1>Speaker Diarization API</h1><p><a href='/docs'>API Docs</a></p>")
 
256
 
257
  @app.get("/debug", tags=["System"])
258
  async def debug():
259
- import speechbrain
260
- import os
261
  import inspect
 
262
  from speechbrain.inference.classifiers import EncoderClassifier
 
 
 
 
 
263
  sig = str(inspect.signature(EncoderClassifier.from_hparams))
264
  return {
265
  "speechbrain_version": speechbrain.__version__,
266
- "tmp_writable": os.access("/tmp", os.W_OK),
267
- "cache_exists": os.path.exists("/tmp/model_cache"),
 
 
268
  "from_hparams_signature": sig,
269
  }
270
 
271
 
272
- static_dir = Path("static")
273
  if static_dir.exists():
274
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
1
+ """Speaker Diarization API - FastAPI Application."""
 
 
2
 
 
 
3
  import asyncio
4
  import tempfile
5
  import traceback
 
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.staticfiles import StaticFiles
17
  from fastapi.responses import HTMLResponse
18
+ from pydantic import BaseModel
19
  from loguru import logger
20
 
 
 
 
21
 
22
  class SegmentOut(BaseModel):
23
  start: float
 
42
  version: str = "1.0.0"
43
 
44
 
 
 
 
 
45
  app = FastAPI(
46
  title="Speaker Diarization API",
47
+ description="Who Spoke When - Speaker diarization using ECAPA-TDNN + AHC Clustering",
48
  version="1.0.0",
49
  )
50
 
 
58
 
59
  _pipeline = None
60
 
61
+
62
  def get_pipeline():
63
  global _pipeline
64
  if _pipeline is None:
65
  from app.pipeline import DiarizationPipeline
66
+
67
+ cache_dir = os.getenv(
68
+ "CACHE_DIR",
69
+ str(Path(tempfile.gettempdir()) / "model_cache"),
70
+ )
71
  _pipeline = DiarizationPipeline(
72
  device="auto",
73
  use_pyannote_vad=True,
 
78
  return _pipeline
79
 
80
 
 
 
 
 
81
  @app.get("/health", response_model=HealthResponse, tags=["System"])
82
  async def health_check():
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
125
  ):
126
  """Diarize audio from a URL."""
127
  import httpx
128
+
129
  try:
130
  async with httpx.AsyncClient(timeout=60.0) as client:
131
  resp = await client.get(audio_url)
 
159
  """Real-time streaming diarization via WebSocket."""
160
  await websocket.accept()
161
  import numpy as np
162
+
163
  audio_buffer = bytearray()
164
  sample_rate = 16000
165
  num_speakers = None
 
170
  sample_rate = config_msg.get("sample_rate", 16000)
171
  num_speakers = config_msg.get("num_speakers", None)
172
 
173
+ await websocket.send_json(
174
+ {
175
+ "type": "progress",
176
+ "data": {"message": "Config received. Send audio chunks.", "chunks_received": 0},
177
+ }
178
+ )
179
 
180
  while True:
181
  try:
 
187
  if "bytes" in msg:
188
  audio_buffer.extend(msg["bytes"])
189
  chunk_count += 1
190
+ await websocket.send_json(
191
+ {
192
+ "type": "progress",
193
+ "data": {
194
+ "message": f"Received chunk {chunk_count}",
195
+ "chunks_received": chunk_count,
196
+ },
197
+ }
198
+ )
199
  elif "text" in msg:
200
  import json
201
+
202
  data = json.loads(msg["text"])
203
  if data.get("type") == "eof":
204
  break
 
207
  await websocket.send_json({"type": "error", "data": {"message": "No audio received"}})
208
  return
209
 
210
+ import torch as torch_local
211
+
212
  audio_np = np.frombuffer(audio_buffer, dtype=np.float32).copy()
213
+ audio_tensor = torch_local.from_numpy(audio_np)
214
 
215
+ await websocket.send_json(
216
+ {
217
+ "type": "progress",
218
+ "data": {"message": "Running diarization pipeline..."},
219
+ }
220
+ )
221
 
222
  loop = asyncio.get_event_loop()
223
  pipeline = get_pipeline()
 
229
  for seg in result.segments:
230
  await websocket.send_json({"type": "segment", "data": seg.to_dict()})
231
 
232
+ await websocket.send_json(
233
+ {
234
+ "type": "done",
235
+ "data": {
236
+ "num_speakers": result.num_speakers,
237
+ "total_segments": len(result.segments),
238
+ "audio_duration": result.audio_duration,
239
+ "processing_time": result.processing_time,
240
+ },
241
+ }
242
+ )
243
 
244
  except WebSocketDisconnect:
245
  logger.info("WebSocket client disconnected")
 
253
 
254
  @app.get("/", response_class=HTMLResponse, include_in_schema=False)
255
  async def serve_ui():
256
+ ui_path = Path(__file__).resolve().parent.parent / "static" / "index.html"
257
  if ui_path.exists():
258
+ return HTMLResponse(ui_path.read_text(encoding="utf-8"))
259
+ return HTMLResponse("<h1>Speaker Diarization API</h1><p><a href='/docs'>API Docs</a></p>")
260
+
261
 
262
  @app.get("/debug", tags=["System"])
263
  async def debug():
 
 
264
  import inspect
265
+ import speechbrain
266
  from speechbrain.inference.classifiers import EncoderClassifier
267
+
268
+ cache_dir = os.getenv(
269
+ "CACHE_DIR",
270
+ str(Path(tempfile.gettempdir()) / "model_cache"),
271
+ )
272
  sig = str(inspect.signature(EncoderClassifier.from_hparams))
273
  return {
274
  "speechbrain_version": speechbrain.__version__,
275
+ "temp_dir": tempfile.gettempdir(),
276
+ "temp_writable": os.access(tempfile.gettempdir(), os.W_OK),
277
+ "cache_dir": cache_dir,
278
+ "cache_exists": os.path.exists(cache_dir),
279
  "from_hparams_signature": sig,
280
  }
281
 
282
 
283
+ static_dir = Path(__file__).resolve().parent.parent / "static"
284
  if static_dir.exists():
285
+ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
deploy_hf.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Deploy this project to a Hugging Face Space."""
3
+
4
+ import os
5
+ import subprocess
6
+ import sys
7
+
8
+ from huggingface_hub import HfApi
9
+
10
+
11
+ def require_env(name: str) -> str:
12
+ value = os.getenv(name)
13
+ if not value:
14
+ raise SystemExit(f"Missing required environment variable: {name}")
15
+ return value
16
+
17
+
18
+ def main() -> None:
19
+ token = require_env("HF_TOKEN")
20
+ space_name = os.getenv("HF_SPACE_NAME", "who-spoke-when")
21
+
22
+ api = HfApi(token=token)
23
+
24
+ username = os.getenv("HF_USERNAME")
25
+ if not username:
26
+ whoami = api.whoami(token=token)
27
+ username = whoami["name"]
28
+
29
+ space_id = f"{username}/{space_name}"
30
+
31
+ try:
32
+ api.create_repo(
33
+ repo_id=space_id,
34
+ repo_type="space",
35
+ space_sdk="docker",
36
+ private=False,
37
+ token=token,
38
+ exist_ok=True,
39
+ )
40
+ print(f"Space ready: {space_id}")
41
+ except Exception as exc:
42
+ raise SystemExit(f"Failed to create or fetch space '{space_id}': {exc}") from exc
43
+
44
+ remote_url = f"https://{username}:{token}@huggingface.co/spaces/{space_id}"
45
+ subprocess.run(["git", "remote", "remove", "huggingface"], check=False, capture_output=True)
46
+ subprocess.run(["git", "remote", "add", "huggingface", remote_url], check=True)
47
+
48
+ push_cmd = ["git", "push", "huggingface", "main"]
49
+ if os.getenv("HF_FORCE_PUSH", "false").lower() in {"1", "true", "yes"}:
50
+ push_cmd.append("--force")
51
+
52
+ subprocess.run(push_cmd, check=True)
53
+ print(f"Pushed to https://huggingface.co/spaces/{space_id}")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ try:
58
+ main()
59
+ except subprocess.CalledProcessError as exc:
60
+ sys.exit(exc.returncode)
models/embedder.py CHANGED
@@ -1,14 +1,16 @@
1
- """
2
  Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
3
  Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
4
  """
5
 
6
- import os
7
- import torch
8
- import torchaudio
9
- import numpy as np
10
  from pathlib import Path
11
  from typing import Union, List, Tuple
 
 
 
 
12
  from loguru import logger
13
 
14
 
@@ -22,7 +24,7 @@ class EcapaTDNNEmbedder:
22
  SAMPLE_RATE = 16000
23
  EMBEDDING_DIM = 192
24
 
25
- def __init__(self, device: str = "auto", cache_dir: str = "/tmp/model_cache"):
26
  self.device = self._resolve_device(device)
27
  self.cache_dir = Path(cache_dir)
28
  self.cache_dir.mkdir(parents=True, exist_ok=True)
@@ -39,41 +41,46 @@ class EcapaTDNNEmbedder:
39
  return
40
 
41
  try:
42
- import shutil
43
  import speechbrain.utils.fetching as _fetching
44
  from speechbrain.utils.fetching import LocalStrategy
 
45
 
46
  def _patched_link(src, dst, local_strategy):
47
- from pathlib import Path as _Path
48
- dst = _Path(dst)
49
- src = _Path(src)
50
- dst.parent.mkdir(parents=True, exist_ok=True)
51
- if dst.exists() or dst.is_symlink():
52
- dst.unlink()
53
- shutil.copy2(str(src), str(dst))
54
 
55
  _fetching.link_with_strategy = _patched_link
56
 
57
- from speechbrain.inference.classifiers import EncoderClassifier
 
 
 
58
 
59
  logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
 
60
 
61
- savedir = "/tmp/model_cache/ecapa_tdnn"
62
- os.makedirs(savedir, exist_ok=True)
63
- logger.info(f"Savedir: {savedir}, exists: {os.path.exists(savedir)}")
 
 
64
 
65
- self._model = EncoderClassifier.from_hparams(
66
- source=self.MODEL_SOURCE,
67
- savedir=savedir,
68
- run_opts={"device": self.device},
69
- huggingface_cache_dir="/tmp/hf_cache",
70
- local_strategy=LocalStrategy.COPY,
71
- )
72
  self._model.eval()
73
  logger.success("ECAPA-TDNN model loaded successfully.")
74
- except ImportError:
75
- raise ImportError("SpeechBrain not installed.")
76
-
77
  def preprocess_audio(
78
  self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
79
  ) -> torch.Tensor:
 
1
+ """
2
  Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
3
  Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
4
  """
5
 
6
+ import inspect
7
+ import shutil
 
 
8
  from pathlib import Path
9
  from typing import Union, List, Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
  from loguru import logger
15
 
16
 
 
24
  SAMPLE_RATE = 16000
25
  EMBEDDING_DIM = 192
26
 
27
+ def __init__(self, device: str = "auto", cache_dir: str = "./model_cache"):
28
  self.device = self._resolve_device(device)
29
  self.cache_dir = Path(cache_dir)
30
  self.cache_dir.mkdir(parents=True, exist_ok=True)
 
41
  return
42
 
43
  try:
 
44
  import speechbrain.utils.fetching as _fetching
45
  from speechbrain.utils.fetching import LocalStrategy
46
+ from speechbrain.inference.classifiers import EncoderClassifier
47
 
48
  def _patched_link(src, dst, local_strategy):
49
+ dst_path = Path(dst)
50
+ src_path = Path(src)
51
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
52
+ if dst_path.exists() or dst_path.is_symlink():
53
+ dst_path.unlink()
54
+ shutil.copy2(str(src_path), str(dst_path))
 
55
 
56
  _fetching.link_with_strategy = _patched_link
57
 
58
+ savedir = self.cache_dir / "ecapa_tdnn"
59
+ hf_cache = self.cache_dir / "hf_cache"
60
+ savedir.mkdir(parents=True, exist_ok=True)
61
+ hf_cache.mkdir(parents=True, exist_ok=True)
62
 
63
  logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
64
+ logger.info(f"Savedir: {savedir}, exists: {savedir.exists()}")
65
 
66
+ kwargs = {
67
+ "source": self.MODEL_SOURCE,
68
+ "savedir": str(savedir),
69
+ "run_opts": {"device": self.device},
70
+ }
71
 
72
+ sig = inspect.signature(EncoderClassifier.from_hparams)
73
+ if "huggingface_cache_dir" in sig.parameters:
74
+ kwargs["huggingface_cache_dir"] = str(hf_cache)
75
+ if "local_strategy" in sig.parameters:
76
+ kwargs["local_strategy"] = LocalStrategy.COPY
77
+
78
+ self._model = EncoderClassifier.from_hparams(**kwargs)
79
  self._model.eval()
80
  logger.success("ECAPA-TDNN model loaded successfully.")
81
+ except ImportError as exc:
82
+ raise ImportError("SpeechBrain not installed.") from exc
83
+
84
  def preprocess_audio(
85
  self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
86
  ) -> torch.Tensor: