parlerTTSmini / app.py
Ryanus's picture
Update app.py
6c24caa verified
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import gradio as gr
import os
import time
import glob
MODEL_ID = "parler-tts/parler-tts-mini-v1" # v1 速度最快[2][5]
device = "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
SAVE_DIR = "saved_audios"
os.makedirs(SAVE_DIR, exist_ok=True)
def tts(text, description, progress=gr.Progress()):
progress(0, desc="開始處理輸入")
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
progress(0.2, desc="描述編碼完成")
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
progress(0.4, desc="文本編碼完成")
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
progress(0.8, desc="語音生成完成,正在寫入檔案")
audio_arr = generation.cpu().numpy().squeeze()
filename = f"tts_{int(time.time())}.wav"
save_path = os.path.join(SAVE_DIR, filename)
sf.write(save_path, audio_arr, model.config.sampling_rate)
progress(1, desc="完成")
return save_path
def list_saved_audios():
audio_files = sorted(
glob.glob(os.path.join(SAVE_DIR, "*.wav")),
key=os.path.getmtime,
reverse=True
)
return audio_files
with gr.Blocks() as demo:
gr.Markdown("## Parler-TTS Mini v1 (CPU)|進度條+自動儲存+音檔檢視")
with gr.Tab("語音生成"):
text = gr.Textbox(label="輸入文字", value="Hello, this is Parler-TTS mini running on CPU.")
description = gr.Textbox(label="語音描述", value="A young female speaker, neutral tone, high quality audio.")
out_audio = gr.Audio(label="生成語音", type="filepath")
btn = gr.Button("生成語音")
btn.click(fn=tts, inputs=[text, description], outputs=out_audio)
with gr.Tab("檢視已儲存音檔"):
audio_list = gr.Files(label="已儲存音檔", file_count="multiple", type="filepath")
refresh_btn = gr.Button("重新整理列表")
refresh_btn.click(fn=list_saved_audios, inputs=[], outputs=audio_list)
demo.load(fn=list_saved_audios, inputs=[], outputs=audio_list)
if __name__ == "__main__":
demo.queue().launch()