chenjgtea commited on
Commit ·
ccae0f6
1
Parent(s): e6c5928
新增gpu模式下chattts代码
Browse files- Chat2TTS/core.py +22 -1
- web/app_gpu.py +22 -20
Chat2TTS/core.py
CHANGED
|
@@ -12,6 +12,9 @@ from .utils.io_utils import get_latest_modified_file
|
|
| 12 |
from .infer.api import refine_text, infer_code
|
| 13 |
from dataclasses import dataclass
|
| 14 |
from typing import Literal, Optional, List, Tuple, Dict
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
from huggingface_hub import snapshot_download
|
| 17 |
|
|
@@ -167,5 +170,23 @@ class Chat:
|
|
| 167 |
|
| 168 |
return wav
|
| 169 |
|
| 170 |
-
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from .infer.api import refine_text, infer_code
|
| 13 |
from dataclasses import dataclass
|
| 14 |
from typing import Literal, Optional, List, Tuple, Dict
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pybase16384 as b14
|
| 17 |
+
import lzma
|
| 18 |
|
| 19 |
from huggingface_hub import snapshot_download
|
| 20 |
|
|
|
|
| 170 |
|
| 171 |
return wav
|
| 172 |
|
|
|
|
| 173 |
|
| 174 |
+
def sample_random_speaker(self) -> str:
|
| 175 |
+
return self._encode_spk_emb(self._sample_random_speaker())
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
| 182 |
+
s = b14.encode_to_string(
|
| 183 |
+
lzma.compress(
|
| 184 |
+
arr.tobytes(),
|
| 185 |
+
format=lzma.FORMAT_RAW,
|
| 186 |
+
filters=[
|
| 187 |
+
{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
|
| 188 |
+
],
|
| 189 |
+
),
|
| 190 |
+
)
|
| 191 |
+
del arr
|
| 192 |
+
return s
|
web/app_gpu.py
CHANGED
|
@@ -29,7 +29,7 @@ def init_chat(args):
|
|
| 29 |
source = "custom"
|
| 30 |
# 获取启动模式
|
| 31 |
MODEL = os.getenv('MODEL')
|
| 32 |
-
logger.info("loading
|
| 33 |
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
| 34 |
if MODEL == "HF":
|
| 35 |
source = "huggingface"
|
|
@@ -253,25 +253,27 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
|
|
| 253 |
def on_audio_seed_change(audio_seed_input):
|
| 254 |
global chat
|
| 255 |
torch.manual_seed(audio_seed_input)
|
| 256 |
-
rand_spk =
|
| 257 |
-
return
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
# def _sample_random_speaker(self) -> torch.Tensor:
|
|
|
|
| 29 |
source = "custom"
|
| 30 |
# 获取启动模式
|
| 31 |
MODEL = os.getenv('MODEL')
|
| 32 |
+
logger.info("loading Chat2TTS model..., start MODEL:" + str(MODEL))
|
| 33 |
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
| 34 |
if MODEL == "HF":
|
| 35 |
source = "huggingface"
|
|
|
|
| 253 |
def on_audio_seed_change(audio_seed_input):
|
| 254 |
global chat
|
| 255 |
torch.manual_seed(audio_seed_input)
|
| 256 |
+
rand_spk = chat.sample_random_speaker()
|
| 257 |
+
return rand_spk
|
| 258 |
+
# rand_spk = torch.randn(audio_seed_input)
|
| 259 |
+
# return encode_spk_emb(rand_spk)
|
| 260 |
+
|
| 261 |
+
# def encode_spk_emb(spk_emb: torch.Tensor) -> str:
|
| 262 |
+
# import pybase16384 as b14
|
| 263 |
+
# import lzma
|
| 264 |
+
# with torch.no_grad():
|
| 265 |
+
# arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
| 266 |
+
# s = b14.encode_to_string(
|
| 267 |
+
# lzma.compress(
|
| 268 |
+
# arr.tobytes(),
|
| 269 |
+
# format=lzma.FORMAT_RAW,
|
| 270 |
+
# filters=[
|
| 271 |
+
# {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
|
| 272 |
+
# ],
|
| 273 |
+
# ),
|
| 274 |
+
# )
|
| 275 |
+
# del arr
|
| 276 |
+
# return s
|
| 277 |
|
| 278 |
|
| 279 |
# def _sample_random_speaker(self) -> torch.Tensor:
|