|
|
import os |
|
|
import random |
|
|
import sys |
|
|
from pathlib import Path |
|
|
import re, regex |
|
|
import soundfile as sf |
|
|
import tqdm |
|
|
from hydra.utils import get_class |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from lemas_tts.infer.utils_infer import ( |
|
|
load_model, |
|
|
load_vocoder, |
|
|
transcribe, |
|
|
preprocess_ref_audio_text, |
|
|
infer_process, |
|
|
remove_silence_for_generated_wav, |
|
|
save_spectrogram, |
|
|
) |
|
|
from lemas_tts.model.utils import seed_everything |
|
|
from lemas_tts.model.backbones.dit import DiT |
|
|
|
|
|
|
|
|
|
|
|
THIS_FILE = Path(__file__).resolve() |
|
|
|
|
|
|
|
|
def _find_repo_root(start: Path) -> Path: |
|
|
"""Locate the repo root by looking for a `pretrained_models` folder upwards.""" |
|
|
for p in [start, *start.parents]: |
|
|
if (p / "pretrained_models").is_dir(): |
|
|
return p |
|
|
cwd = Path.cwd() |
|
|
if (cwd / "pretrained_models").is_dir(): |
|
|
return cwd |
|
|
return start |
|
|
|
|
|
|
|
|
def _find_pretrained_root(start: Path) -> Path: |
|
|
""" |
|
|
Locate the `pretrained_models` root, with support for: |
|
|
1) Explicit env override (LEMAS_PRETRAINED_ROOT) |
|
|
2) Hugging Face Spaces model mount under /models |
|
|
3) Local source tree (searching upwards from this file) |
|
|
""" |
|
|
|
|
|
env_root = os.environ.get("LEMAS_PRETRAINED_ROOT") |
|
|
if env_root: |
|
|
p = Path(env_root) |
|
|
if p.is_dir(): |
|
|
return p |
|
|
|
|
|
|
|
|
models_dir = Path("/models") |
|
|
if models_dir.is_dir(): |
|
|
|
|
|
specific = models_dir / "LEMAS-Project__LEMAS-TTS" |
|
|
if (specific / "pretrained_models").is_dir(): |
|
|
return specific / "pretrained_models" |
|
|
|
|
|
for child in models_dir.iterdir(): |
|
|
if child.is_dir() and (child / "pretrained_models").is_dir(): |
|
|
return child / "pretrained_models" |
|
|
|
|
|
|
|
|
repo_root = _find_repo_root(start) |
|
|
if (repo_root / "pretrained_models").is_dir(): |
|
|
return repo_root / "pretrained_models" |
|
|
|
|
|
cwd = Path.cwd() |
|
|
if (cwd / "pretrained_models").is_dir(): |
|
|
return cwd / "pretrained_models" |
|
|
|
|
|
|
|
|
return repo_root / "pretrained_models" |
|
|
|
|
|
|
|
|
REPO_ROOT = _find_repo_root(THIS_FILE) |
|
|
PRETRAINED_ROOT = _find_pretrained_root(THIS_FILE) |
|
|
CKPTS_ROOT = PRETRAINED_ROOT / "ckpts" |
|
|
|
|
|
class TTS: |
|
|
def __init__( |
|
|
self, |
|
|
model="multilingual", |
|
|
ckpt_file="", |
|
|
vocab_file="", |
|
|
ode_method="euler", |
|
|
use_ema=False, |
|
|
vocoder_local_path=str(CKPTS_ROOT / "vocos-mel-24khz"), |
|
|
use_prosody_encoder=False, |
|
|
prosody_cfg_path="", |
|
|
prosody_ckpt_path="", |
|
|
device=None, |
|
|
hf_cache_dir=None, |
|
|
frontend="phone", |
|
|
): |
|
|
|
|
|
config_dir = THIS_FILE.parent / "configs" |
|
|
model_cfg = OmegaConf.load(config_dir / f"{model}.yaml") |
|
|
|
|
|
model_arc = model_cfg.model.arch |
|
|
|
|
|
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type |
|
|
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate |
|
|
|
|
|
self.ode_method = ode_method |
|
|
self.use_ema = use_ema |
|
|
self.langs = {"cmn":"zh", "zh":"zh", "en":"en-us", "it":"it", "es":"es", "pt":"pt-br", "fr":"fr-fr", "de":"de", "ru":"ru", "id":"id", "vi":"vi", "th":"th"} |
|
|
|
|
|
if device is not None: |
|
|
self.device = device |
|
|
else: |
|
|
import torch |
|
|
|
|
|
self.device = ( |
|
|
"cuda" |
|
|
if torch.cuda.is_available() |
|
|
else "xpu" |
|
|
if torch.xpu.is_available() |
|
|
else "mps" |
|
|
if torch.backends.mps.is_available() |
|
|
else "cpu" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocoder_is_local = False |
|
|
if vocoder_local_path is not None: |
|
|
try: |
|
|
vocoder_is_local = Path(vocoder_local_path).is_dir() |
|
|
except TypeError: |
|
|
vocoder_is_local = False |
|
|
|
|
|
self.vocoder = load_vocoder( |
|
|
self.mel_spec_type, vocoder_is_local, vocoder_local_path, self.device, hf_cache_dir |
|
|
) |
|
|
|
|
|
if frontend is not None: |
|
|
from lemas_tts.infer.frontend import TextNorm |
|
|
self.frontend = TextNorm(dtype=frontend) |
|
|
else: |
|
|
self.frontend = None |
|
|
|
|
|
|
|
|
self.ema_model = load_model( |
|
|
DiT, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device, |
|
|
use_prosody_encoder=use_prosody_encoder, prosody_cfg_path=prosody_cfg_path, prosody_ckpt_path=prosody_ckpt_path, |
|
|
) |
|
|
|
|
|
def transcribe(self, ref_audio, language=None): |
|
|
return transcribe(ref_audio, language) |
|
|
|
|
|
def export_wav(self, wav, file_wave, remove_silence=False): |
|
|
sf.write(file_wave, wav, self.target_sample_rate) |
|
|
|
|
|
if remove_silence: |
|
|
remove_silence_for_generated_wav(file_wave) |
|
|
|
|
|
def export_spectrogram(self, spec, file_spec): |
|
|
save_spectrogram(spec, file_spec) |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
ref_file, |
|
|
ref_text, |
|
|
gen_text, |
|
|
show_info=print, |
|
|
progress=tqdm, |
|
|
target_rms=0.1, |
|
|
cross_fade_duration=0.15, |
|
|
use_acc_grl=False, |
|
|
ref_ratio=None, |
|
|
no_ref_audio=False, |
|
|
cfg_strength=2, |
|
|
nfe_step=32, |
|
|
speed=1.0, |
|
|
sway_sampling_coef=5, |
|
|
separate_langs=False, |
|
|
fix_duration=None, |
|
|
use_prosody_encoder=True, |
|
|
file_wave=None, |
|
|
file_spec=None, |
|
|
seed=None, |
|
|
): |
|
|
if seed is None: |
|
|
seed = random.randint(0, sys.maxsize) |
|
|
seed_everything(seed) |
|
|
self.seed = seed |
|
|
|
|
|
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text) |
|
|
print("preprocesss:\n", "ref_file:", ref_file, "\nref_text:", ref_text) |
|
|
if self.frontend.dtype == "phone": |
|
|
ref_text = self.frontend.text2phn(ref_text+". ").replace("(cmn)", "(zh)").split("|") |
|
|
gen_text = gen_text.split("\n") |
|
|
gen_text = [self.frontend.text2phn(x+". ").replace("(cmn)", "(zh)").split("|") for x in gen_text] |
|
|
|
|
|
elif self.frontend.dtype == "char": |
|
|
src_lang, ref_text = self.frontend.text2norm(ref_text+". ") |
|
|
ref_text = ["("+src_lang.replace("cmn", "zh")+")"] + list(ref_text) |
|
|
gen_text = gen_text.split("\n") |
|
|
gen_text = [self.frontend.text2norm(x+". ") for x in gen_text] |
|
|
gen_text = [["("+x[0].replace("cmn", "zh")+")"] + list(x[1]) for x in gen_text] |
|
|
print("after frontend:\n", "ref_text:", ref_text, "\ngen_text:", gen_text) |
|
|
|
|
|
if separate_langs: |
|
|
ref_text = self.process_phone_list(ref_text) |
|
|
gen_text = [self.process_phone_list(x) for x in gen_text] |
|
|
|
|
|
print("gen_text:", gen_text, "\nref_text:", ref_text) |
|
|
|
|
|
wav, sr, spec = infer_process( |
|
|
ref_file, |
|
|
ref_text, |
|
|
gen_text, |
|
|
self.ema_model, |
|
|
self.vocoder, |
|
|
self.mel_spec_type, |
|
|
show_info=show_info, |
|
|
progress=progress, |
|
|
target_rms=target_rms, |
|
|
cross_fade_duration=cross_fade_duration, |
|
|
nfe_step=nfe_step, |
|
|
cfg_strength=cfg_strength, |
|
|
sway_sampling_coef=sway_sampling_coef, |
|
|
use_prosody_encoder=use_prosody_encoder, |
|
|
use_acc_grl=use_acc_grl, |
|
|
ref_ratio=ref_ratio, |
|
|
no_ref_audio=no_ref_audio, |
|
|
speed=speed, |
|
|
fix_duration=fix_duration, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
if file_wave is not None: |
|
|
self.export_wav(wav, file_wave, remove_silence=False) |
|
|
|
|
|
if file_spec is not None: |
|
|
self.export_spectrogram(spec, file_spec) |
|
|
|
|
|
return wav, sr, spec |
|
|
|
|
|
|
|
|
def process_phone_list(self, parts): |
|
|
puncs = {"#1", "#2", "#3", "#4", "_", "!", ",", ".", "?", '"', "'", "^", "。", ",", "?", "!"} |
|
|
"""(vocab756 ver)处理phone list,给不带language id的phone添加当前language id前缀""" |
|
|
|
|
|
processed = [] |
|
|
current_lang = "" |
|
|
for i in range(len(parts)): |
|
|
part = parts[i] |
|
|
if part.startswith('(') and part.endswith(')') and part[1:-1] in self.langs: |
|
|
|
|
|
current_lang = part |
|
|
|
|
|
elif part in puncs: |
|
|
|
|
|
if len(processed) > 0 and processed[-1] == "_": |
|
|
processed.pop() |
|
|
elif len(processed) > 0 and processed[-1] in puncs and part == "_": |
|
|
continue |
|
|
processed.append(part) |
|
|
|
|
|
|
|
|
elif current_lang is not None: |
|
|
|
|
|
processed.append(f"{current_lang}{part}") |
|
|
return processed |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
f5tts = F5TTS() |
|
|
|
|
|
wav, sr, spec = f5tts.infer( |
|
|
ref_file=str((THIS_FILE.parent / "infer" / "examples" / "basic" / "basic_ref_en.wav").resolve()), |
|
|
ref_text="some call me nature, others call me mother nature.", |
|
|
gen_text=( |
|
|
"I don't really care what you call me. I've been a silent spectator, watching species evolve, " |
|
|
"empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture " |
|
|
"you; ignore me and you shall face the consequences." |
|
|
), |
|
|
file_wave=str((REPO_ROOT / "outputs" / "api_out.wav").resolve()), |
|
|
file_spec=str((REPO_ROOT / "outputs" / "api_out.png").resolve()), |
|
|
seed=None, |
|
|
) |
|
|
|
|
|
print("seed :", f5tts.seed) |
|
|
|