Ro-Matcha-TTS / src /model_loader.py
adrianstanea's picture
Initial upload of Romanian Matcha-TTS models
bca11b0
"""
HuggingFace-compatible model loader for Romanian Matcha-TTS
"""
import json
import os
import torch
from pathlib import Path
from typing import Optional, Dict, Any
try:
from huggingface_hub import hf_hub_download
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
class ModelLoader:
"""
HuggingFace-compatible loader for Romanian Matcha-TTS models
Usage:
loader = ModelLoader.from_pretrained("adrianstanea/Ro-Matcha-TTS")
model, vocoder = loader.load_models(speaker="BAS")
"""
def __init__(self, repo_path: str):
"""
Initialize with local repository path or HuggingFace repo ID
Args:
repo_path: Path to local repo or HuggingFace repo ID
"""
self.repo_path = repo_path
self.config = self._load_config()
@classmethod
def from_pretrained(cls, repo_id: str, cache_dir: Optional[str] = None) -> "ModelLoader":
"""
Load from HuggingFace Hub or local path
Args:
repo_id: HuggingFace repo ID (e.g., "adrianstanea/Ro-Matcha-TTS") or local path
cache_dir: Optional cache directory for downloads
Returns:
ModelLoader instance
"""
if os.path.exists(repo_id):
# Local path
return cls(repo_id)
elif HF_AVAILABLE:
# Download from HuggingFace Hub
try:
config_path = hf_hub_download(
repo_id=repo_id,
filename="configs/config.json",
cache_dir=cache_dir
)
repo_cache_path = Path(config_path).parent.parent
return cls(str(repo_cache_path))
except Exception as e:
raise ValueError(f"Could not download from HuggingFace Hub: {e}")
else:
raise ImportError("huggingface_hub is required for downloading from HF Hub. Install with: pip install huggingface_hub")
def _load_config(self) -> Dict[str, Any]:
"""Load model configuration"""
config_path = os.path.join(self.repo_path, "configs", "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at {config_path}")
with open(config_path, 'r') as f:
return json.load(f)
def get_model_path(self, model: str = None) -> str:
"""
Get path to model checkpoint for specified model
Args:
model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950). If None, uses default.
Returns:
Absolute path to model checkpoint
"""
if model is None:
model = self.config["default_model"]
if model not in self.config["available_models"]:
available = list(self.config["available_models"].keys())
raise ValueError(f"Model '{model}' not available. Available: {available}")
model_file = self.config["available_models"][model]["file"]
model_path = os.path.join(self.repo_path, model_file)
if not os.path.exists(model_path):
# Try to download from HuggingFace if not local
if HF_AVAILABLE and not os.path.exists(self.repo_path):
try:
model_path = hf_hub_download(
repo_id=self.repo_path, # Treat as repo_id if not local path
filename=model_file
)
except Exception as e:
raise FileNotFoundError(f"Model file not found locally and could not download: {e}")
else:
raise FileNotFoundError(f"Model file not found: {model_path}")
return model_path
def get_vocoder_path(self) -> str:
"""
Get path to vocoder checkpoint
Returns:
Absolute path to vocoder checkpoint
"""
vocoder_file = self.config["available_models"]["vocoder"]["file"]
vocoder_path = os.path.join(self.repo_path, vocoder_file)
if not os.path.exists(vocoder_path):
# Try to download from HuggingFace if not local
if HF_AVAILABLE and not os.path.exists(self.repo_path):
try:
vocoder_path = hf_hub_download(
repo_id=self.repo_path,
filename=vocoder_file
)
except Exception as e:
raise FileNotFoundError(f"Vocoder file not found locally and could not download: {e}")
else:
raise FileNotFoundError(f"Vocoder file not found: {vocoder_path}")
return vocoder_path
def load_models(self, model: str = None, device: str = "auto"):
"""
Load TTS model and vocoder for inference
NOTE: This returns paths for use with the original Matcha-TTS repository.
You'll need to import and use the original loading functions.
Args:
model: Model to load (swara, bas_10, bas_950, sgs_10, sgs_950)
device: Device to load on ("auto", "cpu", "cuda")
Returns:
Dict with model and vocoder paths and configurations
"""
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = self.get_model_path(model)
vocoder_path = self.get_vocoder_path()
model_name = model or self.config["default_model"]
model_info = self.config["available_models"][model_name]
return {
"model_path": model_path,
"vocoder_path": vocoder_path,
"config": self.config,
"model_name": model_name,
"model_info": model_info,
"device": device,
"inference_params": self.config["inference_defaults"]
}
def list_models(self):
"""List available models with details"""
models = {}
for name, info in self.config["available_models"].items():
if name != "vocoder":
models[name] = {
"type": info["type"],
"description": info["description"],
"speaker": info.get("speaker", "multi_speaker"),
"training_data": info.get("training_data", "N/A")
}
return models
def list_research_variants(self):
"""List research comparison variants"""
return self.config["research_variants"]
def get_model_info(self, model: str = None):
"""Get detailed information about a specific model"""
model_name = model or self.config["default_model"]
if model_name not in self.config["available_models"]:
raise ValueError(f"Model '{model_name}' not available")
return self.config["available_models"][model_name]
def get_sample_texts(self) -> list:
"""Get Romanian sample texts for testing"""
return [
"Bună ziua! Acesta este un test de sinteză vocală în limba română.",
"Matcha-TTS funcționează foarte bine pentru limba română.",
"Sistemul de sinteză vocală poate genera vorbire naturală.",
"Această tehnologie folosește inteligența artificială avansată.",
"Vorbirea sintetizată sună foarte realistă și naturală."
]