""" HuggingFace-compatible model loader for Romanian Matcha-TTS """ import json import os import torch from pathlib import Path from typing import Optional, Dict, Any try: from huggingface_hub import hf_hub_download HF_AVAILABLE = True except ImportError: HF_AVAILABLE = False class ModelLoader: """ HuggingFace-compatible loader for Romanian Matcha-TTS models Usage: loader = ModelLoader.from_pretrained("adrianstanea/Ro-Matcha-TTS") model, vocoder = loader.load_models(speaker="BAS") """ def __init__(self, repo_path: str): """ Initialize with local repository path or HuggingFace repo ID Args: repo_path: Path to local repo or HuggingFace repo ID """ self.repo_path = repo_path self.config = self._load_config() @classmethod def from_pretrained(cls, repo_id: str, cache_dir: Optional[str] = None) -> "ModelLoader": """ Load from HuggingFace Hub or local path Args: repo_id: HuggingFace repo ID (e.g., "adrianstanea/Ro-Matcha-TTS") or local path cache_dir: Optional cache directory for downloads Returns: ModelLoader instance """ if os.path.exists(repo_id): # Local path return cls(repo_id) elif HF_AVAILABLE: # Download from HuggingFace Hub try: config_path = hf_hub_download( repo_id=repo_id, filename="configs/config.json", cache_dir=cache_dir ) repo_cache_path = Path(config_path).parent.parent return cls(str(repo_cache_path)) except Exception as e: raise ValueError(f"Could not download from HuggingFace Hub: {e}") else: raise ImportError("huggingface_hub is required for downloading from HF Hub. Install with: pip install huggingface_hub") def _load_config(self) -> Dict[str, Any]: """Load model configuration""" config_path = os.path.join(self.repo_path, "configs", "config.json") if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found at {config_path}") with open(config_path, 'r') as f: return json.load(f) def get_model_path(self, model: str = None) -> str: """ Get path to model checkpoint for specified model Args: model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950). If None, uses default. Returns: Absolute path to model checkpoint """ if model is None: model = self.config["default_model"] if model not in self.config["available_models"]: available = list(self.config["available_models"].keys()) raise ValueError(f"Model '{model}' not available. Available: {available}") model_file = self.config["available_models"][model]["file"] model_path = os.path.join(self.repo_path, model_file) if not os.path.exists(model_path): # Try to download from HuggingFace if not local if HF_AVAILABLE and not os.path.exists(self.repo_path): try: model_path = hf_hub_download( repo_id=self.repo_path, # Treat as repo_id if not local path filename=model_file ) except Exception as e: raise FileNotFoundError(f"Model file not found locally and could not download: {e}") else: raise FileNotFoundError(f"Model file not found: {model_path}") return model_path def get_vocoder_path(self) -> str: """ Get path to vocoder checkpoint Returns: Absolute path to vocoder checkpoint """ vocoder_file = self.config["available_models"]["vocoder"]["file"] vocoder_path = os.path.join(self.repo_path, vocoder_file) if not os.path.exists(vocoder_path): # Try to download from HuggingFace if not local if HF_AVAILABLE and not os.path.exists(self.repo_path): try: vocoder_path = hf_hub_download( repo_id=self.repo_path, filename=vocoder_file ) except Exception as e: raise FileNotFoundError(f"Vocoder file not found locally and could not download: {e}") else: raise FileNotFoundError(f"Vocoder file not found: {vocoder_path}") return vocoder_path def load_models(self, model: str = None, device: str = "auto"): """ Load TTS model and vocoder for inference NOTE: This returns paths for use with the original Matcha-TTS repository. You'll need to import and use the original loading functions. Args: model: Model to load (swara, bas_10, bas_950, sgs_10, sgs_950) device: Device to load on ("auto", "cpu", "cuda") Returns: Dict with model and vocoder paths and configurations """ if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" model_path = self.get_model_path(model) vocoder_path = self.get_vocoder_path() model_name = model or self.config["default_model"] model_info = self.config["available_models"][model_name] return { "model_path": model_path, "vocoder_path": vocoder_path, "config": self.config, "model_name": model_name, "model_info": model_info, "device": device, "inference_params": self.config["inference_defaults"] } def list_models(self): """List available models with details""" models = {} for name, info in self.config["available_models"].items(): if name != "vocoder": models[name] = { "type": info["type"], "description": info["description"], "speaker": info.get("speaker", "multi_speaker"), "training_data": info.get("training_data", "N/A") } return models def list_research_variants(self): """List research comparison variants""" return self.config["research_variants"] def get_model_info(self, model: str = None): """Get detailed information about a specific model""" model_name = model or self.config["default_model"] if model_name not in self.config["available_models"]: raise ValueError(f"Model '{model_name}' not available") return self.config["available_models"][model_name] def get_sample_texts(self) -> list: """Get Romanian sample texts for testing""" return [ "Bună ziua! Acesta este un test de sinteză vocală în limba română.", "Matcha-TTS funcționează foarte bine pentru limba română.", "Sistemul de sinteză vocală poate genera vorbire naturală.", "Această tehnologie folosește inteligența artificială avansată.", "Vorbirea sintetizată sună foarte realistă și naturală." ]