Approximetal's picture
Upload folder using huggingface_hub
f36e46d verified
import os
import random
import sys
from pathlib import Path
import re, regex
import soundfile as sf
import tqdm
from hydra.utils import get_class
from omegaconf import OmegaConf
from lemas_tts.infer.utils_infer import (
load_model,
load_vocoder,
transcribe,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
)
from lemas_tts.model.utils import seed_everything
from lemas_tts.model.backbones.dit import DiT
# Resolve repository layout so we can find pretrained assets (ckpts, vocoder, etc.)
THIS_FILE = Path(__file__).resolve()
print("THIS_FILE:", THIS_FILE)
def _find_repo_root(start: Path) -> Path:
"""Locate the repo root by looking for a `pretrained_models` folder upwards."""
for p in [start, *start.parents]:
if (p / "pretrained_models").is_dir():
return p
cwd = Path.cwd()
if (cwd / "pretrained_models").is_dir():
return cwd
return start
def _find_pretrained_root(start: Path) -> Path:
"""
Locate the `pretrained_models` root, with support for:
1) Explicit env override (LEMAS_PRETRAINED_ROOT)
2) Hugging Face Spaces model mount under /models
3) Local source tree (searching upwards from this file)
"""
# 1) Explicit override
env_root = os.environ.get("LEMAS_PRETRAINED_ROOT")
if env_root:
p = Path(env_root)
if p.is_dir():
return p
# 2) HF Spaces model mount: /models/<model_id>/pretrained_models
models_dir = Path("/models")
if models_dir.is_dir():
# Try the expected model name first
specific = models_dir / "LEMAS-Project__LEMAS-TTS"
if (specific / "pretrained_models").is_dir():
return specific / "pretrained_models"
# Otherwise, pick the first model that has a pretrained_models subdir
for child in models_dir.iterdir():
if child.is_dir() and (child / "pretrained_models").is_dir():
return child / "pretrained_models"
# 3) Local repo layout
repo_root = _find_repo_root(start)
if (repo_root / "pretrained_models").is_dir():
return repo_root / "pretrained_models"
cwd = Path.cwd()
if (cwd / "pretrained_models").is_dir():
return cwd / "pretrained_models"
# Fallback: assume under repo root even if directory is missing
return repo_root / "pretrained_models"
REPO_ROOT = _find_repo_root(THIS_FILE)
PRETRAINED_ROOT = _find_pretrained_root(THIS_FILE)
CKPTS_ROOT = PRETRAINED_ROOT / "ckpts"
class TTS:
def __init__(
self,
model="multilingual",
ckpt_file="",
vocab_file="",
use_prosody_encoder=False,
prosody_cfg_path="",
prosody_ckpt_path="",
ode_method="euler",
use_ema=False,
vocoder_local_path=str(CKPTS_ROOT / "vocos-mel-24khz"),
device=None,
hf_cache_dir=None,
frontend="phone",
):
# Load model architecture config from bundled yaml
config_dir = THIS_FILE.parent / "configs"
model_cfg = OmegaConf.load(config_dir / f"{model}.yaml")
# model_cls = get_class(f"lemas_tts.model.dit.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
self.ode_method = ode_method
self.use_ema = use_ema
# remember whether this TTS instance is configured with a prosody encoder
self.use_prosody_encoder = use_prosody_encoder
self.langs = {"cmn":"zh", "zh":"zh", "en":"en-us", "it":"it", "es":"es", "pt":"pt-br", "fr":"fr-fr", "de":"de", "ru":"ru", "id":"id", "vi":"vi", "th":"th"}
if device is not None:
self.device = device
else:
import torch
self.device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# # Load models
# Prefer local vocoder directory if it exists; otherwise let `load_vocoder`
# fall back to downloading from the default HF repo (charactr/vocos-mel-24khz).
vocoder_is_local = False
if vocoder_local_path is not None:
try:
vocoder_is_local = Path(vocoder_local_path).is_dir()
except TypeError:
vocoder_is_local = False
self.vocoder = load_vocoder(
self.mel_spec_type, vocoder_is_local, vocoder_local_path, self.device, hf_cache_dir
)
# self.vocoder = load_vocoder(vocoder_name="vocos", is_local=True, local_path=vocoder_local_path, device=self.device)
if frontend is not None:
from lemas_tts.infer.frontend import TextNorm
# try:
# Try requested frontend first (typically "phone")
self.frontend = TextNorm(dtype=frontend)
# except Exception as e:
# # If espeak/phonemizer is not available, gracefully fall back to char frontend
# print(f"[TTS] Failed to init TextNorm with dtype='{frontend}': {e}")
# print("[TTS] Falling back to char frontend (no espeak required).")
# self.frontend = TextNorm(dtype="char")
else:
self.frontend = None
self.ema_model = load_model(
DiT,
model_arc,
ckpt_file,
self.mel_spec_type,
vocab_file,
self.ode_method,
self.use_ema,
self.device,
use_prosody_encoder=use_prosody_encoder,
prosody_cfg_path=prosody_cfg_path,
prosody_ckpt_path=prosody_ckpt_path,
)
def transcribe(self, ref_audio, language=None):
return transcribe(ref_audio, language)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spec, file_spec):
save_spectrogram(spec, file_spec)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
use_acc_grl=False,
ref_ratio=None,
no_ref_audio=False,
cfg_strength=2,
nfe_step=32,
speed=1.0,
sway_sampling_coef=5,
separate_langs=False,
fix_duration=None,
use_prosody_encoder=True,
file_wave=None,
file_spec=None,
seed=None,
):
if seed is None:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
print("preprocesss:\n", "ref_file:", ref_file, "\nref_text:", ref_text)
if self.frontend.dtype == "phone":
ref_text = self.frontend.text2phn(ref_text+". ").replace("(cmn)", "(zh)").split("|")
gen_text = gen_text.split("\n")
gen_text = [self.frontend.text2phn(x+". ").replace("(cmn)", "(zh)").split("|") for x in gen_text]
elif self.frontend.dtype == "char":
src_lang, ref_text = self.frontend.text2norm(ref_text+". ")
ref_text = ["("+src_lang.replace("cmn", "zh")+")"] + list(ref_text)
gen_text = gen_text.split("\n")
gen_text = [self.frontend.text2norm(x+". ") for x in gen_text]
gen_text = [["("+x[0].replace("cmn", "zh")+")"] + list(x[1]) for x in gen_text]
print("after frontend:\n", "ref_text:", ref_text, "\ngen_text:", gen_text)
if separate_langs:
ref_text = self.process_phone_list(ref_text) # Optional
gen_text = [self.process_phone_list(x) for x in gen_text]
print("gen_text:", gen_text, "\nref_text:", ref_text)
wav, sr, spec = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
self.vocoder,
self.mel_spec_type,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
use_prosody_encoder=use_prosody_encoder,
use_acc_grl=use_acc_grl,
ref_ratio=ref_ratio,
no_ref_audio=no_ref_audio,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence=False)
if file_spec is not None:
self.export_spectrogram(spec, file_spec)
return wav, sr, spec
def process_phone_list(self, parts):
puncs = {"#1", "#2", "#3", "#4", "_", "!", ",", ".", "?", '"', "'", "^", "。", ",", "?", "!"}
"""(vocab756 ver)处理phone list,给不带language id的phone添加当前language id前缀"""
# parts = phn_str.split('|')
processed = []
current_lang = ""
for i in range(len(parts)):
part = parts[i]
if part.startswith('(') and part.endswith(')') and part[1:-1] in self.langs:
# 这是一个language id
current_lang = part
# processed.append(part)
elif part in puncs: # not bool(regex.search(r'\p{L}', part[0])): # 匹配非字母数字、非空格的字符
# 是停顿符或标点
if len(processed) > 0 and processed[-1] == "_":
processed.pop()
elif len(processed) > 0 and processed[-1] in puncs and part == "_":
continue
processed.append(part)
# if i < len(parts) - 1 and parts[i+1] != "_":
# processed.append("_")
elif current_lang is not None:
# 不是language id且有当前language id,添加前缀
processed.append(f"{current_lang}{part}")
return processed
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str((THIS_FILE.parent / "infer" / "examples" / "basic" / "basic_ref_en.wav").resolve()),
ref_text="some call me nature, others call me mother nature.",
gen_text=(
"I don't really care what you call me. I've been a silent spectator, watching species evolve, "
"empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture "
"you; ignore me and you shall face the consequences."
),
file_wave=str((REPO_ROOT / "outputs" / "api_out.wav").resolve()),
file_spec=str((REPO_ROOT / "outputs" / "api_out.png").resolve()),
seed=None,
)
print("seed :", f5tts.seed)