Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ModelManager: HuggingFace model download, caching, and lazy loading. | |
| Handles automatic model downloads on first run, local caching, and singleton | |
| pattern to ensure models are loaded only once per process. | |
| """ | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """ | |
| Manages HuggingFace model downloads and caching. | |
| Implements singleton pattern to ensure models are loaded once and reused | |
| across all processing operations. | |
| """ | |
| _instance: Optional["ModelManager"] = None | |
| _models: Dict[str, Any] = {} | |
| _models_loaded: bool = False | |
| def __new__(cls): | |
| """Singleton pattern - one instance per process.""" | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self, cache_dir: str = "./models"): | |
| """ | |
| Initialize ModelManager. | |
| Args: | |
| cache_dir: Directory for caching downloaded models (default: ./models) | |
| """ | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| # Set HuggingFace cache environment variable | |
| os.environ["HF_HOME"] = str(self.cache_dir) | |
| os.environ["TRANSFORMERS_CACHE"] = str(self.cache_dir) | |
| def get_hf_token(self) -> Optional[str]: | |
| """ | |
| Get HuggingFace authentication token. | |
| Checks multiple sources in order: | |
| 1. HF_TOKEN environment variable | |
| 2. ~/.cache/huggingface/token file (from huggingface-cli login) | |
| Returns: | |
| HuggingFace token or None if not found | |
| """ | |
| # Check environment variable | |
| token = os.environ.get("HF_TOKEN") | |
| if token: | |
| return token | |
| # Check huggingface-cli token file | |
| token_file = Path.home() / ".cache" / "huggingface" / "token" | |
| if token_file.exists(): | |
| return token_file.read_text().strip() | |
| return None | |
| def load_speaker_diarization(self, progress_callback=None) -> Any: | |
| """ | |
| Load pyannote speaker diarization model. | |
| Downloads on first run (~150MB), then loads from cache. | |
| Requires HuggingFace authentication and license acceptance at: | |
| https://huggingface.co/pyannote/speaker-diarization-3.1 | |
| Args: | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| Loaded pyannote Pipeline object | |
| Raises: | |
| RuntimeError: If authentication fails or model cannot be loaded | |
| """ | |
| if "diarization" in self._models: | |
| return self._models["diarization"] | |
| if progress_callback: | |
| progress_callback(0.0, "Loading speaker diarization model...") | |
| logger.info("Loading speaker diarization model (first run downloads ~150MB)") | |
| try: | |
| import torch | |
| from pyannote.audio import Pipeline | |
| token = self.get_hf_token() | |
| if not token: | |
| raise RuntimeError("HuggingFace token not found. Please run: huggingface-cli login") | |
| pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", token=token, cache_dir=str(self.cache_dir) | |
| ) | |
| # Force CPU execution | |
| pipeline.to(torch.device("cpu")) | |
| self._models["diarization"] = pipeline | |
| if progress_callback: | |
| progress_callback(1.0, "Speaker diarization model loaded") | |
| logger.info("✓ Speaker diarization model loaded successfully") | |
| return pipeline | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "401" in error_msg or "authentication" in error_msg.lower(): | |
| raise RuntimeError( | |
| "Authentication failed for pyannote/speaker-diarization-3.1. " | |
| "Please:\n" | |
| "1. Run: huggingface-cli login\n" | |
| "2. Accept license at: https://huggingface.co/pyannote/speaker-diarization-3.1" | |
| ) | |
| elif "disk" in error_msg.lower() or "space" in error_msg.lower(): | |
| raise RuntimeError( | |
| f"Insufficient disk space to download model. " | |
| f"Need ~600MB free in {self.cache_dir}" | |
| ) | |
| else: | |
| raise RuntimeError(f"Failed to load speaker diarization model: {error_msg}") | |
| def load_embedding_model(self, progress_callback=None) -> Any: | |
| """ | |
| Load pyannote embedding model for voice matching. | |
| Downloads on first run (~17MB), then loads from cache. | |
| Args: | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| Loaded pyannote Model object | |
| """ | |
| if "embedding" in self._models: | |
| return self._models["embedding"] | |
| if progress_callback: | |
| progress_callback(0.0, "Loading voice embedding model...") | |
| logger.info("Loading voice embedding model (first run downloads ~17MB)") | |
| try: | |
| from pyannote.audio import Model | |
| token = self.get_hf_token() | |
| if not token: | |
| raise RuntimeError("HuggingFace token not found. Please run: huggingface-cli login") | |
| model = Model.from_pretrained( | |
| "pyannote/embedding", token=token, cache_dir=str(self.cache_dir) | |
| ) | |
| self._models["embedding"] = model | |
| if progress_callback: | |
| progress_callback(1.0, "Voice embedding model loaded") | |
| logger.info("✓ Voice embedding model loaded successfully") | |
| return model | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load embedding model: {str(e)}") | |
| def load_ast_classifier(self, progress_callback=None) -> Any: | |
| """ | |
| Load Audio Spectrogram Transformer for speech classification. | |
| Downloads on first run (~340MB), then loads from cache. | |
| Args: | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| Tuple of (feature_extractor, classifier_model) | |
| """ | |
| if "ast" in self._models: | |
| return self._models["ast"] | |
| if progress_callback: | |
| progress_callback(0.0, "Loading audio classifier model...") | |
| logger.info("Loading audio classifier model (first run downloads ~340MB)") | |
| try: | |
| from transformers import ASTFeatureExtractor, ASTForAudioClassification | |
| feature_extractor = ASTFeatureExtractor.from_pretrained( | |
| "MIT/ast-finetuned-audioset-10-10-0.4593", cache_dir=str(self.cache_dir) | |
| ) | |
| classifier = ASTForAudioClassification.from_pretrained( | |
| "MIT/ast-finetuned-audioset-10-10-0.4593", cache_dir=str(self.cache_dir) | |
| ) | |
| classifier.eval() # Set to inference mode | |
| self._models["ast"] = (feature_extractor, classifier) | |
| if progress_callback: | |
| progress_callback(1.0, "Audio classifier model loaded") | |
| logger.info("✓ Audio classifier model loaded successfully") | |
| return self._models["ast"] | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load AST classifier: {str(e)}") | |
| def load_vad_model(self, progress_callback=None) -> Any: | |
| """ | |
| Load Silero VAD model for voice activity detection. | |
| Downloads on first run (~1.5MB), then loads from cache. | |
| Args: | |
| progress_callback: Optional callback(progress: float, message: str) | |
| Returns: | |
| Loaded Silero VAD model | |
| """ | |
| if "vad" in self._models: | |
| return self._models["vad"] | |
| if progress_callback: | |
| progress_callback(0.0, "Loading voice activity detection model...") | |
| logger.info("Loading VAD model (first run downloads ~1.5MB)") | |
| try: | |
| import torch | |
| # Silero VAD uses torch.hub | |
| model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad", | |
| model="silero_vad", | |
| force_reload=False, | |
| onnx=False, | |
| ) | |
| model.eval() | |
| self._models["vad"] = (model, utils) | |
| if progress_callback: | |
| progress_callback(1.0, "VAD model loaded") | |
| logger.info("✓ VAD model loaded successfully") | |
| return self._models["vad"] | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load VAD model: {str(e)}") | |
| def models_are_cached(self) -> bool: | |
| """ | |
| Check if all required models are already downloaded. | |
| Returns: | |
| True if all models are cached locally, False otherwise | |
| """ | |
| required_models = [ | |
| "pyannote--speaker-diarization-3.1", | |
| "pyannote--embedding", | |
| "MIT--ast-finetuned-audioset-10-10-0.4593", | |
| ] | |
| hub_cache = self.cache_dir / "hub" | |
| if not hub_cache.exists(): | |
| return False | |
| for model_name in required_models: | |
| model_path = hub_cache / f"models--{model_name}" | |
| if not model_path.exists(): | |
| return False | |
| return True | |
| def get_cache_size(self) -> int: | |
| """ | |
| Get total size of cached models in bytes. | |
| Returns: | |
| Total cache size in bytes | |
| """ | |
| total_size = 0 | |
| for path in self.cache_dir.rglob("*"): | |
| if path.is_file(): | |
| total_size += path.stat().st_size | |
| return total_size | |
| def clear_cache(self): | |
| """ | |
| Clear all cached models. | |
| WARNING: This will force re-download on next use (~600MB). | |
| """ | |
| import shutil | |
| if self.cache_dir.exists(): | |
| shutil.rmtree(self.cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Cleared model cache at {self.cache_dir}") | |
| # Reset loaded models | |
| self._models.clear() | |
| self._models_loaded = False | |