Ro-Matcha-TTS / examples /inference_example.py
adrianstanea's picture
Initial upload of Romanian Matcha-TTS models
bca11b0
"""
Example usage of Romanian Matcha-TTS models with HuggingFace integration
This script shows how to use the HuggingFace model loader with the original
Matcha-TTS repository for inference.
"""
import sys
import os
import torch
import soundfile as sf
from pathlib import Path
# Add the HuggingFace model loader to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Import our model loader
from model_loader import ModelLoader
def load_matcha_dependencies():
"""
Try to import Matcha-TTS dependencies
Make sure you have the main repository installed:
pip install git+https://github.com/adrianstanea/Matcha-TTS.git
"""
try:
# Import from the original Matcha-TTS repository
from matcha.models.matcha_tts import MatchaTTS
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.hifigan.config import v1
from matcha.hifigan.env import AttrDict
from matcha.hifigan.denoiser import Denoiser
from matcha.text import text_to_sequence
from matcha.utils.utils import intersperse
return {
'MatchaTTS': MatchaTTS,
'HiFiGAN': HiFiGAN,
'v1': v1,
'AttrDict': AttrDict,
'Denoiser': Denoiser,
'text_to_sequence': text_to_sequence,
'intersperse': intersperse
}
except ImportError as e:
print(f"Error importing Matcha-TTS dependencies: {e}")
print("Please install the main repository:")
print("pip install git+https://github.com/adrianstanea/Matcha-TTS.git")
return None
def synthesize_romanian(text: str, model: str = "bas_950", repo_path: str = None):
"""
Synthesize Romanian speech using HuggingFace model loader
Args:
text: Romanian text to synthesize
model: Model name (swara, bas_10, bas_950, sgs_10, sgs_950)
repo_path: Path to HuggingFace repo (local or repo ID)
"""
# Load Matcha-TTS dependencies
matcha_deps = load_matcha_dependencies()
if matcha_deps is None:
return None
# Initialize model loader
if repo_path is None:
# Use local path relative to this script
repo_path = str(Path(__file__).parent.parent)
try:
loader = ModelLoader.from_pretrained(repo_path)
print(f"✓ Loaded model configuration from {repo_path}")
except Exception as e:
print(f"✗ Failed to load model configuration: {e}")
return None
# Get model paths and configuration
model_info = loader.load_models(model=model)
print(f"✓ Model info loaded: {model_info['model_name']}")
print(f" Description: {model_info['model_info']['description']}")
print(f" Training data: {model_info['model_info'].get('training_data', 'N/A')}")
device = torch.device(model_info['device'])
print(f"✓ Using device: {device}")
# Load TTS model
try:
model = matcha_deps['MatchaTTS'].load_from_checkpoint(
model_info['model_path'],
map_location=device,
weights_only=False # Required for PyTorch 2.6+ to load OmegaConf configs
)
model.eval()
print(f"✓ Loaded TTS model from {model_info['model_path']}")
except Exception as e:
print(f"✗ Failed to load TTS model: {e}")
return None
# Load vocoder
try:
h = matcha_deps['AttrDict'](matcha_deps['v1'])
vocoder = matcha_deps['HiFiGAN'](h).to(device)
checkpoint = torch.load(model_info['vocoder_path'], map_location=device, weights_only=False)
vocoder.load_state_dict(checkpoint['generator'])
vocoder.eval()
vocoder.remove_weight_norm()
denoiser = matcha_deps['Denoiser'](vocoder, mode='zeros')
print(f"✓ Loaded vocoder from {model_info['vocoder_path']}")
except Exception as e:
print(f"✗ Failed to load vocoder: {e}")
return None
# Process text
print(f"Processing text: '{text}'")
try:
# Use Romanian cleaners
x = torch.tensor(
matcha_deps['intersperse'](
matcha_deps['text_to_sequence'](text, ['romanian_cleaners'])[0], 0
),
dtype=torch.long,
device=device
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
print("✓ Text processed successfully")
except Exception as e:
print(f"✗ Failed to process text: {e}")
return None
# Generate speech
print("Generating speech...")
try:
with torch.inference_mode():
# Synthesis parameters from config
params = model_info['inference_params']
output = model.synthesise(
x, x_lengths,
n_timesteps=params['n_timesteps'],
temperature=params['temperature'],
length_scale=params['length_scale']
)
# Convert to waveform
mel = output['mel']
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
print("✓ Speech generated successfully")
return audio.numpy(), model_info['config']['sample_rate']
except Exception as e:
print(f"✗ Failed to generate speech: {e}")
return None
def main():
"""Example usage"""
# Test with local repository path
repo_path = str(Path(__file__).parent.parent) # Path to Ro-Matcha-TTS
# Sample Romanian texts
test_texts = [
"Bună ziua! Acesta este un test de sinteză vocală.",
"România are o cultură bogată și o istorie fascinantă.",
"Limba română face parte din familia limbilor romanice."
]
# Test different models for research comparison
test_models = ["bas_10", "bas_950", "sgs_10", "sgs_950"]
# Test synthesis
output_dir = Path("generated_samples")
output_dir.mkdir(exist_ok=True)
for model in test_models: # Test with first two models
print(f"\n{'='*50}")
print(f"Testing model: {model}")
print(f"{'='*50}")
for i, text in enumerate(test_texts): # Test with first text
print(f"\nText {i+1}: {text}")
result = synthesize_romanian(
text=text,
model=model,
repo_path=repo_path
)
if result is not None:
audio, sr = result
output_file = output_dir / f"sample_{model}_{i+1}.wav"
sf.write(output_file, audio, sr)
print(f"✓ Saved audio to {output_file}")
else:
print(f"✗ Failed to generate audio for {model}")
if __name__ == "__main__":
main()