File size: 5,555 Bytes
e58b011 3f903dd 002e5d4 e58b011 002e5d4 e58b011 002e5d4 e58b011 002e5d4 e58b011 002e5d4 e58b011 3f903dd e58b011 3f903dd e58b011 3f903dd e58b011 3f903dd e58b011 3f903dd e58b011 002e5d4 e58b011 e16bdd9 002e5d4 e58b011 e16bdd9 e58b011 3f903dd e58b011 e16bdd9 3f903dd e58b011 3f903dd e58b011 3f903dd e58b011 3f903dd e58b011 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import re
import time
import torch
import librosa
import numpy as np
import json
from pathlib import Path
from typing import Generator
from huggingface_hub import snapshot_download
# --- BẢN VÁ CAO CẤP (MONKEY PATCH) ĐỂ CHẠY REPO KTVOICE ---
import neucodec.model
def _apply_robust_patch(target_cls):
"""Vá lỗi AssertionError và TypeError cho thư viện neucodec"""
orig_func = target_cls._from_pretrained
@classmethod
def _patched_func(cls, *args, **kwargs):
# Đảm bảo model_id luôn hợp lệ để thoả mãn lệnh assert của thư viện
official_id = "neuphonic/distill-neucodec" if "distill" in str(cls).lower() else "neuphonic/neucodec"
# Sửa lỗi: HubMixin truyền model_id ở vị trí đầu tiên
if args:
kwargs["model_id"] = official_id
return orig_func(*args[1:], **kwargs)
else:
kwargs["model_id"] = official_id
return orig_func(**kwargs)
target_cls._from_pretrained = _patched_func
# Áp dụng cho cả 2 lớp của neucodec
_apply_robust_patch(neucodec.model.NeuCodec)
_apply_robust_patch(neucodec.model.DistillNeuCodec)
# -------------------------------------------------------
from neucodec import NeuCodec, DistillNeuCodec
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.phonemize_text import phonemize_text, phonemize_with_dict
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
assert len(frames)
dtype = frames[0].dtype
shape = frames[0].shape[:-1]
total_size = max(stride * i + frame.shape[-1] for i, frame in enumerate(frames))
sum_weight = np.zeros(total_size, dtype=dtype)
out = np.zeros(*shape, total_size, dtype=dtype)
offset: int = 0
for frame in frames:
frame_length = frame.shape[-1]
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = np.abs(0.5 - (t - 0.5))
out[..., offset : offset + frame_length] += weight * frame
sum_weight[offset : offset + frame_length] += weight
offset += stride
return out / sum_weight
class VoiceEngine:
def __init__(self, backbone_repo="ktvoice/Backbone", backbone_device="cpu", codec_repo="ktvoice/Codec", codec_device="cpu"):
self.sample_rate = 24_000
self.max_context = 2048
self._is_quantized_model = False
self.tokenizer = None
self._load_backbone(backbone_repo, backbone_device)
self._load_codec(codec_repo, codec_device)
def _load_backbone(self, repo, device):
print(f"Loading backbone from: {repo} on {device} ...")
if "gguf" in repo.lower():
from llama_cpp import Llama
self.backbone = Llama.from_pretrained(repo_id=repo, filename="*.gguf", n_ctx=self.max_context)
self._is_quantized_model = True
else:
self.tokenizer = AutoTokenizer.from_pretrained(repo)
self.backbone = AutoModelForCausalLM.from_pretrained(repo).to(torch.device(device))
def _load_codec(self, repo, device):
print(f"Loading codec from your repo: {repo} ...")
# Tải trọng số từ repo ktvoice của bạn
local_dir = snapshot_download(repo_id=repo)
# Tạo file cấu hình tạm thời để tránh lỗi "config.json not found"
# File này chỉ dùng để kích hoạt trình nạp của Hugging Face
tmp_config = os.path.join(local_dir, "config.json")
if not os.path.exists(tmp_config):
with open(tmp_config, "w") as f: json.dump({"model_type": "neucodec"}, f)
if "distill" in repo.lower():
self.codec = DistillNeuCodec.from_pretrained(local_dir)
else:
self.codec = NeuCodec.from_pretrained(local_dir)
self.codec.eval().to(device)
def encode_reference(self, path):
wav, _ = librosa.load(path, sr=16000, mono=True)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
with torch.no_grad():
return self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)
def infer(self, text, ref_codes, ref_text):
prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
with torch.no_grad():
out = self.backbone.generate(prompt_tensor, max_length=self.max_context, do_sample=True, temperature=1)
tokens = self.tokenizer.decode(out[0, prompt_tensor.shape[-1]:], add_special_tokens=False)
return self._decode(tokens)
def _decode(self, codes_str):
speech_ids = [int(n) for n in re.findall(r"<\|speech_(\d+)\|>", codes_str)]
with torch.no_grad():
codes_tensor = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(self.codec.device)
return self.codec.decode_code(codes_tensor).cpu().numpy()[0, 0, :]
def _apply_chat_template(self, ref_codes, ref_text, text):
input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(text)
chat = f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{input_text}<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>"
c_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
return self.tokenizer.encode(chat + c_str) |