goabonga's picture
feat: add speaker recognition API with SpeechBrain ECAPA-TDNN
7323d5e unverified
"""Storage backends for speaker embeddings."""
import json
import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, cast
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
logger = logging.getLogger("uvicorn.error")
EMBEDDINGS_FILENAME = "embeddings.json"
EmbeddingsDict = dict[str, list[float]]
def _load_json(f: Any) -> EmbeddingsDict:
"""Load JSON and cast to embeddings dict type."""
return cast(EmbeddingsDict, json.load(f))
class StorageBackend(ABC):
"""Abstract base class for storage backends."""
@abstractmethod
def load(self) -> dict[str, list[float]]:
"""Load speaker embeddings."""
pass
@abstractmethod
def save(self, embeddings: dict[str, list[float]]) -> None:
"""Save speaker embeddings."""
pass
class LocalStorage(StorageBackend):
"""Local file storage backend."""
def __init__(self, storage_dir: Path = Path("speakers")):
self.storage_dir = storage_dir
self.storage_file = storage_dir / EMBEDDINGS_FILENAME
def load(self) -> dict[str, list[float]]:
"""Load speaker embeddings from local file."""
if self.storage_file.exists():
with open(self.storage_file) as f:
return _load_json(f)
return {}
def save(self, embeddings: dict[str, list[float]]) -> None:
"""Save speaker embeddings to local file."""
self.storage_dir.mkdir(exist_ok=True)
with open(self.storage_file, "w") as f:
json.dump(embeddings, f)
class HuggingFaceStorage(StorageBackend):
"""Hugging Face Hub dataset storage backend."""
def __init__(self, repo_id: str, token: str | None = None):
"""
Initialize HuggingFace storage.
Args:
repo_id: The HuggingFace dataset repository ID
(e.g., "username/dataset-name")
token: HuggingFace API token. If None, uses HF_TOKEN env var
or cached login.
"""
self.repo_id = repo_id
self.token = token or os.environ.get("HF_TOKEN")
self.api = HfApi(token=self.token)
self._local_cache = Path("speakers") / EMBEDDINGS_FILENAME
def load(self) -> dict[str, list[float]]:
"""Load speaker embeddings from HuggingFace dataset."""
try:
local_path = hf_hub_download(
repo_id=self.repo_id,
filename=EMBEDDINGS_FILENAME,
repo_type="dataset",
token=self.token,
)
with open(local_path) as f:
return _load_json(f)
except (EntryNotFoundError, RepositoryNotFoundError):
logger.info(
f"No embeddings found in {self.repo_id}, starting with empty database"
)
return {}
except Exception as e:
logger.warning(f"Failed to load from HuggingFace: {e}, using local cache")
if self._local_cache.exists():
with open(self._local_cache) as f:
return _load_json(f)
return {}
def save(self, embeddings: dict[str, list[float]]) -> None:
"""Save speaker embeddings to HuggingFace dataset."""
# Save locally first
self._local_cache.parent.mkdir(exist_ok=True)
with open(self._local_cache, "w") as f:
json.dump(embeddings, f)
# Upload to HuggingFace
try:
self.api.upload_file(
path_or_fileobj=str(self._local_cache),
path_in_repo=EMBEDDINGS_FILENAME,
repo_id=self.repo_id,
repo_type="dataset",
commit_message="Update speaker embeddings",
)
logger.info(f"Embeddings saved to {self.repo_id}")
except RepositoryNotFoundError:
logger.info(f"Creating dataset repository {self.repo_id}")
self.api.create_repo(
repo_id=self.repo_id,
repo_type="dataset",
private=True,
)
self.api.upload_file(
path_or_fileobj=str(self._local_cache),
path_in_repo=EMBEDDINGS_FILENAME,
repo_id=self.repo_id,
repo_type="dataset",
commit_message="Initial speaker embeddings",
)
logger.info(f"Created and saved embeddings to {self.repo_id}")
except Exception as e:
logger.error(f"Failed to save to HuggingFace: {e}")
raise
def get_storage_backend() -> StorageBackend:
"""
Get the appropriate storage backend based on environment variables.
Environment variables:
HF_EMBEDDINGS_REPO: HuggingFace dataset repo ID for storing embeddings
HF_TOKEN: HuggingFace API token (optional if logged in via CLI)
Returns:
LocalStorage if HF_EMBEDDINGS_REPO is not set, otherwise HuggingFaceStorage
"""
hf_repo = os.environ.get("HF_EMBEDDINGS_REPO")
if hf_repo:
logger.info(f"Using HuggingFace storage: {hf_repo}")
return HuggingFaceStorage(repo_id=hf_repo)
logger.info("Using local storage")
return LocalStorage()