File size: 6,770 Bytes
bca11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
Example usage of Romanian Matcha-TTS models with HuggingFace integration
This script shows how to use the HuggingFace model loader with the original
Matcha-TTS repository for inference.
"""
import sys
import os
import torch
import soundfile as sf
from pathlib import Path
# Add the HuggingFace model loader to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Import our model loader
from model_loader import ModelLoader
def load_matcha_dependencies():
"""
Try to import Matcha-TTS dependencies
Make sure you have the main repository installed:
pip install git+https://github.com/adrianstanea/Matcha-TTS.git
"""
try:
# Import from the original Matcha-TTS repository
from matcha.models.matcha_tts import MatchaTTS
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.hifigan.config import v1
from matcha.hifigan.env import AttrDict
from matcha.hifigan.denoiser import Denoiser
from matcha.text import text_to_sequence
from matcha.utils.utils import intersperse
return {
'MatchaTTS': MatchaTTS,
'HiFiGAN': HiFiGAN,
'v1': v1,
'AttrDict': AttrDict,
'Denoiser': Denoiser,
'text_to_sequence': text_to_sequence,
'intersperse': intersperse
}
except ImportError as e:
print(f"Error importing Matcha-TTS dependencies: {e}")
print("Please install the main repository:")
print("pip install git+https://github.com/adrianstanea/Matcha-TTS.git")
return None
def synthesize_romanian(text: str, model: str = "bas_950", repo_path: str = None):
"""
Synthesize Romanian speech using HuggingFace model loader
Args:
text: Romanian text to synthesize
model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950)
repo_path: Path to HuggingFace repo (local or repo ID)
"""
# Load Matcha-TTS dependencies
matcha_deps = load_matcha_dependencies()
if matcha_deps is None:
return None
# Initialize model loader
if repo_path is None:
# Use local path relative to this script
repo_path = str(Path(__file__).parent.parent)
try:
loader = ModelLoader.from_pretrained(repo_path)
print(f"✓ Loaded model configuration from {repo_path}")
except Exception as e:
print(f"✗ Failed to load model configuration: {e}")
return None
# Get model paths and configuration
model_info = loader.load_models(model=model)
print(f"✓ Model info loaded: {model_info['model_name']}")
print(f" Description: {model_info['model_info']['description']}")
print(f" Training data: {model_info['model_info'].get('training_data', 'N/A')}")
device = torch.device(model_info['device'])
print(f"✓ Using device: {device}")
# Load TTS model
try:
model = matcha_deps['MatchaTTS'].load_from_checkpoint(
model_info['model_path'],
map_location=device,
weights_only=False # Required for PyTorch 2.6+ to load OmegaConf configs
)
model.eval()
print(f"✓ Loaded TTS model from {model_info['model_path']}")
except Exception as e:
print(f"✗ Failed to load TTS model: {e}")
return None
# Load vocoder
try:
h = matcha_deps['AttrDict'](matcha_deps['v1'])
vocoder = matcha_deps['HiFiGAN'](h).to(device)
checkpoint = torch.load(model_info['vocoder_path'], map_location=device, weights_only=False)
vocoder.load_state_dict(checkpoint['generator'])
vocoder.eval()
vocoder.remove_weight_norm()
denoiser = matcha_deps['Denoiser'](vocoder, mode='zeros')
print(f"✓ Loaded vocoder from {model_info['vocoder_path']}")
except Exception as e:
print(f"✗ Failed to load vocoder: {e}")
return None
# Process text
print(f"Processing text: '{text}'")
try:
# Use Romanian cleaners
x = torch.tensor(
matcha_deps['intersperse'](
matcha_deps['text_to_sequence'](text, ['romanian_cleaners'])[0], 0
),
dtype=torch.long,
device=device
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
print("✓ Text processed successfully")
except Exception as e:
print(f"✗ Failed to process text: {e}")
return None
# Generate speech
print("Generating speech...")
try:
with torch.inference_mode():
# Synthesis parameters from config
params = model_info['inference_params']
output = model.synthesise(
x, x_lengths,
n_timesteps=params['n_timesteps'],
temperature=params['temperature'],
length_scale=params['length_scale']
)
# Convert to waveform
mel = output['mel']
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
print("✓ Speech generated successfully")
return audio.numpy(), model_info['config']['sample_rate']
except Exception as e:
print(f"✗ Failed to generate speech: {e}")
return None
def main():
"""Example usage"""
# Test with local repository path
repo_path = str(Path(__file__).parent.parent) # Path to Ro-Matcha-TTS
# Sample Romanian texts
test_texts = [
"Bună ziua! Acesta este un test de sinteză vocală.",
"România are o cultură bogată și o istorie fascinantă.",
"Limba română face parte din familia limbilor romanice."
]
# Test different models for research comparison
test_models = ["bas_10", "bas_950", "sgs_10", "sgs_950"]
# Test synthesis
output_dir = Path("generated_samples")
output_dir.mkdir(exist_ok=True)
for model in test_models: # Test with first two models
print(f"\n{'='*50}")
print(f"Testing model: {model}")
print(f"{'='*50}")
for i, text in enumerate(test_texts): # Test with first text
print(f"\nText {i+1}: {text}")
result = synthesize_romanian(
text=text,
model=model,
repo_path=repo_path
)
if result is not None:
audio, sr = result
output_file = output_dir / f"sample_{model}_{i+1}.wav"
sf.write(output_file, audio, sr)
print(f"✓ Saved audio to {output_file}")
else:
print(f"✗ Failed to generate audio for {model}")
if __name__ == "__main__":
main() |