chenjgtea commited on
Commit ·
3fd53ff
1
Parent(s): 8825975
新增gpu模式下chattts代码
Browse files- README.md +1 -1
- web/{app.py → app_gpu.py} +51 -12
README.md
CHANGED
|
@@ -6,7 +6,7 @@ colorTo: purple
|
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.41.0
|
| 8 |
#app_port: 8080
|
| 9 |
-
app_file: web/
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
|
|
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.41.0
|
| 8 |
#app_port: 8080
|
| 9 |
+
app_file: web/app_gpu.py
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
web/{app.py → app_gpu.py}
RENAMED
|
@@ -203,13 +203,20 @@ def get_chat_infer_audio(chat_txt,
|
|
| 203 |
spk_emb_text):
|
| 204 |
logger.info("========开始生成音频文件=====")
|
| 205 |
#音频参数设置
|
| 206 |
-
params_infer_code = Chat2TTS.Chat.InferCodeParams(
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
)
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
torch.manual_seed(audio_seed_input)
|
| 214 |
wav = chat.infer(
|
| 215 |
text=chat_txt,
|
|
@@ -227,10 +234,11 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
|
|
| 227 |
logger.info("========文本内容无需优化=====")
|
| 228 |
return text
|
| 229 |
|
| 230 |
-
params_refine_text = Chat2TTS.Chat.RefineTextParams(
|
| 231 |
-
|
| 232 |
-
)
|
| 233 |
|
|
|
|
| 234 |
torch.manual_seed(seed)
|
| 235 |
chat_text = chat.infer(
|
| 236 |
text=text,
|
|
@@ -245,9 +253,40 @@ def get_chat_infer_text(text,seed,refine_text_checkBox):
|
|
| 245 |
def on_audio_seed_change(audio_seed_input):
|
| 246 |
global chat
|
| 247 |
torch.manual_seed(audio_seed_input)
|
| 248 |
-
rand_spk =
|
| 249 |
return rand_spk
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
if __name__ == "__main__":
|
| 253 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
|
|
|
| 203 |
spk_emb_text):
|
| 204 |
logger.info("========开始生成音频文件=====")
|
| 205 |
#音频参数设置
|
| 206 |
+
# params_infer_code = Chat2TTS.Chat.InferCodeParams(
|
| 207 |
+
# spk_emb=spk_emb_text, # add sampled speaker
|
| 208 |
+
# temperature=temperature_slider, # using custom temperature
|
| 209 |
+
# top_P=top_p_slider, # top P decode
|
| 210 |
+
# top_K=top_k_slider, # top K decode
|
| 211 |
+
# )
|
| 212 |
+
torch.manual_seed(audio_seed_input)
|
| 213 |
+
rand_spk = torch.randn(768)
|
| 214 |
+
params_infer_code = {
|
| 215 |
+
'spk_emb': rand_spk,
|
| 216 |
+
'temperature': temperature_slider,
|
| 217 |
+
'top_P': top_p_slider,
|
| 218 |
+
'top_K': top_k_slider,
|
| 219 |
+
}
|
| 220 |
torch.manual_seed(audio_seed_input)
|
| 221 |
wav = chat.infer(
|
| 222 |
text=chat_txt,
|
|
|
|
| 234 |
logger.info("========文本内容无需优化=====")
|
| 235 |
return text
|
| 236 |
|
| 237 |
+
# params_refine_text = Chat2TTS.Chat.RefineTextParams(
|
| 238 |
+
# prompt='[oral_2][laugh_0][break_6]',
|
| 239 |
+
# )
|
| 240 |
|
| 241 |
+
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
|
| 242 |
torch.manual_seed(seed)
|
| 243 |
chat_text = chat.infer(
|
| 244 |
text=text,
|
|
|
|
| 253 |
def on_audio_seed_change(audio_seed_input):
|
| 254 |
global chat
|
| 255 |
torch.manual_seed(audio_seed_input)
|
| 256 |
+
rand_spk = torch.randn(audio_seed_input)
|
| 257 |
return rand_spk
|
| 258 |
+
#return encode_spk_emb(rand_spk)
|
| 259 |
+
|
| 260 |
+
def encode_spk_emb(spk_emb: torch.Tensor) -> str:
|
| 261 |
+
import pybase16384 as b14
|
| 262 |
+
import lzma
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
| 265 |
+
s = b14.encode_to_string(
|
| 266 |
+
lzma.compress(
|
| 267 |
+
arr.tobytes(),
|
| 268 |
+
format=lzma.FORMAT_RAW,
|
| 269 |
+
filters=[
|
| 270 |
+
{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
|
| 271 |
+
],
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
del arr
|
| 275 |
+
return s
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# def _sample_random_speaker(self) -> torch.Tensor:
|
| 279 |
+
# with torch.no_grad():
|
| 280 |
+
# dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
| 281 |
+
# out: torch.Tensor = self.pretrain_models["spk_stat"]
|
| 282 |
+
# std, mean = out.chunk(2)
|
| 283 |
+
# spk = (
|
| 284 |
+
# torch.randn(dim, device=std.device, dtype=torch.float16)
|
| 285 |
+
# .mul_(std)
|
| 286 |
+
# .add_(mean)
|
| 287 |
+
# )
|
| 288 |
+
# del out, std, mean
|
| 289 |
+
# return spk
|
| 290 |
|
| 291 |
if __name__ == "__main__":
|
| 292 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|