Spaces:
Running
Running
File size: 7,415 Bytes
c0c84cf a0d48a5 c0c84cf a0d48a5 c0c84cf a0d48a5 c0c84cf a0d48a5 c0c84cf 23f9c22 a0d48a5 c0c84cf 23f9c22 c0c84cf a0d48a5 c0c84cf a0d48a5 c0c84cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | import os
import sys
import re
import torch
import soundfile as sf
import argparse
from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
from tiny_tts.text import phonemes_to_ids
from tiny_tts.nn import commons
from tiny_tts.models import VoiceSynthesizer
from tiny_tts.text.symbols import symbols
from tiny_tts.utils import (
SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
N_SPEAKERS, SPK2ID, MODEL_PARAMS,
)
def load_engine(checkpoint_path, device='cuda'):
print(f"Loading model from {checkpoint_path}")
net_g = VoiceSynthesizer(
len(symbols),
SPEC_CHANNELS,
SEGMENT_FRAMES,
n_speakers=N_SPEAKERS,
**MODEL_PARAMS
).to(device)
# Count model parameters
total_params = sum(p.numel() for p in net_g.parameters())
trainable_params = sum(p.numel() for p in net_g.parameters() if p.requires_grad)
print(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable")
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint['model']
# Remove module. prefix and filter shape mismatches
model_state = net_g.state_dict()
new_state_dict = {}
skipped = []
for k, v in state_dict.items():
key = k[7:] if k.startswith('module.') else k
if key in model_state:
if v.shape == model_state[key].shape:
new_state_dict[key] = v
else:
skipped.append(f"{key}: ckpt{v.shape} vs model{model_state[key].shape}")
else:
new_state_dict[key] = v
if skipped:
print(f"Skipped {len(skipped)} mismatched keys:")
for s in skipped[:5]:
print(f" {s}")
if len(skipped) > 5:
print(f" ... and {len(skipped)-5} more")
net_g.load_state_dict(new_state_dict, strict=False)
net_g.eval()
# Fold weight_norm into weight tensors for faster inference (~18% speedup)
net_g.dec.remove_weight_norm()
return net_g
def synthesize(text, output_path, model, speaker="MALE", device='cuda', speed=1.0):
print(f"Synthesizing: {text}")
# Normalize text
normalized = normalize_text(text)
# Phonemize
phones, tones, word2ph = grapheme_to_phoneme(normalized)
# Convert to sequence
phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
# Add blanks
if ADD_BLANK:
phone_ids = commons.insert_blanks(phone_ids, 0)
tone_ids = commons.insert_blanks(tone_ids, 0)
lang_ids = commons.insert_blanks(lang_ids, 0)
x = torch.LongTensor(phone_ids).unsqueeze(0).to(device)
x_lengths = torch.LongTensor([len(phone_ids)]).to(device)
tone = torch.LongTensor(tone_ids).unsqueeze(0).to(device)
language = torch.LongTensor(lang_ids).unsqueeze(0).to(device)
# Speaker ID
if speaker not in SPK2ID:
print(f"Warning: Speaker {speaker} not found, using ID 0")
sid = torch.LongTensor([0]).to(device)
else:
sid = torch.LongTensor([SPK2ID[speaker]]).to(device)
# BERT features (disabled - using zero tensors)
bert = torch.zeros(1024, len(phone_ids)).to(device).unsqueeze(0)
ja_bert = torch.zeros(768, len(phone_ids)).to(device).unsqueeze(0)
# speed > 1.0 = faster speech, < 1.0 = slower speech
length_scale = 1.0 / speed
with torch.no_grad():
audio, *_ = model.infer(
x, x_lengths, sid, tone, language, bert, ja_bert,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=length_scale
)
audio = audio[0, 0].cpu().numpy()
sf.write(output_path, audio, SAMPLING_RATE)
print(f"Saved audio to {output_path}")
def get_latest_checkpoint(checkpoint_dir):
"""Finds the latest G_*.pth checkpoint in the given directory."""
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('G_') and f.endswith('.pth')]
if not checkpoints:
return None
def get_step(filename):
match = re.search(r'_(\d+)\.pth', filename)
return int(match.group(1)) if match else -1
latest_ckpt = max(checkpoints, key=get_step)
return os.path.join(checkpoint_dir, latest_ckpt)
def main():
parser = argparse.ArgumentParser(description="TinyTTS — English Text-to-Speech Inference")
parser.add_argument("--text", "-t", type=str, default="The weather is nice today, and I feel very relaxed.", help="Text to synthesize")
parser.add_argument("--checkpoint", "-c", type=str, default=None, help="Path to checkpoint. Auto-downloads if not provided.")
parser.add_argument("--output", "-o", type=str, default="output.wav", help="Output audio file path")
parser.add_argument("--speaker", "-s", type=str, default="MALE", help="Speaker ID")
parser.add_argument("--speed", type=float, default=1.0, help="Speech speed (1.0=normal, 1.5=faster, 0.7=slower)")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
args = parser.parse_args()
if args.checkpoint is None:
try:
from huggingface_hub import hf_hub_download
print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
args.checkpoint = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
except ImportError:
print("Error: huggingface_hub is required for auto-download. Run: pip install huggingface_hub")
sys.exit(1)
except Exception as e:
print(f"Error downloading checkpoint: {e}")
sys.exit(1)
if not os.path.exists(args.checkpoint):
print(f"Error: Checkpoint or directory not found at {args.checkpoint}")
sys.exit(1)
if os.path.isdir(args.checkpoint):
latest_ckpt = get_latest_checkpoint(args.checkpoint)
if not latest_ckpt:
print(f"Error: No G_*.pth checkpoints found in directory {args.checkpoint}")
sys.exit(1)
args.checkpoint = latest_ckpt
print(f"Auto-detected latest checkpoint: {args.checkpoint}")
# Extract step from checkpoint filename
ckpt_basename = os.path.basename(args.checkpoint)
match = re.search(r'_(\d+)\.pth', ckpt_basename)
step_str = match.group(1) if match else "unknown"
# Save to output folder
out_dir = "infer_outputs"
os.makedirs(out_dir, exist_ok=True)
out_name = os.path.basename(args.output)
name, ext = os.path.splitext(out_name)
model = load_engine(args.checkpoint, args.device)
if args.speaker.lower() == "all":
if not SPK2ID:
print("Error: No speakers found")
sys.exit(1)
print(f"Synthesizing for all {len(SPK2ID)} speakers...")
for spk in SPK2ID.keys():
final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{spk}{ext}")
synthesize(args.text, final_output, model, speaker=spk, device=args.device, speed=args.speed)
else:
final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{args.speaker}{ext}")
synthesize(args.text, final_output, model, speaker=args.speaker, device=args.device, speed=args.speed)
if __name__ == "__main__":
main()
|