hg-5090 / codes /main_multithreading_digitalhuman.py
wsj1995's picture
feat: codes
07d89db
import threading
import time
import traceback
# from multiprocessing.dummy import Queue
import cv2
import librosa
import numpy as np
from face_detect_utils.face_detect import FaceDetect, pfpld
from landmark2face_wy.digitalhuman_interface import DigitalHumanModel
from preprocess_audio_and_3dmm import op
from wenet.compute_ctc_att_bnf import get_weget
from wenet.compute_ctc_att_bnf import load_ppg_model
from multiprocessing import Pool, Queue, Process, set_start_method
from Deep3DFaceRecon_pytorch.util.load_mats import load_lm3d
# import pyaudio
# from twisted.internet import reactor
import subprocess
import os
def feature_extraction_wenet(audio_file, fps, wenet_model, mfccnorm=True, section=560000):
rate = 16000
win_size = 20
if type(audio_file) == str:
sig, rate = librosa.load(audio_file, sr=rate, duration=None)
else:
sig = audio_file
time_duration = len(sig) / rate
cnts = range(int(time_duration * fps))
indexs = []
f_wenet_all = get_weget(audio_file, wenet_model, section)
f_wenet_all = np.swapaxes(f_wenet_all, 0, 1)
for cnt in cnts:
c_count = int((cnt - 1) * (f_wenet_all.shape[1] - 20) / int(time_duration * fps)) + 10
if f_wenet_all.T[c_count - win_size // 2:c_count + win_size // 2].shape[0] == 20:
indexs.append([c_count - win_size // 2, c_count + win_size // 2])
return indexs, f_wenet_all
def get_aud_feat1(wav_fragment, fps, wenet_model):
return feature_extraction_wenet(wav_fragment, fps, wenet_model)
def warp_imgs(imgs_data):
caped_img2 = {idx: {'imgs_data': it, 'idx': idx} for it, idx in
zip(imgs_data, range(len(imgs_data)))}
return caped_img2
def get_complete_imgs(output_img_list, start_index, params):
out_shape, output_resize, drivered_imgs_data, Y1_list, Y2_list, X1_list, X2_list = params
complete_imgs = []
for i, mask_B_pre in enumerate(output_img_list):
img_idx = start_index + i
image = drivered_imgs_data[img_idx]
y1, y2, x1, x2 = Y1_list[img_idx], Y2_list[img_idx], X1_list[img_idx], X2_list[img_idx]
mask_B_pre_resize = cv2.resize(mask_B_pre, (y2 - y1, x2 - x1))
if y1 < 0:
mask_B_pre_resize = mask_B_pre_resize[:, -y1:]
y1 = 0
if y2 > image.shape[1]:
mask_B_pre_resize = mask_B_pre_resize[:, :-(y2 - image.shape[1])]
y2 = image.shape[1]
if x1 < 0:
mask_B_pre_resize = mask_B_pre_resize[-x1:, :]
x1 = 0
if x2 > image.shape[0]:
mask_B_pre_resize = mask_B_pre_resize[:-(x2 - image.shape[0]), :]
x2 = image.shape[0]
image[x1:x2, y1:y2] = mask_B_pre_resize
image = cv2.resize(image, (out_shape[1] // output_resize, out_shape[0] // output_resize))
complete_imgs.append(image)
return complete_imgs
# 获取生成的人头图片序列
def get_blend_imgs(batch_size, audio_data, wenet_feature, face_data_dict, params, digital_human_model):
result_img_list = []
for idx in range(len(audio_data) // batch_size + 1):
print('\r{}/{}'.format((idx + 1) * batch_size, len(audio_data)), end='')
if idx < len(audio_data) // batch_size:
start_index = idx * batch_size
output_img_list = digital_human_model.inference([audio_data, wenet_feature], face_data_dict,
batch_size,
start_index, params)
complete_imgs = get_complete_imgs(output_img_list, start_index, params)
result_img_list += complete_imgs
else:
this_batch = len(audio_data) % batch_size
if this_batch > 0:
start_index = idx * batch_size
output_img_list = digital_human_model.inference([audio_data, wenet_feature],
face_data_dict,
this_batch, start_index, params)
complete_imgs = get_complete_imgs(output_img_list, start_index, params)
result_img_list += complete_imgs
return result_img_list
def load_wav(audio_queue, audio_path):
print(f'in load_wav function')
wav_arr, _ = librosa.load(audio_path, sr=16000)
CHUNK = 500
RATE = 16000
while True:
try:
audio_list = []
t0 = time.time()
for i in range(0, int(RATE / CHUNK)):
data = wav_arr[i * CHUNK:(i + 1) * CHUNK]
audio_list.append(data)
audio_queue.put(audio_list)
wav_arr = wav_arr[RATE:]
print(f'append wav cost {time.time() - t0}')
except Exception as e:
print(traceback.format_exc())
def microphone_audio(audio_queue):
while True:
CHUNK = 500
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
try:
while True:
audio_list = []
# fps 32
for i in range(0, int(RATE / CHUNK)):
data = stream.read(CHUNK)
audio_list.append(data)
audio_queue.put(audio_list)
except Exception as e:
print(traceback.format_exc())
finally:
stream.stop_stream()
stream.close()
p.terminate()
def drivered_video(drivered_queue, video_file=None):
while True:
print(f'in drivered function')
cap = cv2.VideoCapture(0) if video_file is None else cv2.VideoCapture(video_file)
# cap0 = cv2.VideoCapture(1 + cv2.CAP_DSHOW) # 视频流
drivered_list = []
count_f = 0
t0 = time.time()
while cap.isOpened():
count_f += 1
# 读取视频画面
if count_f % 2 == 1:
continue
t_st = time.time()
ret, frame = cap.read()
# print(f'read img cost {time.time() - t_st} {count_f}')
# cv2.imshow('drivered_video', frame)
# cv2.waitKey(1)
if ret:
drivered_list.append(frame)
if count_f % fps == 0:
drivered_queue.put(drivered_list)
count_f = 0
drivered_list = []
print(f'append imgs cost {time.time() - t0}')
t0 = time.time()
else:
cap.release()
cap.release()
def capture_video(driver_queue):
while True:
# 生成读取摄像头对象
cap = cv2.VideoCapture(1 + cv2.CAP_DSHOW) # 视频流
capture_list = []
count_f = 0
while cap.isOpened():
count_f += 1
# 读取摄像头画面
if count_f % 2 == 1:
continue
ret, frame = cap.read()
if ret:
# cv2.imshow('capture_video', frame)
# cv2.waitKey(24)
capture_list.append(frame)
if count_f % fps == 0:
# print('count_f----', count_f)
# print('capture_list----', len(capture_list))
driver_queue.put(capture_list)
count_f = 0
capture_list = []
cap.release() # 释放摄像头
def audio_features(audio_features_queue, audio_queue):
wenet_model = load_ppg_model('wenet/examples/aishell/aidata/conf/train_conformer_multi_cn.yaml',
'wenet/examples/aishell/aidata/exp/conformer/wenetmodel.pt',
'cuda')
while True:
if not audio_queue.empty():
crop_st = time.time()
wav_fragment = audio_queue.get()
wav_fragment = np.array(wav_fragment).reshape(-1)
# wav_fragment = np.frombuffer(wav_fragment, dtype=np.int16)
audio_data, audio_wenet_feature = get_aud_feat1(wav_fragment, fps=fps, wenet_model=wenet_model)
audio_features_queue.put([audio_data, audio_wenet_feature])
crop_et = time.time()
print('audio_features-耗时:================', crop_et - crop_st)
def audio_transfer(audio_features_queue, drivered_queue, output_imgs_queue, pth_path):
output_resize = 1 ##输出视频缩放系数
face2face_batch_size = 4
digital_human_model = DigitalHumanModel(pth_path)
scrfd_detector = FaceDetect(cpu=False, model_path='face_detect_utils/resources/')
scrfd_predictor = pfpld(cpu=False, model_path='face_detect_utils/resources/')
lm3d_std = load_lm3d('Deep3DFaceRecon_pytorch/BFM')
while True:
try:
if not audio_features_queue.empty():
print('开始合成表情图片----------')
face_st = time.time()
audio_data, audio_wenet_feature = audio_features_queue.get()
drivered_imgs_data = drivered_queue.get()
out_shape = drivered_imgs_data[0].shape
drivered_st = time.time()
caped_drivered_img2 = warp_imgs(drivered_imgs_data)
drivered_op = op(caped_drivered_img2, digital_human_model.drivered_wh, scrfd_detector, scrfd_predictor,
lm3d_std, digital_human_model.img_size, False)
drivered_op.flow()
drivered_face_dict = drivered_op.mp_dict
x1_list, x2_list, y1_list, y2_list = [], [], [], []
for idx in range(len(drivered_face_dict)):
facebox = drivered_face_dict[idx]['bounding_box']
x1_list.append(facebox[0])
x2_list.append(facebox[1])
y1_list.append(facebox[2])
y2_list.append(facebox[3])
##############被驱动人如果检测不到人脸
drivered_exceptlist = []
frame_len = len(drivered_face_dict.keys())
for i in range(frame_len):
if len(drivered_face_dict[i]['bounding_box_p']) == 4:
break
drivered_exceptlist.append(i)
print(drivered_exceptlist, '-------------------------------------')
for i in drivered_exceptlist:
drivered_face_dict[i]['bounding_box_p'] = drivered_face_dict[len(drivered_exceptlist)][
'bounding_box_p']
drivered_face_dict[i]['bounding_box'] = drivered_face_dict[len(drivered_exceptlist)][
'bounding_box']
drivered_face_dict[i]['crop_lm'] = drivered_face_dict[len(drivered_exceptlist)]['crop_lm']
drivered_face_dict[i]['crop_img'] = drivered_face_dict[len(drivered_exceptlist)]['crop_img']
keylist = list(drivered_face_dict.keys())
keylist.sort()
for it in keylist:
# print(it)
if len(drivered_face_dict[it]['bounding_box_p']) != 4:
print(it, '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
# bbx_smooth[it, :] = bbx_smooth[it - 2, :]
drivered_face_dict[it]['bounding_box_p'] = drivered_face_dict[it - 1]['bounding_box_p']
drivered_face_dict[it]['bounding_box'] = drivered_face_dict[it - 1]['bounding_box']
drivered_face_dict[it]['crop_lm'] = drivered_face_dict[it - 1]['crop_lm']
drivered_face_dict[it]['crop_img'] = drivered_face_dict[it - 1]['crop_img']
drivered_et = time.time()
print('被驱动人数据耗时:================', drivered_et - drivered_st)
####################被驱动人的视频帧数据处理#######################################################
params = [out_shape, output_resize, drivered_imgs_data, y1_list, y2_list, x1_list, x2_list]
output_imgs = get_blend_imgs(face2face_batch_size, audio_data, audio_wenet_feature,
drivered_face_dict, params, digital_human_model)
cv2.imwrite('result.jpg', output_imgs[0])
face_et = time.time()
print('expression_transfer-耗时:================', face_et - face_st)
output_imgs_queue.put(output_imgs)
except Exception as e:
print(traceback.format_exc())
def write_video(result_queue, save_dir, audio_path):
output_name = os.path.basename(audio_path).split('_add_valume.wav')[0] + '_stream.avi'
fourcc = cv2.VideoWriter_fourcc(*'XVID')
output_path = os.path.join(save_dir, output_name)
video_write = cv2.VideoWriter(output_path, fourcc, 25.0,
(1080, 1920))
while True:
try:
value_ = result_queue.get(True, timeout=10)
for result_img in value_:
video_write.write(result_img)
except:
command = 'ffmpeg -y -i {} -i {} -b:v 5000k -strict -2 -q:v 1 {}'.format(audio_path, output_path,
output_path.replace('avi', 'mp4'))
subprocess.call(command, shell=True)
print(f'###### write over')
break
def video_synthesis(output_imgs_queue):
img_id = 0
st = time.time()
while True:
while not output_imgs_queue.empty():
et = time.time()
print('表情迁移首次出现耗时======================:', et - st)
output_imgs = output_imgs_queue.get()
# if not os.path.exists('test_data/output'):
# os.makedirs('test_data/output')
# ###写测试图片
# for temp_img in output_imgs:
# cv2.imwrite(os.path.join('test_data/output', str(img_id) + '.jpg'), temp_img)
# img_id += 1
for img in output_imgs:
time.sleep(0.03125)
cv2.imshow('output_imgs', img)
cv2.waitKey(1)
st = time.time()
fps = 25
if __name__ == '__main__':
##############
set_start_method('spawn', force=True)
audio_queue = Queue(10)
drivered_queue = Queue(10)
audio_features_queue = Queue(10)
face3d_features_queue = Queue(10)
output_imgs_queue = Queue(10)
process_list = []
process_list.append(Process(target=load_wav, args=(audio_queue, 'test_data/audio/driver_add_valume.wav',)))
process_list.append(Process(target=audio_features, args=(audio_features_queue, audio_queue)))
process_list.append(Process(target=drivered_video, args=(drivered_queue, './landmark2face_wy/checkpoints/hy/1.mp4')))
process_list.append(Process(target=audio_transfer, args=(
audio_features_queue, drivered_queue, output_imgs_queue, './landmark2face_wy/checkpoints/hy/11.pth')))
process_list.append(Process(target=write_video,
args=(output_imgs_queue, 'test_data/result/', 'test_data/audio/driver_add_valume.wav')))
[p.start() for p in process_list]
[p.join() for p in process_list]
################# reactor
# reactor.callInThread(load_wav, audio_queue)
# reactor.callInThread(audio_features, audio_features_queue, audio_queue)
# reactor.callInThread(drivered_video, drivered_queue, 'test_data/video/爱夏4k_test.mp4')
# reactor.callInThread(audio_transfer, audio_features_queue, drivered_queue, output_imgs_queue)
# reactor.callInThread(video_synthesis, output_imgs_queue)
# reactor.run()
##################### origin
# set_start_method('spawn', force=True)
# face_st = time.time()
# pool = Pool(10)
# # 音频
# pool.apply_async(load_wav, args=(audio_queue,))
# # 被驱动视频
# pool.apply_async(drivered_video, args=(drivered_queue, 'test_data/video/爱夏4k_test.mp4'))
# # 音频特征提取
# pool.apply_async(audio_features, args=(audio_features_queue, audio_queue))
# # 口型驱动
# pool.apply_async(audio_transfer, args=(audio_features_queue, drivered_queue, output_imgs_queue))
# # 视频合成
# pool.apply_async(video_synthesis, args=(output_imgs_queue,))
#
# pool.close()
# pool.join()
#
# while True:
# print('audio_queue size={}'.format(audio_queue.qsize()))
# print('face3d_features_queue size={}'.format(face3d_features_queue.qsize()))
# print('output_imgs_queue size={}'.format(output_imgs_queue.qsize()))
# face_et = time.time()
# print('执行时间-耗时:================', face_et - face_st)
# print('=======================================================')
# time.sleep(3)