File size: 7,374 Bytes
bca11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """
HuggingFace-compatible model loader for Romanian Matcha-TTS
"""
import json
import os
import torch
from pathlib import Path
from typing import Optional, Dict, Any
try:
from huggingface_hub import hf_hub_download
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
class ModelLoader:
"""
HuggingFace-compatible loader for Romanian Matcha-TTS models
Usage:
loader = ModelLoader.from_pretrained("adrianstanea/Ro-Matcha-TTS")
model, vocoder = loader.load_models(speaker="BAS")
"""
def __init__(self, repo_path: str):
"""
Initialize with local repository path or HuggingFace repo ID
Args:
repo_path: Path to local repo or HuggingFace repo ID
"""
self.repo_path = repo_path
self.config = self._load_config()
@classmethod
def from_pretrained(cls, repo_id: str, cache_dir: Optional[str] = None) -> "ModelLoader":
"""
Load from HuggingFace Hub or local path
Args:
repo_id: HuggingFace repo ID (e.g., "adrianstanea/Ro-Matcha-TTS") or local path
cache_dir: Optional cache directory for downloads
Returns:
ModelLoader instance
"""
if os.path.exists(repo_id):
# Local path
return cls(repo_id)
elif HF_AVAILABLE:
# Download from HuggingFace Hub
try:
config_path = hf_hub_download(
repo_id=repo_id,
filename="configs/config.json",
cache_dir=cache_dir
)
repo_cache_path = Path(config_path).parent.parent
return cls(str(repo_cache_path))
except Exception as e:
raise ValueError(f"Could not download from HuggingFace Hub: {e}")
else:
raise ImportError("huggingface_hub is required for downloading from HF Hub. Install with: pip install huggingface_hub")
def _load_config(self) -> Dict[str, Any]:
"""Load model configuration"""
config_path = os.path.join(self.repo_path, "configs", "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at {config_path}")
with open(config_path, 'r') as f:
return json.load(f)
def get_model_path(self, model: str = None) -> str:
"""
Get path to model checkpoint for specified model
Args:
model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950). If None, uses default.
Returns:
Absolute path to model checkpoint
"""
if model is None:
model = self.config["default_model"]
if model not in self.config["available_models"]:
available = list(self.config["available_models"].keys())
raise ValueError(f"Model '{model}' not available. Available: {available}")
model_file = self.config["available_models"][model]["file"]
model_path = os.path.join(self.repo_path, model_file)
if not os.path.exists(model_path):
# Try to download from HuggingFace if not local
if HF_AVAILABLE and not os.path.exists(self.repo_path):
try:
model_path = hf_hub_download(
repo_id=self.repo_path, # Treat as repo_id if not local path
filename=model_file
)
except Exception as e:
raise FileNotFoundError(f"Model file not found locally and could not download: {e}")
else:
raise FileNotFoundError(f"Model file not found: {model_path}")
return model_path
def get_vocoder_path(self) -> str:
"""
Get path to vocoder checkpoint
Returns:
Absolute path to vocoder checkpoint
"""
vocoder_file = self.config["available_models"]["vocoder"]["file"]
vocoder_path = os.path.join(self.repo_path, vocoder_file)
if not os.path.exists(vocoder_path):
# Try to download from HuggingFace if not local
if HF_AVAILABLE and not os.path.exists(self.repo_path):
try:
vocoder_path = hf_hub_download(
repo_id=self.repo_path,
filename=vocoder_file
)
except Exception as e:
raise FileNotFoundError(f"Vocoder file not found locally and could not download: {e}")
else:
raise FileNotFoundError(f"Vocoder file not found: {vocoder_path}")
return vocoder_path
def load_models(self, model: str = None, device: str = "auto"):
"""
Load TTS model and vocoder for inference
NOTE: This returns paths for use with the original Matcha-TTS repository.
You'll need to import and use the original loading functions.
Args:
model: Model to load (swara, bas_10, bas_950, sgs_10, sgs_950)
device: Device to load on ("auto", "cpu", "cuda")
Returns:
Dict with model and vocoder paths and configurations
"""
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = self.get_model_path(model)
vocoder_path = self.get_vocoder_path()
model_name = model or self.config["default_model"]
model_info = self.config["available_models"][model_name]
return {
"model_path": model_path,
"vocoder_path": vocoder_path,
"config": self.config,
"model_name": model_name,
"model_info": model_info,
"device": device,
"inference_params": self.config["inference_defaults"]
}
def list_models(self):
"""List available models with details"""
models = {}
for name, info in self.config["available_models"].items():
if name != "vocoder":
models[name] = {
"type": info["type"],
"description": info["description"],
"speaker": info.get("speaker", "multi_speaker"),
"training_data": info.get("training_data", "N/A")
}
return models
def list_research_variants(self):
"""List research comparison variants"""
return self.config["research_variants"]
def get_model_info(self, model: str = None):
"""Get detailed information about a specific model"""
model_name = model or self.config["default_model"]
if model_name not in self.config["available_models"]:
raise ValueError(f"Model '{model_name}' not available")
return self.config["available_models"][model_name]
def get_sample_texts(self) -> list:
"""Get Romanian sample texts for testing"""
return [
"Bună ziua! Acesta este un test de sinteză vocală în limba română.",
"Matcha-TTS funcționează foarte bine pentru limba română.",
"Sistemul de sinteză vocală poate genera vorbire naturală.",
"Această tehnologie folosește inteligența artificială avansată.",
"Vorbirea sintetizată sună foarte realistă și naturală."
] |