kt005 / tts_engine.py
ktvoice's picture
Upload 2 files
002e5d4 verified
import os
import re
import time
import torch
import librosa
import numpy as np
import json
from pathlib import Path
from typing import Generator
from huggingface_hub import snapshot_download
# --- BẢN VÁ CAO CẤP (MONKEY PATCH) ĐỂ CHẠY REPO KTVOICE ---
import neucodec.model
def _apply_robust_patch(target_cls):
"""Vá lỗi AssertionError và TypeError cho thư viện neucodec"""
orig_func = target_cls._from_pretrained
@classmethod
def _patched_func(cls, *args, **kwargs):
# Đảm bảo model_id luôn hợp lệ để thoả mãn lệnh assert của thư viện
official_id = "neuphonic/distill-neucodec" if "distill" in str(cls).lower() else "neuphonic/neucodec"
# Sửa lỗi: HubMixin truyền model_id ở vị trí đầu tiên
if args:
kwargs["model_id"] = official_id
return orig_func(*args[1:], **kwargs)
else:
kwargs["model_id"] = official_id
return orig_func(**kwargs)
target_cls._from_pretrained = _patched_func
# Áp dụng cho cả 2 lớp của neucodec
_apply_robust_patch(neucodec.model.NeuCodec)
_apply_robust_patch(neucodec.model.DistillNeuCodec)
# -------------------------------------------------------
from neucodec import NeuCodec, DistillNeuCodec
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.phonemize_text import phonemize_text, phonemize_with_dict
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
assert len(frames)
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = max(stride * i + frame.shape[-1] for i, frame in enumerate(frames))
sum_weight = np.zeros(total_size, dtype=dtype)
out = np.zeros(*shape, total_size, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = np.abs(0.5 - (t - 0.5))
out[..., offset : offset + frame_length] += weight * frame
sum_weight[offset : offset + frame_length] += weight
offset += stride
return out / sum_weight
class VoiceEngine:
def __init__(self, backbone_repo="ktvoice/Backbone", backbone_device="cpu", codec_repo="ktvoice/Codec", codec_device="cpu"):
self.sample_rate = 24_000
self.max_context = 2048
self._is_quantized_model = False
self.tokenizer = None
self._load_backbone(backbone_repo, backbone_device)
self._load_codec(codec_repo, codec_device)
def _load_backbone(self, repo, device):
print(f"Loading backbone from: {repo} on {device} ...")
if "gguf" in repo.lower():
from llama_cpp import Llama
self.backbone = Llama.from_pretrained(repo_id=repo, filename="*.gguf", n_ctx=self.max_context)
self._is_quantized_model = True
else:
self.tokenizer = AutoTokenizer.from_pretrained(repo)
self.backbone = AutoModelForCausalLM.from_pretrained(repo).to(torch.device(device))
def _load_codec(self, repo, device):
print(f"Loading codec from your repo: {repo} ...")
# Tải trọng số từ repo ktvoice của bạn
local_dir = snapshot_download(repo_id=repo)
# Tạo file cấu hình tạm thời để tránh lỗi "config.json not found"
# File này chỉ dùng để kích hoạt trình nạp của Hugging Face
tmp_config = os.path.join(local_dir, "config.json")
if not os.path.exists(tmp_config):
with open(tmp_config, "w") as f: json.dump({"model_type": "neucodec"}, f)
if "distill" in repo.lower():
self.codec = DistillNeuCodec.from_pretrained(local_dir)
else:
self.codec = NeuCodec.from_pretrained(local_dir)
self.codec.eval().to(device)
def encode_reference(self, path):
wav, _ = librosa.load(path, sr=16000, mono=True)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
with torch.no_grad():
return self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
def infer(self, text, ref_codes, ref_text):
prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
with torch.no_grad():
out = self.backbone.generate(prompt_tensor, max_length=self.max_context, do_sample=True, temperature=1)
tokens = self.tokenizer.decode(out[0, prompt_tensor.shape[-1]:], add_special_tokens=False)
return self._decode(tokens)
def _decode(self, codes_str):
speech_ids = [int(n) for n in re.findall(r"<\|speech_(\d+)\|>", codes_str)]
with torch.no_grad():
codes_tensor = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(self.codec.device)
return self.codec.decode_code(codes_tensor).cpu().numpy()[0, 0, :]
def _apply_chat_template(self, ref_codes, ref_text, text):
input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(text)
chat = f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{input_text}<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>"
c_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
return self.tokenizer.encode(chat + c_str)