Spaces:
Sleeping
Sleeping
| """Storage backends for speaker embeddings.""" | |
| import json | |
| import logging | |
| import os | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import Any, cast | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError | |
| logger = logging.getLogger("uvicorn.error") | |
| EMBEDDINGS_FILENAME = "embeddings.json" | |
| EmbeddingsDict = dict[str, list[float]] | |
| def _load_json(f: Any) -> EmbeddingsDict: | |
| """Load JSON and cast to embeddings dict type.""" | |
| return cast(EmbeddingsDict, json.load(f)) | |
| class StorageBackend(ABC): | |
| """Abstract base class for storage backends.""" | |
| def load(self) -> dict[str, list[float]]: | |
| """Load speaker embeddings.""" | |
| pass | |
| def save(self, embeddings: dict[str, list[float]]) -> None: | |
| """Save speaker embeddings.""" | |
| pass | |
| class LocalStorage(StorageBackend): | |
| """Local file storage backend.""" | |
| def __init__(self, storage_dir: Path = Path("speakers")): | |
| self.storage_dir = storage_dir | |
| self.storage_file = storage_dir / EMBEDDINGS_FILENAME | |
| def load(self) -> dict[str, list[float]]: | |
| """Load speaker embeddings from local file.""" | |
| if self.storage_file.exists(): | |
| with open(self.storage_file) as f: | |
| return _load_json(f) | |
| return {} | |
| def save(self, embeddings: dict[str, list[float]]) -> None: | |
| """Save speaker embeddings to local file.""" | |
| self.storage_dir.mkdir(exist_ok=True) | |
| with open(self.storage_file, "w") as f: | |
| json.dump(embeddings, f) | |
| class HuggingFaceStorage(StorageBackend): | |
| """Hugging Face Hub dataset storage backend.""" | |
| def __init__(self, repo_id: str, token: str | None = None): | |
| """ | |
| Initialize HuggingFace storage. | |
| Args: | |
| repo_id: The HuggingFace dataset repository ID | |
| (e.g., "username/dataset-name") | |
| token: HuggingFace API token. If None, uses HF_TOKEN env var | |
| or cached login. | |
| """ | |
| self.repo_id = repo_id | |
| self.token = token or os.environ.get("HF_TOKEN") | |
| self.api = HfApi(token=self.token) | |
| self._local_cache = Path("speakers") / EMBEDDINGS_FILENAME | |
| def load(self) -> dict[str, list[float]]: | |
| """Load speaker embeddings from HuggingFace dataset.""" | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=EMBEDDINGS_FILENAME, | |
| repo_type="dataset", | |
| token=self.token, | |
| ) | |
| with open(local_path) as f: | |
| return _load_json(f) | |
| except (EntryNotFoundError, RepositoryNotFoundError): | |
| logger.info( | |
| f"No embeddings found in {self.repo_id}, starting with empty database" | |
| ) | |
| return {} | |
| except Exception as e: | |
| logger.warning(f"Failed to load from HuggingFace: {e}, using local cache") | |
| if self._local_cache.exists(): | |
| with open(self._local_cache) as f: | |
| return _load_json(f) | |
| return {} | |
| def save(self, embeddings: dict[str, list[float]]) -> None: | |
| """Save speaker embeddings to HuggingFace dataset.""" | |
| # Save locally first | |
| self._local_cache.parent.mkdir(exist_ok=True) | |
| with open(self._local_cache, "w") as f: | |
| json.dump(embeddings, f) | |
| # Upload to HuggingFace | |
| try: | |
| self.api.upload_file( | |
| path_or_fileobj=str(self._local_cache), | |
| path_in_repo=EMBEDDINGS_FILENAME, | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| commit_message="Update speaker embeddings", | |
| ) | |
| logger.info(f"Embeddings saved to {self.repo_id}") | |
| except RepositoryNotFoundError: | |
| logger.info(f"Creating dataset repository {self.repo_id}") | |
| self.api.create_repo( | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| private=True, | |
| ) | |
| self.api.upload_file( | |
| path_or_fileobj=str(self._local_cache), | |
| path_in_repo=EMBEDDINGS_FILENAME, | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| commit_message="Initial speaker embeddings", | |
| ) | |
| logger.info(f"Created and saved embeddings to {self.repo_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to save to HuggingFace: {e}") | |
| raise | |
| def get_storage_backend() -> StorageBackend: | |
| """ | |
| Get the appropriate storage backend based on environment variables. | |
| Environment variables: | |
| HF_EMBEDDINGS_REPO: HuggingFace dataset repo ID for storing embeddings | |
| HF_TOKEN: HuggingFace API token (optional if logged in via CLI) | |
| Returns: | |
| LocalStorage if HF_EMBEDDINGS_REPO is not set, otherwise HuggingFaceStorage | |
| """ | |
| hf_repo = os.environ.get("HF_EMBEDDINGS_REPO") | |
| if hf_repo: | |
| logger.info(f"Using HuggingFace storage: {hf_repo}") | |
| return HuggingFaceStorage(repo_id=hf_repo) | |
| logger.info("Using local storage") | |
| return LocalStorage() | |