File size: 5,555 Bytes
e58b011
 
 
 
3f903dd
 
002e5d4
e58b011
 
 
 
002e5d4
e58b011
 
002e5d4
 
 
e58b011
002e5d4
 
 
 
 
 
 
 
 
 
 
 
 
 
e58b011
002e5d4
 
 
 
e58b011
3f903dd
 
 
 
 
 
 
 
e58b011
3f903dd
 
 
 
 
 
 
 
 
 
 
 
 
e58b011
3f903dd
 
 
 
 
 
 
e58b011
 
 
 
 
3f903dd
 
e58b011
 
3f903dd
e58b011
002e5d4
 
e58b011
e16bdd9
002e5d4
 
 
 
 
 
e58b011
 
e16bdd9
e58b011
 
3f903dd
e58b011
 
e16bdd9
3f903dd
e58b011
3f903dd
e58b011
 
 
 
 
 
 
 
3f903dd
e58b011
 
3f903dd
e58b011
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import re
import time
import torch
import librosa
import numpy as np
import json
from pathlib import Path
from typing import Generator
from huggingface_hub import snapshot_download

# --- BẢN VÁ CAO CẤP (MONKEY PATCH) ĐỂ CHẠY REPO KTVOICE ---
import neucodec.model

def _apply_robust_patch(target_cls):
    """Vá lỗi AssertionError và TypeError cho thư viện neucodec"""
    orig_func = target_cls._from_pretrained
    
    @classmethod
    def _patched_func(cls, *args, **kwargs):
        # Đảm bảo model_id luôn hợp lệ để thoả mãn lệnh assert của thư viện
        official_id = "neuphonic/distill-neucodec" if "distill" in str(cls).lower() else "neuphonic/neucodec"
        
        # Sửa lỗi: HubMixin truyền model_id ở vị trí đầu tiên
        if args:
            kwargs["model_id"] = official_id
            return orig_func(*args[1:], **kwargs)
        else:
            kwargs["model_id"] = official_id
            return orig_func(**kwargs)
            
    target_cls._from_pretrained = _patched_func

# Áp dụng cho cả 2 lớp của neucodec
_apply_robust_patch(neucodec.model.NeuCodec)
_apply_robust_patch(neucodec.model.DistillNeuCodec)
# -------------------------------------------------------

from neucodec import NeuCodec, DistillNeuCodec
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.phonemize_text import phonemize_text, phonemize_with_dict

def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
    assert len(frames)
    dtype = frames[0].dtype
    shape = frames[0].shape[:-1]
    total_size = max(stride * i + frame.shape[-1] for i, frame in enumerate(frames))
    sum_weight = np.zeros(total_size, dtype=dtype)
    out = np.zeros(*shape, total_size, dtype=dtype)
    offset: int = 0
    for frame in frames:
        frame_length = frame.shape[-1]
        t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
        weight = np.abs(0.5 - (t - 0.5))
        out[..., offset : offset + frame_length] += weight * frame
        sum_weight[offset : offset + frame_length] += weight
        offset += stride
    return out / sum_weight

class VoiceEngine:
    def __init__(self, backbone_repo="ktvoice/Backbone", backbone_device="cpu", codec_repo="ktvoice/Codec", codec_device="cpu"):
        self.sample_rate = 24_000
        self.max_context = 2048
        self._is_quantized_model = False
        self.tokenizer = None
        self._load_backbone(backbone_repo, backbone_device)
        self._load_codec(codec_repo, codec_device)
    
    def _load_backbone(self, repo, device):
        print(f"Loading backbone from: {repo} on {device} ...")
        if "gguf" in repo.lower():
            from llama_cpp import Llama
            self.backbone = Llama.from_pretrained(repo_id=repo, filename="*.gguf", n_ctx=self.max_context)
            self._is_quantized_model = True
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(repo)
            self.backbone = AutoModelForCausalLM.from_pretrained(repo).to(torch.device(device))
    
    def _load_codec(self, repo, device):
        print(f"Loading codec from your repo: {repo} ...")
        # Tải trọng số từ repo ktvoice của bạn
        local_dir = snapshot_download(repo_id=repo)
        
        # Tạo file cấu hình tạm thời để tránh lỗi "config.json not found"
        # File này chỉ dùng để kích hoạt trình nạp của Hugging Face
        tmp_config = os.path.join(local_dir, "config.json")
        if not os.path.exists(tmp_config):
            with open(tmp_config, "w") as f: json.dump({"model_type": "neucodec"}, f)
        
        if "distill" in repo.lower():
            self.codec = DistillNeuCodec.from_pretrained(local_dir)
        else:
            self.codec = NeuCodec.from_pretrained(local_dir)
        self.codec.eval().to(device)

    def encode_reference(self, path):
        wav, _ = librosa.load(path, sr=16000, mono=True)
        wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            return self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0)

    def infer(self, text, ref_codes, ref_text):
        prompt_ids = self._apply_chat_template(ref_codes, ref_text, text)
        prompt_tensor = torch.tensor(prompt_ids).unsqueeze(0).to(self.backbone.device)
        with torch.no_grad():
            out = self.backbone.generate(prompt_tensor, max_length=self.max_context, do_sample=True, temperature=1)
        
        tokens = self.tokenizer.decode(out[0, prompt_tensor.shape[-1]:], add_special_tokens=False)
        return self._decode(tokens)

    def _decode(self, codes_str):
        speech_ids = [int(n) for n in re.findall(r"<\|speech_(\d+)\|>", codes_str)]
        with torch.no_grad():
            codes_tensor = torch.tensor(speech_ids, dtype=torch.long)[None, None, :].to(self.codec.device)
            return self.codec.decode_code(codes_tensor).cpu().numpy()[0, 0, :]

    def _apply_chat_template(self, ref_codes, ref_text, text):
        input_text = phonemize_with_dict(ref_text) + " " + phonemize_with_dict(text)
        chat = f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{input_text}<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>"
        c_str = "".join([f"<|speech_{i}|>" for i in ref_codes])
        return self.tokenizer.encode(chat + c_str)