|
|
import os |
|
|
import random |
|
|
import gradio as gr |
|
|
|
|
|
from src.cost_time import calculate_time |
|
|
|
|
|
from configs import * |
|
|
os.environ["GRADIO_TEMP_DIR"]= './temp' |
|
|
|
|
|
description = """<p style="text-align: center; font-weight: bold;"> |
|
|
<span style="font-size: 28px;">Linly 智能对话系统 (Linly-Talker)</span> |
|
|
<br> |
|
|
<span style="font-size: 18px;" id="paper-info"> |
|
|
[<a href="https://zhuanlan.zhihu.com/p/671006998" target="_blank">知乎</a>] |
|
|
[<a href="https://www.bilibili.com/video/BV1rN4y1a76x/" target="_blank">bilibili</a>] |
|
|
[<a href="https://github.com/Kedreamix/Linly-Talker" target="_blank">GitHub</a>] |
|
|
[<a herf="https://kedreamix.github.io/" target="_blank">个人主页</a>] |
|
|
</span> |
|
|
<br> |
|
|
<span>Linly-Talker 是一款智能 AI 对话系统,结合了大型语言模型 (LLMs) 与视觉模型,是一种新颖的人工智能交互方式。</span> |
|
|
</p> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
blink_every = True |
|
|
size_of_image = 256 |
|
|
preprocess_type = 'crop' |
|
|
facerender = 'facevid2vid' |
|
|
enhancer = False |
|
|
is_still_mode = False |
|
|
|
|
|
exp_weight = 1 |
|
|
|
|
|
use_ref_video = False |
|
|
ref_video = None |
|
|
ref_info = 'pose' |
|
|
use_idle_mode = False |
|
|
length_of_audio = 5 |
|
|
|
|
|
@calculate_time |
|
|
def TTS_response(text, |
|
|
voice, rate, volume, pitch, |
|
|
am, voc, lang, male, |
|
|
tts_method = 'PaddleTTS', save_path = 'answer.wav'): |
|
|
print(text, voice, rate, volume, pitch, am, voc, lang, male, tts_method, save_path) |
|
|
if tts_method == 'Edge-TTS': |
|
|
try: |
|
|
edgetts.predict(text, voice, rate, volume, pitch , 'answer.wav', 'answer.vtt') |
|
|
except: |
|
|
os.system(f'edge-tts --text "{text}" --voice {voice} --write-media answer.wav') |
|
|
return 'answer.wav' |
|
|
elif tts_method == 'PaddleTTS': |
|
|
paddletts.predict(text, am, voc, lang = lang, male=male, save_path = save_path) |
|
|
return save_path |
|
|
|
|
|
@calculate_time |
|
|
def Talker_response(source_image, source_video, method = 'SadTalker', driven_audio = '', batch_size = 2): |
|
|
|
|
|
if source_video: |
|
|
source_image = source_video |
|
|
print(source_image, method , driven_audio, batch_size) |
|
|
pose_style = random.randint(0, 45) |
|
|
if method == 'SadTalker': |
|
|
video = sadtalker.test2(source_image, |
|
|
driven_audio, |
|
|
preprocess_type, |
|
|
is_still_mode, |
|
|
enhancer, |
|
|
batch_size, |
|
|
size_of_image, |
|
|
pose_style, |
|
|
facerender, |
|
|
exp_weight, |
|
|
use_ref_video, |
|
|
ref_video, |
|
|
ref_info, |
|
|
use_idle_mode, |
|
|
length_of_audio, |
|
|
blink_every, |
|
|
fps=20) |
|
|
elif method == 'Wav2Lip': |
|
|
video = wav2lip.predict(source_image, driven_audio, batch_size) |
|
|
elif method == 'ER-NeRF': |
|
|
video = ernerf.predict(driven_audio) |
|
|
else: |
|
|
gr.Warning("不支持的方法:" + method) |
|
|
return None |
|
|
return video |
|
|
|
|
|
def main(): |
|
|
with gr.Blocks(analytics_enabled=False, title = 'Linly-Talker') as inference: |
|
|
gr.HTML(description) |
|
|
with gr.Row(equal_height=False): |
|
|
with gr.Column(variant='panel'): |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("图片人物"): |
|
|
source_image = gr.Image(label='Source image', type = 'filepath') |
|
|
|
|
|
with gr.Tab("视频人物"): |
|
|
source_video = gr.Video(label="Source video") |
|
|
|
|
|
with gr.Tabs(): |
|
|
input_audio = gr.Audio(sources=['upload', 'microphone'], type="filepath", label = '语音') |
|
|
input_text = gr.Textbox(label="Input Text", lines=3) |
|
|
with gr.Column(): |
|
|
tts_method = gr.Radio(["Edge-TTS", "PaddleTTS"], label="Text To Speech Method (Edge-TTS利用微软的TTS,PaddleSpeech是离线的TTS,不过第一次运行会自动下载模型)", |
|
|
value = 'Edge-TTS') |
|
|
|
|
|
with gr.Tabs("TTS Method"): |
|
|
|
|
|
with gr.Tab("Edge-TTS"): |
|
|
|
|
|
voice = gr.Dropdown(edgetts.SUPPORTED_VOICE, |
|
|
value='zh-CN-XiaoxiaoNeural', |
|
|
label="Voice") |
|
|
rate = gr.Slider(minimum=-100, |
|
|
maximum=100, |
|
|
value=0, |
|
|
step=1.0, |
|
|
label='Rate') |
|
|
volume = gr.Slider(minimum=0, |
|
|
maximum=100, |
|
|
value=100, |
|
|
step=1, |
|
|
label='Volume') |
|
|
pitch = gr.Slider(minimum=-100, |
|
|
maximum=100, |
|
|
value=0, |
|
|
step=1, |
|
|
label='Pitch') |
|
|
with gr.Tab("PaddleTTS"): |
|
|
am = gr.Dropdown(["FastSpeech2"], label="声学模型选择", value = 'FastSpeech2') |
|
|
voc = gr.Dropdown(["PWGan", "HifiGan"], label="声码器选择", value = 'PWGan') |
|
|
lang = gr.Dropdown(["zh", "en", "mix", "canton"], label="语言选择", value = 'zh') |
|
|
male = gr.Checkbox(label="男声(Male)", value=False) |
|
|
with gr.Column(variant='panel'): |
|
|
batch_size = gr.Slider(minimum=1, |
|
|
maximum=10, |
|
|
value=2, |
|
|
step=1, |
|
|
label='Talker Batch size') |
|
|
button_text = gr.Button('语音生成') |
|
|
button_text.click(fn=TTS_response,inputs=[input_text, voice, rate, volume, pitch, am, voc, lang, male, tts_method], |
|
|
outputs=[input_audio]) |
|
|
|
|
|
with gr.Column(variant='panel'): |
|
|
with gr.Tabs(): |
|
|
with gr.TabItem('数字人问答'): |
|
|
method = gr.Radio(choices = ['SadTalker', 'Wav2Lip', 'ER-NeRF'], value = 'SadTalker', label = '模型选择') |
|
|
gen_video = gr.Video(label="Generated video", format="mp4", scale=1, autoplay=True) |
|
|
video_button = gr.Button("提交", variant='primary') |
|
|
|
|
|
video_button.click(fn=Talker_response,inputs=[source_image, source_video, method, input_audio, batch_size] , |
|
|
outputs=[gen_video]) |
|
|
|
|
|
with gr.Row(): |
|
|
examples = [ |
|
|
[ |
|
|
'examples/source_image/full_body_2.png', |
|
|
'应对压力最有效的方法是什么?', |
|
|
], |
|
|
[ |
|
|
'examples/source_image/full_body_1.png', |
|
|
'如何进行时间管理?', |
|
|
], |
|
|
[ |
|
|
'examples/source_image/full3.png', |
|
|
'为什么有些人选择使用纸质地图或寻求方向,而不是依赖GPS设备或智能手机应用程序?', |
|
|
], |
|
|
[ |
|
|
'examples/source_image/full4.jpeg', |
|
|
'近日,苹果公司起诉高通公司,状告其未按照相关合约进行合作,高通方面尚未回应。这句话中“其”指的是谁?', |
|
|
], |
|
|
[ |
|
|
'examples/source_image/art_13.png', |
|
|
'三年级同学种树80颗,四、五年级种的棵树比三年级种的2倍多14棵,三个年级共种树多少棵?', |
|
|
], |
|
|
[ |
|
|
'examples/source_image/art_5.png', |
|
|
'撰写一篇交响乐音乐会评论,讨论乐团的表演和观众的整体体验。', |
|
|
], |
|
|
] |
|
|
gr.Examples(examples=examples, |
|
|
inputs=[ |
|
|
source_image, |
|
|
input_text, |
|
|
], |
|
|
) |
|
|
return inference |
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
from TFG import SadTalker |
|
|
sadtalker = SadTalker(lazy_load=True) |
|
|
except Exception as e: |
|
|
print("SadTalker Error: ", e) |
|
|
print("如果使用SadTalker,请先下载SadTalker模型") |
|
|
|
|
|
try: |
|
|
from TFG import Wav2Lip |
|
|
wav2lip = Wav2Lip("checkpoints/wav2lip_gan.pth") |
|
|
except Exception as e: |
|
|
print("Wav2Lip Error: ", e) |
|
|
print("如果使用Wav2Lip,请先下载Wav2Lip模型") |
|
|
|
|
|
try: |
|
|
from TFG import ERNeRF |
|
|
ernerf = ERNeRF() |
|
|
ernerf.init_model('checkpoints/Obama_ave.pth', 'checkpoints/Obama.json') |
|
|
except Exception as e: |
|
|
print("ERNeRF Error: ", e) |
|
|
print("如果使用ERNeRF,请先下载ERNeRF模型") |
|
|
|
|
|
try: |
|
|
from TTS import EdgeTTS |
|
|
edgetts = EdgeTTS() |
|
|
except Exception as e: |
|
|
print("EdgeTTS Error: ", e) |
|
|
print("如果使用EdgeTTS,请先下载EdgeTTS模型") |
|
|
|
|
|
try: |
|
|
from TTS import PaddleTTS |
|
|
paddletts = PaddleTTS() |
|
|
except Exception as e: |
|
|
print("PaddleTTS Error: ", e) |
|
|
print("如果使用PaddleTTS,请先下载PaddleTTS模型") |
|
|
|
|
|
gr.close_all() |
|
|
demo = main() |
|
|
demo.queue() |
|
|
|
|
|
demo.launch(server_name=ip, |
|
|
server_port=port, |
|
|
|
|
|
ssl_certfile=ssl_certfile, |
|
|
ssl_keyfile=ssl_keyfile, |
|
|
ssl_verify=False, |
|
|
debug=True) |