|
|
|
|
|
import os
|
|
|
import time
|
|
|
import torch
|
|
|
import torchaudio
|
|
|
from huggingface_hub import snapshot_download, hf_hub_download
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig
|
|
|
from TTS.tts.models.xtts import Xtts
|
|
|
from pathlib import Path
|
|
|
import traceback
|
|
|
import logging
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("Đang khởi tạo module TTS...")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
CHECKPOINT_DIR = "model/"
|
|
|
REPO_ID = "capleaf/viXTTS"
|
|
|
SAVE_DIR = "generated_content"
|
|
|
USE_DEEPSPEED = False
|
|
|
|
|
|
logger.info(f"CHECKPOINT_DIR: {CHECKPOINT_DIR}")
|
|
|
logger.info(f"REPO_ID: {REPO_ID}")
|
|
|
logger.info(f"SAVE_DIR: {SAVE_DIR}")
|
|
|
|
|
|
|
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
|
|
Path(SAVE_DIR).mkdir(exist_ok=True)
|
|
|
Path(os.path.join(SAVE_DIR, "audio")).mkdir(exist_ok=True)
|
|
|
logger.info("Đã tạo các thư mục cần thiết.")
|
|
|
|
|
|
|
|
|
logger.info("Kiểm tra và tải model viXTTS...")
|
|
|
try:
|
|
|
|
|
|
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
|
|
|
files_in_dir = os.listdir(CHECKPOINT_DIR)
|
|
|
|
|
|
if not all(file in files_in_dir for file in required_files):
|
|
|
logger.info("Một số file model bị thiếu. Bắt đầu tải xuống...")
|
|
|
snapshot_download(
|
|
|
repo_id=REPO_ID,
|
|
|
repo_type="model",
|
|
|
local_dir=CHECKPOINT_DIR,
|
|
|
)
|
|
|
|
|
|
hf_hub_download(
|
|
|
repo_id="coqui/XTTS-v2",
|
|
|
filename="speakers_xtts.pth",
|
|
|
local_dir=CHECKPOINT_DIR,
|
|
|
)
|
|
|
logger.info("Tải model hoàn tất.")
|
|
|
else:
|
|
|
logger.info("Model đã tồn tại, không cần tải lại.")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Lỗi khi tải model: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
raise
|
|
|
|
|
|
|
|
|
logger.info("Đang nạp cấu hình XTTS...")
|
|
|
config_path = os.path.join(CHECKPOINT_DIR, "config.json")
|
|
|
|
|
|
if not os.path.exists(config_path):
|
|
|
raise FileNotFoundError(f"Không tìm thấy file config tại: {config_path}")
|
|
|
|
|
|
try:
|
|
|
config = XttsConfig()
|
|
|
config.load_json(config_path)
|
|
|
logger.info("Đã tải cấu hình thành công.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Lỗi khi tải config: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
raise
|
|
|
|
|
|
|
|
|
logger.info("Đang khởi tạo model XTTS...")
|
|
|
try:
|
|
|
XTTS_MODEL = Xtts.init_from_config(config)
|
|
|
logger.info("Đã khởi tạo model từ config thành công.")
|
|
|
logger.info("Đang tải checkpoint...")
|
|
|
XTTS_MODEL.load_checkpoint(config, checkpoint_dir=CHECKPOINT_DIR, use_deepspeed=USE_DEEPSPEED)
|
|
|
logger.info("Đã tải checkpoint thành công.")
|
|
|
|
|
|
XTTS_MODEL.eval()
|
|
|
logger.info("Đã đặt model vào chế độ inference.")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Lỗi khi khởi tạo hoặc tải model: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
raise
|
|
|
|
|
|
|
|
|
try:
|
|
|
if torch.cuda.is_available():
|
|
|
logger.info("Phát hiện CUDA. Chuyển model sang GPU.")
|
|
|
XTTS_MODEL.cuda()
|
|
|
logger.info(f"Model đã được chuyển lên GPU: {torch.cuda.get_device_name(0)}")
|
|
|
else:
|
|
|
logger.info("Không phát hiện CUDA. Model sẽ chạy trên CPU.")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Lỗi khi chuyển model lên GPU: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("Khởi tạo module TTS hoàn tất!")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_vietnamese_text(text):
|
|
|
"""Chuẩn hóa văn bản tiếng Việt."""
|
|
|
logger.debug(f"Văn bản trước khi chuẩn hóa: {text}")
|
|
|
|
|
|
text = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", text)
|
|
|
text = (
|
|
|
text.replace("..", ".").replace("!.", "!").replace("?.", "?")
|
|
|
.replace(" .", ".").replace(" ,", ",").replace('"', "").replace("'", "")
|
|
|
.replace("AI", "Ây Ai").replace("A.I", "Ây Ai")
|
|
|
)
|
|
|
logger.debug(f"Văn bản sau khi chuẩn hóa: {text}")
|
|
|
return text
|
|
|
|
|
|
def get_voice_conditioning(audio_file_path):
|
|
|
"""
|
|
|
Tính toán đặc trưng giọng nói từ file audio tham chiếu.
|
|
|
|
|
|
Args:
|
|
|
audio_file_path (str): Đường dẫn đến file audio tham chiếu
|
|
|
|
|
|
Returns:
|
|
|
tuple: (gpt_cond_latent, speaker_embedding)
|
|
|
"""
|
|
|
logger.info(f"Đang tính toán đặc trưng giọng nói từ: {audio_file_path}")
|
|
|
|
|
|
if not os.path.exists(audio_file_path):
|
|
|
raise FileNotFoundError(f"Không tìm thấy file audio tham chiếu: {audio_file_path}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
|
|
audio_path=audio_file_path,
|
|
|
gpt_cond_len=30,
|
|
|
gpt_cond_chunk_len=4,
|
|
|
max_ref_length=60,
|
|
|
)
|
|
|
logger.info("✓ Đã tính toán đặc trưng giọng nói thành công.")
|
|
|
return gpt_cond_latent, speaker_embedding
|
|
|
except Exception as e:
|
|
|
logger.error(f"Lỗi khi tính toán đặc trưng giọng nói: {e}")
|
|
|
raise
|
|
|
|
|
|
def predict_tts(
|
|
|
text: str,
|
|
|
language: str = "vi",
|
|
|
audio_file_path: str = None,
|
|
|
output_filename: str = "output.wav",
|
|
|
gpt_cond_latent=None,
|
|
|
speaker_embedding=None
|
|
|
):
|
|
|
"""
|
|
|
Chuyển văn bản thành giọng nói.
|
|
|
|
|
|
Args:
|
|
|
text (str): Văn bản cần chuyển đổi
|
|
|
language (str): Mã ngôn ngữ (mặc định: 'vi')
|
|
|
audio_file_path (str): Đường dẫn file audio tham chiếu (nếu chưa có conditioning).
|
|
|
output_filename (str): Tên file đầu ra.
|
|
|
gpt_cond_latent: Đặc trưng điều kiện GPT (nếu đã có).
|
|
|
speaker_embedding: Embedding người nói (nếu đã có).
|
|
|
|
|
|
Returns:
|
|
|
str: Đường dẫn đến file audio đã tạo.
|
|
|
"""
|
|
|
logger.info("=" * 50)
|
|
|
logger.info("BẮT ĐẦU PREDICT_TTS")
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|
logger.info(f"Text: '{text}'")
|
|
|
logger.info(f"Language: '{language}'")
|
|
|
logger.info(f"Output filename: '{output_filename}'")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not text or len(text.strip()) < 2:
|
|
|
raise ValueError("Văn bản quá ngắn, vui lòng nhập nội dung dài hơn.")
|
|
|
|
|
|
|
|
|
if language == "vi":
|
|
|
text = normalize_vietnamese_text(text)
|
|
|
|
|
|
|
|
|
if gpt_cond_latent is None or speaker_embedding is None:
|
|
|
if not audio_file_path:
|
|
|
|
|
|
default_audio = os.path.join(CHECKPOINT_DIR, "samples/nu-luu-loat.wav")
|
|
|
if os.path.exists(default_audio):
|
|
|
logger.warning(f"Không có file tham chiếu, sử dụng file mặc định: {default_audio}")
|
|
|
audio_file_path = default_audio
|
|
|
else:
|
|
|
raise ValueError("Cần cung cấp file audio tham chiếu hoặc đã có conditioning latents.")
|
|
|
|
|
|
gpt_cond_latent, speaker_embedding = get_voice_conditioning(audio_file_path)
|
|
|
|
|
|
|
|
|
logger.info("Đang tạo audio...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
out = XTTS_MODEL.inference(
|
|
|
text,
|
|
|
language,
|
|
|
gpt_cond_latent,
|
|
|
speaker_embedding,
|
|
|
repetition_penalty=5.0,
|
|
|
temperature=0.75,
|
|
|
enable_text_splitting=True,
|
|
|
)
|
|
|
|
|
|
inference_time = time.time() - start_time
|
|
|
logger.info(f"✓ Inference hoàn tất trong {round(inference_time * 1000)}ms")
|
|
|
|
|
|
|
|
|
audio_save_path = os.path.join(SAVE_DIR, "audio", output_filename)
|
|
|
logger.info(f"Đang lưu audio vào: {audio_save_path}")
|
|
|
|
|
|
|
|
|
audio_tensor = torch.tensor(out["wav"]).unsqueeze(0)
|
|
|
torchaudio.save(
|
|
|
audio_save_path,
|
|
|
audio_tensor,
|
|
|
24000
|
|
|
)
|
|
|
|
|
|
|
|
|
if os.path.exists(audio_save_path):
|
|
|
file_size = os.path.getsize(audio_save_path)
|
|
|
logger.info(f"✓ Đã lưu file thành công. Kích thước: {file_size} bytes")
|
|
|
else:
|
|
|
raise IOError("File không được tạo sau khi lưu.")
|
|
|
|
|
|
logger.info("=" * 50)
|
|
|
logger.info("PREDICT_TTS HOÀN THÀNH THÀNH CÔNG")
|
|
|
logger.info("=" * 50)
|
|
|
|
|
|
return audio_save_path
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error("=" * 50)
|
|
|
logger.error("PREDICT_TTS THẤT BẠI")
|
|
|
logger.error(f"Lỗi: {e}")
|
|
|
logger.error(traceback.format_exc())
|
|
|
logger.error("=" * 50)
|
|
|
raise
|
|
|
|
|
|
def test_tts():
|
|
|
"""Hàm test để kiểm tra TTS hoạt động."""
|
|
|
try:
|
|
|
logger.info("Đang chạy test TTS...")
|
|
|
|
|
|
|
|
|
test_text = "Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt."
|
|
|
|
|
|
|
|
|
ref_audio = os.path.join(CHECKPOINT_DIR, "samples/nam-tai-llieu.wav")
|
|
|
|
|
|
|
|
|
output_path = predict_tts(
|
|
|
text=test_text,
|
|
|
language="vi",
|
|
|
audio_file_path=ref_audio,
|
|
|
output_filename="test_output.wav"
|
|
|
)
|
|
|
|
|
|
logger.info(f"Test thành công! File audio tạo tại: {output_path}")
|
|
|
return output_path
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Test thất bại: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
"""
|
|
|
Đây là khối lệnh sẽ chạy khi bạn thực thi file này trực tiếp.
|
|
|
python tts_module.py
|
|
|
"""
|
|
|
print("Chạy thử nghiệm module TTS...")
|
|
|
|
|
|
|
|
|
test_tts()
|
|
|
|
|
|
|
|
|
try:
|
|
|
print("\n--- Chạy test nâng cao với conditioning một lần ---")
|
|
|
|
|
|
|
|
|
custom_ref_audio = os.path.join(CHECKPOINT_DIR, "samples/nam-tram-am.wav")
|
|
|
if not os.path.exists(custom_ref_audio):
|
|
|
print(f"Không tìm thấy file {custom_ref_audio}. Bỏ qua test nâng cao.")
|
|
|
else:
|
|
|
|
|
|
print(f"Tính toán đặc trưng giọng nói từ: {custom_ref_audio}")
|
|
|
gpt_cond, speaker_emb = get_voice_conditioning(custom_ref_audio)
|
|
|
|
|
|
texts_to_generate = [
|
|
|
"Đây là câu đầu tiên sử dụng giọng đã được tính toán trước.",
|
|
|
"Và đây là câu thứ hai, quá trình sẽ nhanh hơn vì không cần đọc lại file âm thanh.",
|
|
|
"Công nghệ này thật tuyệt vời."
|
|
|
]
|
|
|
|
|
|
|
|
|
for i, text in enumerate(texts_to_generate):
|
|
|
predict_tts(
|
|
|
text=text,
|
|
|
language="vi",
|
|
|
gpt_cond_latent=gpt_cond,
|
|
|
speaker_embedding=speaker_emb,
|
|
|
output_filename=f"advanced_test_{i+1}.wav"
|
|
|
)
|
|
|
print("--- Test nâng cao hoàn thành ---")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Lỗi trong quá trình test nâng cao: {e}") |