| | """ |
| | HuggingFace-compatible model loader for Romanian Matcha-TTS |
| | """ |
| |
|
| | import json |
| | import os |
| | import torch |
| | from pathlib import Path |
| | from typing import Optional, Dict, Any |
| |
|
| | try: |
| | from huggingface_hub import hf_hub_download |
| | HF_AVAILABLE = True |
| | except ImportError: |
| | HF_AVAILABLE = False |
| |
|
| |
|
| | class ModelLoader: |
| | """ |
| | HuggingFace-compatible loader for Romanian Matcha-TTS models |
| | |
| | Usage: |
| | loader = ModelLoader.from_pretrained("adrianstanea/Ro-Matcha-TTS") |
| | model, vocoder = loader.load_models(speaker="BAS") |
| | """ |
| |
|
| | def __init__(self, repo_path: str): |
| | """ |
| | Initialize with local repository path or HuggingFace repo ID |
| | |
| | Args: |
| | repo_path: Path to local repo or HuggingFace repo ID |
| | """ |
| | self.repo_path = repo_path |
| | self.config = self._load_config() |
| |
|
| | @classmethod |
| | def from_pretrained(cls, repo_id: str, cache_dir: Optional[str] = None) -> "ModelLoader": |
| | """ |
| | Load from HuggingFace Hub or local path |
| | |
| | Args: |
| | repo_id: HuggingFace repo ID (e.g., "adrianstanea/Ro-Matcha-TTS") or local path |
| | cache_dir: Optional cache directory for downloads |
| | |
| | Returns: |
| | ModelLoader instance |
| | """ |
| | if os.path.exists(repo_id): |
| | |
| | return cls(repo_id) |
| | elif HF_AVAILABLE: |
| | |
| | try: |
| | config_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename="configs/config.json", |
| | cache_dir=cache_dir |
| | ) |
| | repo_cache_path = Path(config_path).parent.parent |
| | return cls(str(repo_cache_path)) |
| | except Exception as e: |
| | raise ValueError(f"Could not download from HuggingFace Hub: {e}") |
| | else: |
| | raise ImportError("huggingface_hub is required for downloading from HF Hub. Install with: pip install huggingface_hub") |
| |
|
| | def _load_config(self) -> Dict[str, Any]: |
| | """Load model configuration""" |
| | config_path = os.path.join(self.repo_path, "configs", "config.json") |
| | if not os.path.exists(config_path): |
| | raise FileNotFoundError(f"Config file not found at {config_path}") |
| |
|
| | with open(config_path, 'r') as f: |
| | return json.load(f) |
| |
|
| | def get_model_path(self, model: str = None) -> str: |
| | """ |
| | Get path to model checkpoint for specified model |
| | |
| | Args: |
| | model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950). If None, uses default. |
| | |
| | Returns: |
| | Absolute path to model checkpoint |
| | """ |
| | if model is None: |
| | model = self.config["default_model"] |
| |
|
| | if model not in self.config["available_models"]: |
| | available = list(self.config["available_models"].keys()) |
| | raise ValueError(f"Model '{model}' not available. Available: {available}") |
| |
|
| | model_file = self.config["available_models"][model]["file"] |
| | model_path = os.path.join(self.repo_path, model_file) |
| |
|
| | if not os.path.exists(model_path): |
| | |
| | if HF_AVAILABLE and not os.path.exists(self.repo_path): |
| | try: |
| | model_path = hf_hub_download( |
| | repo_id=self.repo_path, |
| | filename=model_file |
| | ) |
| | except Exception as e: |
| | raise FileNotFoundError(f"Model file not found locally and could not download: {e}") |
| | else: |
| | raise FileNotFoundError(f"Model file not found: {model_path}") |
| |
|
| | return model_path |
| |
|
| | def get_vocoder_path(self) -> str: |
| | """ |
| | Get path to vocoder checkpoint |
| | |
| | Returns: |
| | Absolute path to vocoder checkpoint |
| | """ |
| | vocoder_file = self.config["available_models"]["vocoder"]["file"] |
| | vocoder_path = os.path.join(self.repo_path, vocoder_file) |
| |
|
| | if not os.path.exists(vocoder_path): |
| | |
| | if HF_AVAILABLE and not os.path.exists(self.repo_path): |
| | try: |
| | vocoder_path = hf_hub_download( |
| | repo_id=self.repo_path, |
| | filename=vocoder_file |
| | ) |
| | except Exception as e: |
| | raise FileNotFoundError(f"Vocoder file not found locally and could not download: {e}") |
| | else: |
| | raise FileNotFoundError(f"Vocoder file not found: {vocoder_path}") |
| |
|
| | return vocoder_path |
| |
|
| | def load_models(self, model: str = None, device: str = "auto"): |
| | """ |
| | Load TTS model and vocoder for inference |
| | |
| | NOTE: This returns paths for use with the original Matcha-TTS repository. |
| | You'll need to import and use the original loading functions. |
| | |
| | Args: |
| | model: Model to load (swara, bas_10, bas_950, sgs_10, sgs_950) |
| | device: Device to load on ("auto", "cpu", "cuda") |
| | |
| | Returns: |
| | Dict with model and vocoder paths and configurations |
| | """ |
| | if device == "auto": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | model_path = self.get_model_path(model) |
| | vocoder_path = self.get_vocoder_path() |
| |
|
| | model_name = model or self.config["default_model"] |
| | model_info = self.config["available_models"][model_name] |
| |
|
| | return { |
| | "model_path": model_path, |
| | "vocoder_path": vocoder_path, |
| | "config": self.config, |
| | "model_name": model_name, |
| | "model_info": model_info, |
| | "device": device, |
| | "inference_params": self.config["inference_defaults"] |
| | } |
| |
|
| | def list_models(self): |
| | """List available models with details""" |
| | models = {} |
| | for name, info in self.config["available_models"].items(): |
| | if name != "vocoder": |
| | models[name] = { |
| | "type": info["type"], |
| | "description": info["description"], |
| | "speaker": info.get("speaker", "multi_speaker"), |
| | "training_data": info.get("training_data", "N/A") |
| | } |
| | return models |
| |
|
| | def list_research_variants(self): |
| | """List research comparison variants""" |
| | return self.config["research_variants"] |
| |
|
| | def get_model_info(self, model: str = None): |
| | """Get detailed information about a specific model""" |
| | model_name = model or self.config["default_model"] |
| | if model_name not in self.config["available_models"]: |
| | raise ValueError(f"Model '{model_name}' not available") |
| |
|
| | return self.config["available_models"][model_name] |
| |
|
| | def get_sample_texts(self) -> list: |
| | """Get Romanian sample texts for testing""" |
| | return [ |
| | "Bună ziua! Acesta este un test de sinteză vocală în limba română.", |
| | "Matcha-TTS funcționează foarte bine pentru limba română.", |
| | "Sistemul de sinteză vocală poate genera vorbire naturală.", |
| | "Această tehnologie folosește inteligența artificială avansată.", |
| | "Vorbirea sintetizată sună foarte realistă și naturală." |
| | ] |