| import hashlib
|
| import os
|
| import string
|
| import subprocess
|
| import sys
|
| from datetime import datetime
|
| import torch
|
| import torchaudio
|
| from huggingface_hub import hf_hub_download, snapshot_download
|
| from underthesea import sent_tokenize
|
| from unidecode import unidecode
|
| from vinorm import TTSnorm
|
|
|
| from TTS.tts.configs.xtts_config import XttsConfig
|
| from TTS.tts.models.xtts import Xtts
|
|
|
| XTTS_MODEL = None
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| MODEL_DIR = os.path.join(SCRIPT_DIR, "model")
|
| OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
|
| FILTER_SUFFIX = "_DeepFilterNet3.wav"
|
| os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
| def clear_gpu_cache():
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
|
|
| def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False):
|
| global XTTS_MODEL
|
| clear_gpu_cache()
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
| required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
|
| files_in_dir = os.listdir(checkpoint_dir)
|
| if not all(file in files_in_dir for file in required_files):
|
| yield f"Missing model files! Downloading from {repo_id}..."
|
| snapshot_download(
|
| repo_id=repo_id,
|
| repo_type="model",
|
| local_dir=checkpoint_dir,
|
| )
|
| hf_hub_download(
|
| repo_id="coqui/XTTS-v2",
|
| filename="speakers_xtts.pth",
|
| local_dir=checkpoint_dir,
|
| )
|
| yield f"Model download finished..."
|
|
|
| xtts_config = os.path.join(checkpoint_dir, "config.json")
|
| config = XttsConfig()
|
| config.load_json(xtts_config)
|
| XTTS_MODEL = Xtts.init_from_config(config)
|
| yield "Loading model..."
|
| XTTS_MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
|
| if torch.cuda.is_available():
|
| XTTS_MODEL.cuda()
|
|
|
| print("Model Loaded!")
|
| yield "Model Loaded!"
|
|
|
|
|
|
|
| cache_queue = []
|
| speaker_audio_cache = {}
|
| filter_cache = {}
|
| conditioning_latents_cache = {}
|
|
|
|
|
| def invalidate_cache(cache_limit=50):
|
| """Invalidate the cache for the oldest key"""
|
| if len(cache_queue) > cache_limit:
|
| key_to_remove = cache_queue.pop(0)
|
| print("Invalidating cache", key_to_remove)
|
| if os.path.exists(key_to_remove):
|
| os.remove(key_to_remove)
|
| if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")):
|
| os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav"))
|
| if key_to_remove in filter_cache:
|
| del filter_cache[key_to_remove]
|
| if key_to_remove in conditioning_latents_cache:
|
| del conditioning_latents_cache[key_to_remove]
|
|
|
|
|
| def generate_hash(data):
|
| hash_object = hashlib.md5()
|
| hash_object.update(data)
|
| return hash_object.hexdigest()
|
|
|
|
|
| def get_file_name(text, max_char=50):
|
| filename = text[:max_char]
|
| filename = filename.lower()
|
| filename = filename.replace(" ", "_")
|
| filename = filename.translate(str.maketrans("", "", string.punctuation.replace("_", "")))
|
| filename = unidecode(filename)
|
| current_datetime = datetime.now().strftime("%m%d%H%M%S")
|
| filename = f"{current_datetime}_{filename}"
|
| return filename
|
|
|
| from unicodedata import normalize
|
| def normalize_vietnamese_text(text):
|
| text = (
|
| normalize("NFC", text)
|
| .replace("..", ".")
|
| .replace("!.", "!")
|
| .replace("?.", "?")
|
| .replace(" .", ".")
|
| .replace(" ,", ",")
|
| .replace('"', "")
|
| .replace("'", "")
|
| .replace("AI", "Ây Ai")
|
| .replace("A.I", "Ây Ai")
|
| )
|
| return text
|
|
|
|
|
| def calculate_keep_len(text, lang):
|
| """Simple hack for short sentences"""
|
| if lang in ["ja", "zh-cn"]:
|
| return -1
|
|
|
| word_count = len(text.split())
|
| num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
|
|
|
| if word_count < 5:
|
| return 15000 * word_count + 2000 * num_punct
|
| elif word_count < 10:
|
| return 13000 * word_count + 2000 * num_punct
|
| return -1
|
|
|
|
|
| def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text):
|
| global filter_cache, conditioning_latents_cache, cache_queue
|
|
|
| if XTTS_MODEL is None:
|
| return "You need to run the previous step to load the model !!", None, None
|
|
|
| if not speaker_audio_file:
|
| return "You need to provide reference audio!!!", None, None
|
|
|
|
|
| speaker_audio_key = speaker_audio_file
|
| if not speaker_audio_key in cache_queue:
|
| cache_queue.append(speaker_audio_key)
|
| invalidate_cache()
|
|
|
|
|
| if use_deepfilter and speaker_audio_key in filter_cache:
|
| print("Using filter cache...")
|
| speaker_audio_file = filter_cache[speaker_audio_key]
|
| elif use_deepfilter:
|
| print("Running filter...")
|
| subprocess.run(
|
| [
|
| "deepFilter",
|
| speaker_audio_file,
|
| "-o",
|
| os.path.dirname(speaker_audio_file),
|
| ]
|
| )
|
| filter_cache[speaker_audio_key] = speaker_audio_file.replace(".wav", FILTER_SUFFIX)
|
| speaker_audio_file = filter_cache[speaker_audio_key]
|
|
|
|
|
| cache_key = (
|
| speaker_audio_key,
|
| XTTS_MODEL.config.gpt_cond_len,
|
| XTTS_MODEL.config.max_ref_len,
|
| XTTS_MODEL.config.sound_norm_refs,
|
| )
|
| if cache_key in conditioning_latents_cache:
|
| print("Using conditioning latents cache...")
|
| gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key]
|
| else:
|
| print("Computing conditioning latents...")
|
| gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
| audio_path=speaker_audio_file,
|
| gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
| max_ref_length=XTTS_MODEL.config.max_ref_len,
|
| sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
| )
|
| conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding)
|
|
|
| if normalize_text and lang == "vi":
|
| tts_text = normalize_vietnamese_text(tts_text)
|
|
|
|
|
| if lang in ["ja", "zh-cn"]:
|
| sentences = tts_text.split("。")
|
| else:
|
| sentences = sent_tokenize(tts_text)
|
|
|
| wav_chunks = []
|
| for sentence in sentences:
|
| if sentence.strip() == "":
|
| continue
|
| wav_chunk = XTTS_MODEL.inference(
|
| text=sentence,
|
| language=lang,
|
| gpt_cond_latent=gpt_cond_latent,
|
| speaker_embedding=speaker_embedding,
|
|
|
| temperature=0.3,
|
| length_penalty=1.0,
|
| repetition_penalty=10.0,
|
| top_k=30,
|
| top_p=0.85,
|
| enable_text_splitting=True,
|
| )
|
|
|
| keep_len = calculate_keep_len(sentence, lang)
|
| wav_chunk["wav"] = wav_chunk["wav"][:keep_len]
|
|
|
| wav_chunks.append(torch.tensor(wav_chunk["wav"]))
|
|
|
| out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)
|
| out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}.wav")
|
| print("Saving output to ", out_path)
|
| torchaudio.save(out_path, out_wav, 24000)
|
|
|
| return "Speech generated !", out_path
|
|
|
|
|
|
|
|
|
| def create_interface():
|
| try:
|
|
|
| model_loading_gen = load_model(checkpoint_dir=MODEL_DIR, repo_id="capleaf/viXTTS", use_deepspeed=False)
|
|
|
|
|
| for message in model_loading_gen:
|
| print(message)
|
|
|
|
|
| speaker_audio_files = [
|
| r"samples\nu-nhe-nhang.wav",
|
| r"samples\nu-nhan-nha.wav",
|
| r"samples\nu-luu-loat.wav",
|
| r"samples\nu-cham.wav",
|
| r"samples\nu-calm.wav",
|
| r"samples\nam-truyen-cam.wav",
|
| r"samples\nam-nhanh.wav",
|
| r"samples\nam-cham.wav",
|
| r"samples\nam-calm.wav",
|
| ]
|
|
|
| speaker_audio_file = speaker_audio_files[0]
|
|
|
| lang = "vi"
|
| normalize_text = True
|
| use_deepfilter = False
|
| tts_text = "Chào bạn, tôi là một trợ lý ảo."
|
|
|
|
|
| return run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text)
|
| except Exception as e:
|
| return f"Error loading model: {str(e)}", None, None
|
|
|
|
|
|
|
| print(create_interface())
|
|
|