chenjgtea commited on
Commit
ccae0f6
·
1 Parent(s): e6c5928

新增gpu模式下chattts代码

Browse files
Files changed (2) hide show
  1. Chat2TTS/core.py +22 -1
  2. web/app_gpu.py +22 -20
Chat2TTS/core.py CHANGED
@@ -12,6 +12,9 @@ from .utils.io_utils import get_latest_modified_file
12
  from .infer.api import refine_text, infer_code
13
  from dataclasses import dataclass
14
  from typing import Literal, Optional, List, Tuple, Dict
 
 
 
15
 
16
  from huggingface_hub import snapshot_download
17
 
@@ -167,5 +170,23 @@ class Chat:
167
 
168
  return wav
169
 
170
-
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from .infer.api import refine_text, infer_code
13
  from dataclasses import dataclass
14
  from typing import Literal, Optional, List, Tuple, Dict
15
+ import numpy as np
16
+ import pybase16384 as b14
17
+ import lzma
18
 
19
  from huggingface_hub import snapshot_download
20
 
 
170
 
171
  return wav
172
 
 
173
 
174
+ def sample_random_speaker(self) -> str:
175
+ return self._encode_spk_emb(self._sample_random_speaker())
176
+
177
+
178
+ @staticmethod
179
+ def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
180
+ with torch.no_grad():
181
+ arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
182
+ s = b14.encode_to_string(
183
+ lzma.compress(
184
+ arr.tobytes(),
185
+ format=lzma.FORMAT_RAW,
186
+ filters=[
187
+ {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
188
+ ],
189
+ ),
190
+ )
191
+ del arr
192
+ return s
web/app_gpu.py CHANGED
@@ -29,7 +29,7 @@ def init_chat(args):
29
  source = "custom"
30
  # 获取启动模式
31
  MODEL = os.getenv('MODEL')
32
- logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL))
33
  # huggingface 部署模式下,模型则直接使用hf的模型数据
34
  if MODEL == "HF":
35
  source = "huggingface"
@@ -253,25 +253,27 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
253
  def on_audio_seed_change(audio_seed_input):
254
  global chat
255
  torch.manual_seed(audio_seed_input)
256
- rand_spk = torch.randn(audio_seed_input)
257
- return encode_spk_emb(rand_spk)
258
-
259
- def encode_spk_emb(spk_emb: torch.Tensor) -> str:
260
- import pybase16384 as b14
261
- import lzma
262
- with torch.no_grad():
263
- arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
264
- s = b14.encode_to_string(
265
- lzma.compress(
266
- arr.tobytes(),
267
- format=lzma.FORMAT_RAW,
268
- filters=[
269
- {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
270
- ],
271
- ),
272
- )
273
- del arr
274
- return s
 
 
275
 
276
 
277
  # def _sample_random_speaker(self) -> torch.Tensor:
 
29
  source = "custom"
30
  # 获取启动模式
31
  MODEL = os.getenv('MODEL')
32
+ logger.info("loading Chat2TTS model..., start MODEL:" + str(MODEL))
33
  # huggingface 部署模式下,模型则直接使用hf的模型数据
34
  if MODEL == "HF":
35
  source = "huggingface"
 
253
  def on_audio_seed_change(audio_seed_input):
254
  global chat
255
  torch.manual_seed(audio_seed_input)
256
+ rand_spk = chat.sample_random_speaker()
257
+ return rand_spk
258
+ # rand_spk = torch.randn(audio_seed_input)
259
+ # return encode_spk_emb(rand_spk)
260
+
261
+ # def encode_spk_emb(spk_emb: torch.Tensor) -> str:
262
+ # import pybase16384 as b14
263
+ # import lzma
264
+ # with torch.no_grad():
265
+ # arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
266
+ # s = b14.encode_to_string(
267
+ # lzma.compress(
268
+ # arr.tobytes(),
269
+ # format=lzma.FORMAT_RAW,
270
+ # filters=[
271
+ # {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
272
+ # ],
273
+ # ),
274
+ # )
275
+ # del arr
276
+ # return s
277
 
278
 
279
  # def _sample_random_speaker(self) -> torch.Tensor: