chenjgtea commited on
Commit ·
214ea91
1
Parent(s): 394c436
提交代码
Browse files- .gitignore +5 -0
- .idea/.gitignore +10 -0
- README.md +1 -1
- requirements.txt +29 -0
- test/__init__.py +0 -0
- test/api.py +72 -0
- test/common_test.py +24 -0
- tool/__init__.py +4 -0
- tool/av.py +79 -0
- tool/ctx.py +14 -0
- tool/func.py +35 -0
- tool/logger/__init__.py +1 -0
- tool/logger/log.py +73 -0
- tool/np.py +11 -0
- tool/pcm.py +21 -0
- web/__init__.py +0 -0
- web/app.py +246 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/.idea/misc.xml
|
| 2 |
+
/.idea/modules.xml
|
| 3 |
+
/.idea/inspectionProfiles/profiles_settings.xml
|
| 4 |
+
/.idea/inspectionProfiles/Project_Default.xml
|
| 5 |
+
/.idea/vcs.xml
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
| 9 |
+
/.idea/
|
| 10 |
+
/chat-tts.iml
|
README.md
CHANGED
|
@@ -5,7 +5,7 @@ colorFrom: blue
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.41.0
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
|
|
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.41.0
|
| 8 |
+
app_file: web\app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch and related libraries
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
torchaudio
|
| 5 |
+
|
| 6 |
+
# Hugging Face transformers library
|
| 7 |
+
transformers
|
| 8 |
+
|
| 9 |
+
# Configuration management with OmegaConf
|
| 10 |
+
omegaconf
|
| 11 |
+
|
| 12 |
+
# Interactive widgets for Jupyter Notebooks
|
| 13 |
+
ipywidgets
|
| 14 |
+
|
| 15 |
+
# Gradio for creating web UIs
|
| 16 |
+
gradio
|
| 17 |
+
|
| 18 |
+
# Vector quantization for PyTorch
|
| 19 |
+
vector_quantize_pytorch
|
| 20 |
+
# Hugging Face Hub client
|
| 21 |
+
huggingface_hub
|
| 22 |
+
|
| 23 |
+
vocos
|
| 24 |
+
|
| 25 |
+
spaces
|
| 26 |
+
|
| 27 |
+
ChatTTS
|
| 28 |
+
|
| 29 |
+
av
|
test/__init__.py
ADDED
|
File without changes
|
test/api.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import necessary libraries and configure settings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import ChatTTS
|
| 5 |
+
import os,sys
|
| 6 |
+
from common_test import *
|
| 7 |
+
|
| 8 |
+
now_dir = os.getcwd()
|
| 9 |
+
sys.path.append(now_dir)
|
| 10 |
+
from tool.logger import get_logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
torch._dynamo.config.cache_size_limit = 64
|
| 14 |
+
torch._dynamo.config.suppress_errors = True
|
| 15 |
+
torch.set_float32_matmul_precision('high')
|
| 16 |
+
|
| 17 |
+
logger= get_logger("api")
|
| 18 |
+
# Initialize and load the model:
|
| 19 |
+
chat = ChatTTS.Chat()
|
| 20 |
+
if chat.load(source="custom", custom_path="D:\\chenjgspace\\ai-model\\chattts",coef=None):
|
| 21 |
+
print("Models loaded successfully.")
|
| 22 |
+
else:
|
| 23 |
+
print("Models load failed.")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
|
| 26 |
+
# Define the text input for inference (Support Batching)
|
| 27 |
+
texts = [
|
| 28 |
+
"我真的不敢相信,他那么年轻武功居然这么好",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
#使用随机种子数,会导致每次生成的音频文件都是随机的音色
|
| 33 |
+
rand_spk = chat.sample_random_speaker()
|
| 34 |
+
print(rand_spk) # save it for later timbre recovery
|
| 35 |
+
|
| 36 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
| 37 |
+
spk_emb = rand_spk, # add sampled speaker
|
| 38 |
+
temperature = .3, # using custom temperature
|
| 39 |
+
top_P = 0.7, # top P decode
|
| 40 |
+
top_K = 20, # top K decode
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
###################################
|
| 44 |
+
# For sentence level manual control.
|
| 45 |
+
|
| 46 |
+
# use oral_(0-9), laugh_(0-2), break_(0-7)
|
| 47 |
+
# to generate special token in text to synthesize.
|
| 48 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
| 49 |
+
prompt='[oral_2][laugh_0][break_6]',
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
wavs = chat.infer(
|
| 53 |
+
texts,
|
| 54 |
+
params_refine_text=params_refine_text,
|
| 55 |
+
params_infer_code=params_infer_code,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Perform inference and play the generated audio
|
| 60 |
+
#wavs = chat.infer(texts)
|
| 61 |
+
#Audio(wavs[0], rate=24_000, autoplay=True)
|
| 62 |
+
|
| 63 |
+
# Save the generated audio
|
| 64 |
+
#torchaudio.save("D:\\Download\\output.wav", torch.from_numpy(wavs[0]), 24000)
|
| 65 |
+
prefix_name = "D:\\Download\\" + get_date_time()
|
| 66 |
+
|
| 67 |
+
for index, wav in enumerate(wavs):
|
| 68 |
+
save_mp3_file(wav, index, prefix_name)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
test/common_test.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import time
|
| 3 |
+
import os,sys
|
| 4 |
+
|
| 5 |
+
now_dir = os.getcwd()
|
| 6 |
+
sys.path.append(now_dir)
|
| 7 |
+
from tool.logger import get_logger
|
| 8 |
+
|
| 9 |
+
logger=get_logger("common-test")
|
| 10 |
+
def save_mp3_file(wav, index, prefix_name):
|
| 11 |
+
from tool.pcm import pcm_arr_to_mp3_view
|
| 12 |
+
data = pcm_arr_to_mp3_view(wav)
|
| 13 |
+
mp3_filename = prefix_name + "_" + str(index) + ".mp3"
|
| 14 |
+
with open(mp3_filename, "wb") as f:
|
| 15 |
+
f.write(data)
|
| 16 |
+
logger.info(f"Audio saved to {mp3_filename}")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_date_time():
|
| 20 |
+
# 获取当前时间戳(秒级别)
|
| 21 |
+
current_timestamp = int(time.time())
|
| 22 |
+
# 将时间戳转换为datetime对象
|
| 23 |
+
current_datetime = datetime.datetime.fromtimestamp(current_timestamp)
|
| 24 |
+
return current_datetime.strftime("%Y-%m-%d-%H-%M-%S")
|
tool/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .av import load_audio
|
| 2 |
+
from .pcm import pcm_arr_to_mp3_view
|
| 3 |
+
from .np import float_to_int16
|
| 4 |
+
from .ctx import TorchSeedContext
|
tool/av.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BufferedWriter, BytesIO
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
import av
|
| 6 |
+
from av.audio.resampler import AudioResampler
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
video_format_dict: Dict[str, str] = {
|
| 11 |
+
"m4a": "mp4",
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
audio_format_dict: Dict[str, str] = {
|
| 15 |
+
"ogg": "libvorbis",
|
| 16 |
+
"mp4": "aac",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def wav2(i: BytesIO, o: BufferedWriter, format: str):
|
| 21 |
+
"""
|
| 22 |
+
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L20
|
| 23 |
+
"""
|
| 24 |
+
inp = av.open(i, "r")
|
| 25 |
+
format = video_format_dict.get(format, format)
|
| 26 |
+
out = av.open(o, "w", format=format)
|
| 27 |
+
format = audio_format_dict.get(format, format)
|
| 28 |
+
|
| 29 |
+
ostream = out.add_stream(format)
|
| 30 |
+
|
| 31 |
+
for frame in inp.decode(audio=0):
|
| 32 |
+
for p in ostream.encode(frame):
|
| 33 |
+
out.mux(p)
|
| 34 |
+
|
| 35 |
+
for p in ostream.encode(None):
|
| 36 |
+
out.mux(p)
|
| 37 |
+
|
| 38 |
+
out.close()
|
| 39 |
+
inp.close()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_audio(file: str, sr: int) -> np.ndarray:
|
| 43 |
+
"""
|
| 44 |
+
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if not Path(file).exists():
|
| 48 |
+
raise FileNotFoundError(f"File not found: {file}")
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
container = av.open(file)
|
| 52 |
+
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
|
| 53 |
+
|
| 54 |
+
# Estimated maximum total number of samples to pre-allocate the array
|
| 55 |
+
# AV stores length in microseconds by default
|
| 56 |
+
estimated_total_samples = int(container.duration * sr // 1_000_000)
|
| 57 |
+
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
|
| 58 |
+
|
| 59 |
+
offset = 0
|
| 60 |
+
for frame in container.decode(audio=0):
|
| 61 |
+
frame.pts = None # Clear presentation timestamp to avoid resampling issues
|
| 62 |
+
resampled_frames = resampler.resample(frame)
|
| 63 |
+
for resampled_frame in resampled_frames:
|
| 64 |
+
frame_data = resampled_frame.to_ndarray()[0]
|
| 65 |
+
end_index = offset + len(frame_data)
|
| 66 |
+
|
| 67 |
+
# Check if decoded_audio has enough space, and resize if necessary
|
| 68 |
+
if end_index > decoded_audio.shape[0]:
|
| 69 |
+
decoded_audio = np.resize(decoded_audio, end_index + 1)
|
| 70 |
+
|
| 71 |
+
decoded_audio[offset:end_index] = frame_data
|
| 72 |
+
offset += len(frame_data)
|
| 73 |
+
|
| 74 |
+
# Truncate the array to the actual size
|
| 75 |
+
decoded_audio = decoded_audio[:offset]
|
| 76 |
+
except Exception as e:
|
| 77 |
+
raise RuntimeError(f"Failed to load audio: {e}")
|
| 78 |
+
|
| 79 |
+
return decoded_audio
|
tool/ctx.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TorchSeedContext:
|
| 5 |
+
def __init__(self, seed):
|
| 6 |
+
self.seed = seed
|
| 7 |
+
self.state = None
|
| 8 |
+
|
| 9 |
+
def __enter__(self):
|
| 10 |
+
self.state = torch.random.get_rng_state()
|
| 11 |
+
torch.manual_seed(self.seed)
|
| 12 |
+
|
| 13 |
+
def __exit__(self, type, value, traceback):
|
| 14 |
+
torch.random.set_rng_state(self.state)
|
tool/func.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
seed_min = 1
|
| 6 |
+
seed_max = 4294967295
|
| 7 |
+
|
| 8 |
+
seeds = {
|
| 9 |
+
"旁白": {"seed": 4444},
|
| 10 |
+
"中年女性": {"seed": 7869},
|
| 11 |
+
"年轻女性": {"seed": 6615},
|
| 12 |
+
"中年男性": {"seed": 4099},
|
| 13 |
+
"年轻男性": {"seed": 6653},
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
# 音色选项:用于预置合适的音色
|
| 17 |
+
voices = {
|
| 18 |
+
"旁白": {"seed": 2},
|
| 19 |
+
"Timbre1": {"seed": 1111},
|
| 20 |
+
"Timbre2": {"seed": 2222},
|
| 21 |
+
"Timbre3": {"seed": 3333},
|
| 22 |
+
"Timbre4": {"seed": 4444},
|
| 23 |
+
"Timbre5": {"seed": 5555},
|
| 24 |
+
"Timbre6": {"seed": 6666},
|
| 25 |
+
"Timbre7": {"seed": 7777},
|
| 26 |
+
"Timbre8": {"seed": 8888},
|
| 27 |
+
"Timbre9": {"seed": 9999},
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
def on_voice_change(vocie_selection):
|
| 31 |
+
return voices.get(vocie_selection)["seed"]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def generate_seed():
|
| 35 |
+
return gr.update(value=random.randint(seed_min, seed_max))
|
tool/logger/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .log import get_logger
|
tool/logger/log.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform, sys
|
| 2 |
+
import logging
|
| 3 |
+
from datetime import datetime, timezone
|
| 4 |
+
|
| 5 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
| 6 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 7 |
+
logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING)
|
| 8 |
+
logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING)
|
| 9 |
+
|
| 10 |
+
# from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96
|
| 11 |
+
colorCodePanic = "\x1b[1;31m"
|
| 12 |
+
colorCodeFatal = "\x1b[1;31m"
|
| 13 |
+
colorCodeError = "\x1b[31m"
|
| 14 |
+
colorCodeWarn = "\x1b[33m"
|
| 15 |
+
colorCodeInfo = "\x1b[37m"
|
| 16 |
+
colorCodeDebug = "\x1b[32m"
|
| 17 |
+
colorCodeTrace = "\x1b[36m"
|
| 18 |
+
colorReset = "\x1b[0m"
|
| 19 |
+
|
| 20 |
+
log_level_color_code = {
|
| 21 |
+
logging.DEBUG: colorCodeDebug,
|
| 22 |
+
logging.INFO: colorCodeInfo,
|
| 23 |
+
logging.WARN: colorCodeWarn,
|
| 24 |
+
logging.ERROR: colorCodeError,
|
| 25 |
+
logging.FATAL: colorCodeFatal,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
log_level_msg_str = {
|
| 29 |
+
logging.DEBUG: "DEBU",
|
| 30 |
+
logging.INFO: "INFO",
|
| 31 |
+
logging.WARN: "WARN",
|
| 32 |
+
logging.ERROR: "ERRO",
|
| 33 |
+
logging.FATAL: "FATL",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Formatter(logging.Formatter):
|
| 38 |
+
def __init__(self, color=platform.system().lower() != "windows"):
|
| 39 |
+
# https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone
|
| 40 |
+
self.tz = datetime.now(timezone.utc).astimezone().tzinfo
|
| 41 |
+
self.color = color
|
| 42 |
+
|
| 43 |
+
def format(self, record: logging.LogRecord):
|
| 44 |
+
logstr = "[" + datetime.now(self.tz).strftime("%z %Y%m%d %H:%M:%S") + "] ["
|
| 45 |
+
if self.color:
|
| 46 |
+
logstr += log_level_color_code.get(record.levelno, colorCodeInfo)
|
| 47 |
+
logstr += log_level_msg_str.get(record.levelno, record.levelname)
|
| 48 |
+
if self.color:
|
| 49 |
+
logstr += colorReset
|
| 50 |
+
if sys.version_info >= (3, 9):
|
| 51 |
+
fn = record.filename.removesuffix(".py")
|
| 52 |
+
elif record.filename.endswith(".py"):
|
| 53 |
+
fn = record.filename[:-3]
|
| 54 |
+
logstr += f"] {str(record.name)} | {fn} | {str(record.msg)%record.args}"
|
| 55 |
+
return logstr
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False):
|
| 59 |
+
logger = logging.getLogger(name)
|
| 60 |
+
logger.setLevel(lv)
|
| 61 |
+
if remove_exist and logger.hasHandlers():
|
| 62 |
+
logger.handlers.clear()
|
| 63 |
+
if not logger.hasHandlers():
|
| 64 |
+
syslog = logging.StreamHandler()
|
| 65 |
+
syslog.setFormatter(Formatter())
|
| 66 |
+
logger.addHandler(syslog)
|
| 67 |
+
else:
|
| 68 |
+
for h in logger.handlers:
|
| 69 |
+
h.setFormatter(Formatter())
|
| 70 |
+
if format_root:
|
| 71 |
+
for h in logger.root.handlers:
|
| 72 |
+
h.setFormatter(Formatter())
|
| 73 |
+
return logger
|
tool/np.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from numba import jit
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@jit
|
| 8 |
+
def float_to_int16(audio: np.ndarray) -> np.ndarray:
|
| 9 |
+
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
|
| 10 |
+
am = 32767 * 32768 // am
|
| 11 |
+
return np.multiply(audio, am).astype(np.int16)
|
tool/pcm.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import wave
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .np import float_to_int16
|
| 7 |
+
from .av import wav2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pcm_arr_to_mp3_view(wav: np.ndarray):
|
| 11 |
+
buf = BytesIO()
|
| 12 |
+
with wave.open(buf, "wb") as wf:
|
| 13 |
+
wf.setnchannels(1) # Mono channel
|
| 14 |
+
wf.setsampwidth(2) # Sample width in bytes
|
| 15 |
+
wf.setframerate(24000) # Sample rate in Hz
|
| 16 |
+
wf.writeframes(float_to_int16(wav))
|
| 17 |
+
buf.seek(0, 0)
|
| 18 |
+
buf2 = BytesIO()
|
| 19 |
+
wav2(buf, buf2, "mp3")
|
| 20 |
+
buf.seek(0, 0)
|
| 21 |
+
return buf2.getbuffer()
|
web/__init__.py
ADDED
|
File without changes
|
web/app.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
if sys.platform == "darwin":
|
| 4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 5 |
+
|
| 6 |
+
now_dir = os.getcwd()
|
| 7 |
+
sys.path.append(now_dir)
|
| 8 |
+
|
| 9 |
+
from tool.logger import get_logger
|
| 10 |
+
import ChatTTS
|
| 11 |
+
import argparse
|
| 12 |
+
import gradio as gr
|
| 13 |
+
from tool.func import *
|
| 14 |
+
from tool.ctx import TorchSeedContext
|
| 15 |
+
from tool.np import *
|
| 16 |
+
|
| 17 |
+
logger = get_logger("app")
|
| 18 |
+
|
| 19 |
+
# Initialize and load the model:
|
| 20 |
+
chat = ChatTTS.Chat()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def init_chat(args):
|
| 24 |
+
global chat
|
| 25 |
+
# 获取启动模式
|
| 26 |
+
MODEL = os.getenv('MODEL')
|
| 27 |
+
logger.info("loading ChatTTS model..., start MODEL:" + str(MODEL))
|
| 28 |
+
source = "custom"
|
| 29 |
+
# huggingface 部署模式下,模型则直接使用hf的模型数据
|
| 30 |
+
if MODEL == "HF":
|
| 31 |
+
source = "huggingface"
|
| 32 |
+
|
| 33 |
+
if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts", coef=None):
|
| 34 |
+
print("Models loaded successfully.")
|
| 35 |
+
else:
|
| 36 |
+
print("Models load failed.")
|
| 37 |
+
sys.exit(1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main(args):
|
| 41 |
+
with gr.Blocks() as demo:
|
| 42 |
+
gr.Markdown("# ChatTTS demo")
|
| 43 |
+
with gr.Row():
|
| 44 |
+
with gr.Column(scale=1):
|
| 45 |
+
text_input = gr.Textbox(
|
| 46 |
+
label="转换内容",
|
| 47 |
+
lines=4,
|
| 48 |
+
max_lines=4,
|
| 49 |
+
placeholder="Please Input Text...",
|
| 50 |
+
value="柔柔的,浓浓的,痴痴的风,牵引起心底灵动的思潮;情愫悠悠,思情绵绵,风里默坐,红尘中的浅醉,诗词中的优柔,任那自在飞花轻似梦的情怀,裁一束霓衣,织就清浅淡薄的安寂。",
|
| 51 |
+
interactive=True,
|
| 52 |
+
)
|
| 53 |
+
with gr.Row():
|
| 54 |
+
refine_text_checkBox = gr.Checkbox(
|
| 55 |
+
label="是否优化文本,如是则先对文本内容做优化分词",
|
| 56 |
+
interactive=True,
|
| 57 |
+
value=True
|
| 58 |
+
)
|
| 59 |
+
temperature_slider = gr.Slider(
|
| 60 |
+
minimum=0.00001,
|
| 61 |
+
maximum=1.0,
|
| 62 |
+
step=0.00001,
|
| 63 |
+
value=0.3,
|
| 64 |
+
interactive=True,
|
| 65 |
+
label="模型 Temperature 参数设置"
|
| 66 |
+
)
|
| 67 |
+
top_p_slider = gr.Slider(
|
| 68 |
+
minimum=0.1,
|
| 69 |
+
maximum=0.9,
|
| 70 |
+
step=0.05,
|
| 71 |
+
value=0.7,
|
| 72 |
+
label="模型 top_P 参数设置",
|
| 73 |
+
interactive=True,
|
| 74 |
+
)
|
| 75 |
+
top_k_slider = gr.Slider(
|
| 76 |
+
minimum=1,
|
| 77 |
+
maximum=20,
|
| 78 |
+
step=1,
|
| 79 |
+
value=20,
|
| 80 |
+
label="模型 top_K 参数设置",
|
| 81 |
+
interactive=True,
|
| 82 |
+
)
|
| 83 |
+
with gr.Row():
|
| 84 |
+
voice_selection = gr.Dropdown(
|
| 85 |
+
label="Timbre",
|
| 86 |
+
choices=voices.keys(),
|
| 87 |
+
value="旁白",
|
| 88 |
+
interactive=True,
|
| 89 |
+
show_label=True
|
| 90 |
+
)
|
| 91 |
+
audio_seed_input = gr.Number(
|
| 92 |
+
value=2,
|
| 93 |
+
label="音色种子",
|
| 94 |
+
interactive=True,
|
| 95 |
+
minimum=seed_min,
|
| 96 |
+
maximum=seed_max,
|
| 97 |
+
)
|
| 98 |
+
generate_audio_seed = gr.Button("随机生成音色种子", interactive=True)
|
| 99 |
+
text_seed_input = gr.Number(
|
| 100 |
+
value=42,
|
| 101 |
+
label="文本种子",
|
| 102 |
+
interactive=True,
|
| 103 |
+
minimum=seed_min,
|
| 104 |
+
maximum=seed_max,
|
| 105 |
+
)
|
| 106 |
+
generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
|
| 107 |
+
|
| 108 |
+
with gr.Row():
|
| 109 |
+
spk_emb_text = gr.Textbox(
|
| 110 |
+
label="Speaker Embedding",
|
| 111 |
+
max_lines=3,
|
| 112 |
+
show_copy_button=True,
|
| 113 |
+
interactive=False,
|
| 114 |
+
scale=2,
|
| 115 |
+
|
| 116 |
+
)
|
| 117 |
+
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
|
| 118 |
+
|
| 119 |
+
with gr.Row():
|
| 120 |
+
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
| 121 |
+
|
| 122 |
+
with gr.Row():
|
| 123 |
+
text_output = gr.Textbox(
|
| 124 |
+
label="输出文本",
|
| 125 |
+
interactive=False,
|
| 126 |
+
show_copy_button=True,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
audio_output = gr.Audio(
|
| 130 |
+
label="输出音频",
|
| 131 |
+
value=None,
|
| 132 |
+
format="wav",
|
| 133 |
+
autoplay=False,
|
| 134 |
+
streaming=False,
|
| 135 |
+
interactive=False,
|
| 136 |
+
show_label=True,
|
| 137 |
+
waveform_options=gr.WaveformOptions(
|
| 138 |
+
sample_rate=24000,
|
| 139 |
+
),
|
| 140 |
+
)
|
| 141 |
+
# 针对页面元素新增 监听事件
|
| 142 |
+
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
|
| 143 |
+
|
| 144 |
+
audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
|
| 145 |
+
|
| 146 |
+
generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
|
| 147 |
+
|
| 148 |
+
generate_text_seed.click(fn=generate_seed,outputs=text_seed_input)
|
| 149 |
+
|
| 150 |
+
# reload_chat_button.click()
|
| 151 |
+
|
| 152 |
+
generate_button.click(fn=get_chat_infer_text,
|
| 153 |
+
inputs=[text_input,
|
| 154 |
+
text_seed_input,
|
| 155 |
+
refine_text_checkBox
|
| 156 |
+
],
|
| 157 |
+
outputs=[text_output]
|
| 158 |
+
).then(fn=get_chat_infer_audio,
|
| 159 |
+
inputs=[text_output,
|
| 160 |
+
temperature_slider,
|
| 161 |
+
top_p_slider,
|
| 162 |
+
top_k_slider,
|
| 163 |
+
audio_seed_input,
|
| 164 |
+
spk_emb_text
|
| 165 |
+
],
|
| 166 |
+
outputs=[audio_output])
|
| 167 |
+
# 初始化 spk_emb_text 数值
|
| 168 |
+
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
|
| 169 |
+
logger.info("元素初始化完成,启动gradio服务=======")
|
| 170 |
+
|
| 171 |
+
# 运行gradio服务
|
| 172 |
+
demo.launch(
|
| 173 |
+
server_name=args.server_name,
|
| 174 |
+
server_port=args.server_port,
|
| 175 |
+
inbrowser=True,
|
| 176 |
+
show_api=False)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_chat_infer_audio(chat_txt,
|
| 181 |
+
temperature_slider,
|
| 182 |
+
top_p_slider,
|
| 183 |
+
top_k_slider,
|
| 184 |
+
audio_seed_input,
|
| 185 |
+
spk_emb_text):
|
| 186 |
+
logger.info("========开始生成音频文件=====")
|
| 187 |
+
#音频参数设置
|
| 188 |
+
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
| 189 |
+
spk_emb=spk_emb_text, # add sampled speaker
|
| 190 |
+
temperature=temperature_slider, # using custom temperature
|
| 191 |
+
top_P=top_p_slider, # top P decode
|
| 192 |
+
top_K=top_k_slider, # top K decode
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
with TorchSeedContext(audio_seed_input):
|
| 196 |
+
wav = chat.infer(
|
| 197 |
+
text=chat_txt,
|
| 198 |
+
skip_refine_text=True, #跳过文本优化
|
| 199 |
+
params_infer_code=params_infer_code,
|
| 200 |
+
)
|
| 201 |
+
yield 24000, float_to_int16(wav[0]).T
|
| 202 |
+
|
| 203 |
+
def get_chat_infer_text(text,seed,refine_text_checkBox):
|
| 204 |
+
|
| 205 |
+
logger.info("========开始优化文本内容=====")
|
| 206 |
+
global chat
|
| 207 |
+
if not refine_text_checkBox:
|
| 208 |
+
logger.info("========文本内容无需优化=====")
|
| 209 |
+
return text
|
| 210 |
+
|
| 211 |
+
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
| 212 |
+
prompt='[oral_2][laugh_0][break_6]',
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
with TorchSeedContext(seed):
|
| 216 |
+
chat_text = chat.infer(
|
| 217 |
+
text=text,
|
| 218 |
+
skip_refine_text=False,
|
| 219 |
+
refine_text_only=True, #仅返回优化后文本内容
|
| 220 |
+
params_refine_text=params_refine_text,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return chat_text[0] if isinstance(chat_text, list) else chat_text
|
| 224 |
+
|
| 225 |
+
def on_audio_seed_change(audio_seed_input):
|
| 226 |
+
global chat
|
| 227 |
+
with TorchSeedContext(audio_seed_input):
|
| 228 |
+
rand_spk = chat.sample_random_speaker()
|
| 229 |
+
return rand_spk
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--server_name", type=str, default="0.0.0.0", help="server name"
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument("--server_port", type=int, default=8080, help="server port")
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--custom_path", type=str, default="D:\\chenjgspace\\ai-model\\chattts", help="custom model path"
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--coef", type=str, default=None, help="custom dvae coefficient"
|
| 243 |
+
)
|
| 244 |
+
args = parser.parse_args()
|
| 245 |
+
init_chat(args)
|
| 246 |
+
main(args)
|