Spaces:
Running
Running
Fix portability, model cache handling, and deploy token safety
Browse files- app/main.py +63 -52
- deploy_hf.py +60 -0
- models/embedder.py +35 -28
app/main.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Speaker Diarization API - FastAPI Application
|
| 3 |
-
"""
|
| 4 |
|
| 5 |
-
import io
|
| 6 |
-
import time
|
| 7 |
import asyncio
|
| 8 |
import tempfile
|
| 9 |
import traceback
|
|
@@ -19,12 +15,9 @@ from fastapi import (
|
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
from fastapi.staticfiles import StaticFiles
|
| 21 |
from fastapi.responses import HTMLResponse
|
| 22 |
-
from pydantic import BaseModel
|
| 23 |
from loguru import logger
|
| 24 |
|
| 25 |
-
# ---------------------------------------------------------------------------
|
| 26 |
-
# Schemas
|
| 27 |
-
# ---------------------------------------------------------------------------
|
| 28 |
|
| 29 |
class SegmentOut(BaseModel):
|
| 30 |
start: float
|
|
@@ -49,13 +42,9 @@ class HealthResponse(BaseModel):
|
|
| 49 |
version: str = "1.0.0"
|
| 50 |
|
| 51 |
|
| 52 |
-
# ---------------------------------------------------------------------------
|
| 53 |
-
# App
|
| 54 |
-
# ---------------------------------------------------------------------------
|
| 55 |
-
|
| 56 |
app = FastAPI(
|
| 57 |
title="Speaker Diarization API",
|
| 58 |
-
description="Who Spoke When
|
| 59 |
version="1.0.0",
|
| 60 |
)
|
| 61 |
|
|
@@ -69,12 +58,16 @@ app.add_middleware(
|
|
| 69 |
|
| 70 |
_pipeline = None
|
| 71 |
|
|
|
|
| 72 |
def get_pipeline():
|
| 73 |
global _pipeline
|
| 74 |
if _pipeline is None:
|
| 75 |
from app.pipeline import DiarizationPipeline
|
| 76 |
-
|
| 77 |
-
cache_dir = os.getenv(
|
|
|
|
|
|
|
|
|
|
| 78 |
_pipeline = DiarizationPipeline(
|
| 79 |
device="auto",
|
| 80 |
use_pyannote_vad=True,
|
|
@@ -85,10 +78,6 @@ def get_pipeline():
|
|
| 85 |
return _pipeline
|
| 86 |
|
| 87 |
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
# Endpoints
|
| 90 |
-
# ---------------------------------------------------------------------------
|
| 91 |
-
|
| 92 |
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
| 93 |
async def health_check():
|
| 94 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -136,6 +125,7 @@ async def diarize_from_url(
|
|
| 136 |
):
|
| 137 |
"""Diarize audio from a URL."""
|
| 138 |
import httpx
|
|
|
|
| 139 |
try:
|
| 140 |
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 141 |
resp = await client.get(audio_url)
|
|
@@ -169,6 +159,7 @@ async def stream_diarization(websocket: WebSocket):
|
|
| 169 |
"""Real-time streaming diarization via WebSocket."""
|
| 170 |
await websocket.accept()
|
| 171 |
import numpy as np
|
|
|
|
| 172 |
audio_buffer = bytearray()
|
| 173 |
sample_rate = 16000
|
| 174 |
num_speakers = None
|
|
@@ -179,10 +170,12 @@ async def stream_diarization(websocket: WebSocket):
|
|
| 179 |
sample_rate = config_msg.get("sample_rate", 16000)
|
| 180 |
num_speakers = config_msg.get("num_speakers", None)
|
| 181 |
|
| 182 |
-
await websocket.send_json(
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
| 186 |
|
| 187 |
while True:
|
| 188 |
try:
|
|
@@ -194,12 +187,18 @@ async def stream_diarization(websocket: WebSocket):
|
|
| 194 |
if "bytes" in msg:
|
| 195 |
audio_buffer.extend(msg["bytes"])
|
| 196 |
chunk_count += 1
|
| 197 |
-
await websocket.send_json(
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
elif "text" in msg:
|
| 202 |
import json
|
|
|
|
| 203 |
data = json.loads(msg["text"])
|
| 204 |
if data.get("type") == "eof":
|
| 205 |
break
|
|
@@ -208,14 +207,17 @@ async def stream_diarization(websocket: WebSocket):
|
|
| 208 |
await websocket.send_json({"type": "error", "data": {"message": "No audio received"}})
|
| 209 |
return
|
| 210 |
|
| 211 |
-
import torch
|
|
|
|
| 212 |
audio_np = np.frombuffer(audio_buffer, dtype=np.float32).copy()
|
| 213 |
-
audio_tensor =
|
| 214 |
|
| 215 |
-
await websocket.send_json(
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
|
| 220 |
loop = asyncio.get_event_loop()
|
| 221 |
pipeline = get_pipeline()
|
|
@@ -227,15 +229,17 @@ async def stream_diarization(websocket: WebSocket):
|
|
| 227 |
for seg in result.segments:
|
| 228 |
await websocket.send_json({"type": "segment", "data": seg.to_dict()})
|
| 229 |
|
| 230 |
-
await websocket.send_json(
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
"
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
|
| 240 |
except WebSocketDisconnect:
|
| 241 |
logger.info("WebSocket client disconnected")
|
|
@@ -249,26 +253,33 @@ async def stream_diarization(websocket: WebSocket):
|
|
| 249 |
|
| 250 |
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
| 251 |
async def serve_ui():
|
| 252 |
-
ui_path = Path("static/index.html"
|
| 253 |
if ui_path.exists():
|
| 254 |
-
return HTMLResponse(ui_path.read_text())
|
| 255 |
-
return HTMLResponse("<h1>Speaker Diarization API</h1><p><a href='/docs'>API Docs
|
|
|
|
| 256 |
|
| 257 |
@app.get("/debug", tags=["System"])
|
| 258 |
async def debug():
|
| 259 |
-
import speechbrain
|
| 260 |
-
import os
|
| 261 |
import inspect
|
|
|
|
| 262 |
from speechbrain.inference.classifiers import EncoderClassifier
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
sig = str(inspect.signature(EncoderClassifier.from_hparams))
|
| 264 |
return {
|
| 265 |
"speechbrain_version": speechbrain.__version__,
|
| 266 |
-
"
|
| 267 |
-
"
|
|
|
|
|
|
|
| 268 |
"from_hparams_signature": sig,
|
| 269 |
}
|
| 270 |
|
| 271 |
|
| 272 |
-
static_dir = Path("static"
|
| 273 |
if static_dir.exists():
|
| 274 |
-
app.mount("/static", StaticFiles(directory=
|
|
|
|
| 1 |
+
"""Speaker Diarization API - FastAPI Application."""
|
|
|
|
|
|
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import asyncio
|
| 4 |
import tempfile
|
| 5 |
import traceback
|
|
|
|
| 15 |
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
from fastapi.staticfiles import StaticFiles
|
| 17 |
from fastapi.responses import HTMLResponse
|
| 18 |
+
from pydantic import BaseModel
|
| 19 |
from loguru import logger
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class SegmentOut(BaseModel):
|
| 23 |
start: float
|
|
|
|
| 42 |
version: str = "1.0.0"
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
app = FastAPI(
|
| 46 |
title="Speaker Diarization API",
|
| 47 |
+
description="Who Spoke When - Speaker diarization using ECAPA-TDNN + AHC Clustering",
|
| 48 |
version="1.0.0",
|
| 49 |
)
|
| 50 |
|
|
|
|
| 58 |
|
| 59 |
_pipeline = None
|
| 60 |
|
| 61 |
+
|
| 62 |
def get_pipeline():
|
| 63 |
global _pipeline
|
| 64 |
if _pipeline is None:
|
| 65 |
from app.pipeline import DiarizationPipeline
|
| 66 |
+
|
| 67 |
+
cache_dir = os.getenv(
|
| 68 |
+
"CACHE_DIR",
|
| 69 |
+
str(Path(tempfile.gettempdir()) / "model_cache"),
|
| 70 |
+
)
|
| 71 |
_pipeline = DiarizationPipeline(
|
| 72 |
device="auto",
|
| 73 |
use_pyannote_vad=True,
|
|
|
|
| 78 |
return _pipeline
|
| 79 |
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
| 82 |
async def health_check():
|
| 83 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 125 |
):
|
| 126 |
"""Diarize audio from a URL."""
|
| 127 |
import httpx
|
| 128 |
+
|
| 129 |
try:
|
| 130 |
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 131 |
resp = await client.get(audio_url)
|
|
|
|
| 159 |
"""Real-time streaming diarization via WebSocket."""
|
| 160 |
await websocket.accept()
|
| 161 |
import numpy as np
|
| 162 |
+
|
| 163 |
audio_buffer = bytearray()
|
| 164 |
sample_rate = 16000
|
| 165 |
num_speakers = None
|
|
|
|
| 170 |
sample_rate = config_msg.get("sample_rate", 16000)
|
| 171 |
num_speakers = config_msg.get("num_speakers", None)
|
| 172 |
|
| 173 |
+
await websocket.send_json(
|
| 174 |
+
{
|
| 175 |
+
"type": "progress",
|
| 176 |
+
"data": {"message": "Config received. Send audio chunks.", "chunks_received": 0},
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
|
| 180 |
while True:
|
| 181 |
try:
|
|
|
|
| 187 |
if "bytes" in msg:
|
| 188 |
audio_buffer.extend(msg["bytes"])
|
| 189 |
chunk_count += 1
|
| 190 |
+
await websocket.send_json(
|
| 191 |
+
{
|
| 192 |
+
"type": "progress",
|
| 193 |
+
"data": {
|
| 194 |
+
"message": f"Received chunk {chunk_count}",
|
| 195 |
+
"chunks_received": chunk_count,
|
| 196 |
+
},
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
elif "text" in msg:
|
| 200 |
import json
|
| 201 |
+
|
| 202 |
data = json.loads(msg["text"])
|
| 203 |
if data.get("type") == "eof":
|
| 204 |
break
|
|
|
|
| 207 |
await websocket.send_json({"type": "error", "data": {"message": "No audio received"}})
|
| 208 |
return
|
| 209 |
|
| 210 |
+
import torch as torch_local
|
| 211 |
+
|
| 212 |
audio_np = np.frombuffer(audio_buffer, dtype=np.float32).copy()
|
| 213 |
+
audio_tensor = torch_local.from_numpy(audio_np)
|
| 214 |
|
| 215 |
+
await websocket.send_json(
|
| 216 |
+
{
|
| 217 |
+
"type": "progress",
|
| 218 |
+
"data": {"message": "Running diarization pipeline..."},
|
| 219 |
+
}
|
| 220 |
+
)
|
| 221 |
|
| 222 |
loop = asyncio.get_event_loop()
|
| 223 |
pipeline = get_pipeline()
|
|
|
|
| 229 |
for seg in result.segments:
|
| 230 |
await websocket.send_json({"type": "segment", "data": seg.to_dict()})
|
| 231 |
|
| 232 |
+
await websocket.send_json(
|
| 233 |
+
{
|
| 234 |
+
"type": "done",
|
| 235 |
+
"data": {
|
| 236 |
+
"num_speakers": result.num_speakers,
|
| 237 |
+
"total_segments": len(result.segments),
|
| 238 |
+
"audio_duration": result.audio_duration,
|
| 239 |
+
"processing_time": result.processing_time,
|
| 240 |
+
},
|
| 241 |
+
}
|
| 242 |
+
)
|
| 243 |
|
| 244 |
except WebSocketDisconnect:
|
| 245 |
logger.info("WebSocket client disconnected")
|
|
|
|
| 253 |
|
| 254 |
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
| 255 |
async def serve_ui():
|
| 256 |
+
ui_path = Path(__file__).resolve().parent.parent / "static" / "index.html"
|
| 257 |
if ui_path.exists():
|
| 258 |
+
return HTMLResponse(ui_path.read_text(encoding="utf-8"))
|
| 259 |
+
return HTMLResponse("<h1>Speaker Diarization API</h1><p><a href='/docs'>API Docs</a></p>")
|
| 260 |
+
|
| 261 |
|
| 262 |
@app.get("/debug", tags=["System"])
|
| 263 |
async def debug():
|
|
|
|
|
|
|
| 264 |
import inspect
|
| 265 |
+
import speechbrain
|
| 266 |
from speechbrain.inference.classifiers import EncoderClassifier
|
| 267 |
+
|
| 268 |
+
cache_dir = os.getenv(
|
| 269 |
+
"CACHE_DIR",
|
| 270 |
+
str(Path(tempfile.gettempdir()) / "model_cache"),
|
| 271 |
+
)
|
| 272 |
sig = str(inspect.signature(EncoderClassifier.from_hparams))
|
| 273 |
return {
|
| 274 |
"speechbrain_version": speechbrain.__version__,
|
| 275 |
+
"temp_dir": tempfile.gettempdir(),
|
| 276 |
+
"temp_writable": os.access(tempfile.gettempdir(), os.W_OK),
|
| 277 |
+
"cache_dir": cache_dir,
|
| 278 |
+
"cache_exists": os.path.exists(cache_dir),
|
| 279 |
"from_hparams_signature": sig,
|
| 280 |
}
|
| 281 |
|
| 282 |
|
| 283 |
+
static_dir = Path(__file__).resolve().parent.parent / "static"
|
| 284 |
if static_dir.exists():
|
| 285 |
+
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
deploy_hf.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Deploy this project to a Hugging Face Space."""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import HfApi
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def require_env(name: str) -> str:
|
| 12 |
+
value = os.getenv(name)
|
| 13 |
+
if not value:
|
| 14 |
+
raise SystemExit(f"Missing required environment variable: {name}")
|
| 15 |
+
return value
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main() -> None:
|
| 19 |
+
token = require_env("HF_TOKEN")
|
| 20 |
+
space_name = os.getenv("HF_SPACE_NAME", "who-spoke-when")
|
| 21 |
+
|
| 22 |
+
api = HfApi(token=token)
|
| 23 |
+
|
| 24 |
+
username = os.getenv("HF_USERNAME")
|
| 25 |
+
if not username:
|
| 26 |
+
whoami = api.whoami(token=token)
|
| 27 |
+
username = whoami["name"]
|
| 28 |
+
|
| 29 |
+
space_id = f"{username}/{space_name}"
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
api.create_repo(
|
| 33 |
+
repo_id=space_id,
|
| 34 |
+
repo_type="space",
|
| 35 |
+
space_sdk="docker",
|
| 36 |
+
private=False,
|
| 37 |
+
token=token,
|
| 38 |
+
exist_ok=True,
|
| 39 |
+
)
|
| 40 |
+
print(f"Space ready: {space_id}")
|
| 41 |
+
except Exception as exc:
|
| 42 |
+
raise SystemExit(f"Failed to create or fetch space '{space_id}': {exc}") from exc
|
| 43 |
+
|
| 44 |
+
remote_url = f"https://{username}:{token}@huggingface.co/spaces/{space_id}"
|
| 45 |
+
subprocess.run(["git", "remote", "remove", "huggingface"], check=False, capture_output=True)
|
| 46 |
+
subprocess.run(["git", "remote", "add", "huggingface", remote_url], check=True)
|
| 47 |
+
|
| 48 |
+
push_cmd = ["git", "push", "huggingface", "main"]
|
| 49 |
+
if os.getenv("HF_FORCE_PUSH", "false").lower() in {"1", "true", "yes"}:
|
| 50 |
+
push_cmd.append("--force")
|
| 51 |
+
|
| 52 |
+
subprocess.run(push_cmd, check=True)
|
| 53 |
+
print(f"Pushed to https://huggingface.co/spaces/{space_id}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
try:
|
| 58 |
+
main()
|
| 59 |
+
except subprocess.CalledProcessError as exc:
|
| 60 |
+
sys.exit(exc.returncode)
|
models/embedder.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
-
"""
|
| 2 |
Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
|
| 3 |
Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import torchaudio
|
| 9 |
-
import numpy as np
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Union, List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from loguru import logger
|
| 13 |
|
| 14 |
|
|
@@ -22,7 +24,7 @@ class EcapaTDNNEmbedder:
|
|
| 22 |
SAMPLE_RATE = 16000
|
| 23 |
EMBEDDING_DIM = 192
|
| 24 |
|
| 25 |
-
def __init__(self, device: str = "auto", cache_dir: str = "/
|
| 26 |
self.device = self._resolve_device(device)
|
| 27 |
self.cache_dir = Path(cache_dir)
|
| 28 |
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -39,41 +41,46 @@ class EcapaTDNNEmbedder:
|
|
| 39 |
return
|
| 40 |
|
| 41 |
try:
|
| 42 |
-
import shutil
|
| 43 |
import speechbrain.utils.fetching as _fetching
|
| 44 |
from speechbrain.utils.fetching import LocalStrategy
|
|
|
|
| 45 |
|
| 46 |
def _patched_link(src, dst, local_strategy):
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
shutil.copy2(str(src), str(dst))
|
| 54 |
|
| 55 |
_fetching.link_with_strategy = _patched_link
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
)
|
| 72 |
self._model.eval()
|
| 73 |
logger.success("ECAPA-TDNN model loaded successfully.")
|
| 74 |
-
except ImportError:
|
| 75 |
-
raise ImportError("SpeechBrain not installed.")
|
| 76 |
-
|
| 77 |
def preprocess_audio(
|
| 78 |
self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
|
| 79 |
) -> torch.Tensor:
|
|
|
|
| 1 |
+
"""
|
| 2 |
Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
|
| 3 |
Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import inspect
|
| 7 |
+
import shutil
|
|
|
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Union, List, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torchaudio
|
| 14 |
from loguru import logger
|
| 15 |
|
| 16 |
|
|
|
|
| 24 |
SAMPLE_RATE = 16000
|
| 25 |
EMBEDDING_DIM = 192
|
| 26 |
|
| 27 |
+
def __init__(self, device: str = "auto", cache_dir: str = "./model_cache"):
|
| 28 |
self.device = self._resolve_device(device)
|
| 29 |
self.cache_dir = Path(cache_dir)
|
| 30 |
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 41 |
return
|
| 42 |
|
| 43 |
try:
|
|
|
|
| 44 |
import speechbrain.utils.fetching as _fetching
|
| 45 |
from speechbrain.utils.fetching import LocalStrategy
|
| 46 |
+
from speechbrain.inference.classifiers import EncoderClassifier
|
| 47 |
|
| 48 |
def _patched_link(src, dst, local_strategy):
|
| 49 |
+
dst_path = Path(dst)
|
| 50 |
+
src_path = Path(src)
|
| 51 |
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
if dst_path.exists() or dst_path.is_symlink():
|
| 53 |
+
dst_path.unlink()
|
| 54 |
+
shutil.copy2(str(src_path), str(dst_path))
|
|
|
|
| 55 |
|
| 56 |
_fetching.link_with_strategy = _patched_link
|
| 57 |
|
| 58 |
+
savedir = self.cache_dir / "ecapa_tdnn"
|
| 59 |
+
hf_cache = self.cache_dir / "hf_cache"
|
| 60 |
+
savedir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
hf_cache.mkdir(parents=True, exist_ok=True)
|
| 62 |
|
| 63 |
logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
|
| 64 |
+
logger.info(f"Savedir: {savedir}, exists: {savedir.exists()}")
|
| 65 |
|
| 66 |
+
kwargs = {
|
| 67 |
+
"source": self.MODEL_SOURCE,
|
| 68 |
+
"savedir": str(savedir),
|
| 69 |
+
"run_opts": {"device": self.device},
|
| 70 |
+
}
|
| 71 |
|
| 72 |
+
sig = inspect.signature(EncoderClassifier.from_hparams)
|
| 73 |
+
if "huggingface_cache_dir" in sig.parameters:
|
| 74 |
+
kwargs["huggingface_cache_dir"] = str(hf_cache)
|
| 75 |
+
if "local_strategy" in sig.parameters:
|
| 76 |
+
kwargs["local_strategy"] = LocalStrategy.COPY
|
| 77 |
+
|
| 78 |
+
self._model = EncoderClassifier.from_hparams(**kwargs)
|
| 79 |
self._model.eval()
|
| 80 |
logger.success("ECAPA-TDNN model loaded successfully.")
|
| 81 |
+
except ImportError as exc:
|
| 82 |
+
raise ImportError("SpeechBrain not installed.") from exc
|
| 83 |
+
|
| 84 |
def preprocess_audio(
|
| 85 |
self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
|
| 86 |
) -> torch.Tensor:
|