API_MC_AI / SadTalker /src /generate_batch.py
duyv's picture
Upload 139 files
9c79341 verified
raw
history blame
5.58 kB
import os
from tqdm import tqdm
import torch
import numpy as np
import random
import scipy.io as scio
import src.utils.audio as audio
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = max(int(audio_length / bit_per_frames), 30) # Ít nhất 30 frames
return int(num_frames * bit_per_frames), num_frames
def generate_blink_seq(num_frames):
ratio = np.zeros((num_frames,1))
frame_id = 0
while frame_id in range(num_frames):
start = 80
if frame_id+start+9 <= num_frames - 1:
ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
frame_id = frame_id+start+9
else:
break
return ratio
def generate_blink_seq_randomly(num_frames):
ratio = np.zeros((num_frames,1))
if num_frames <= 20:
return ratio
# Ensure valid range for random selection
min_start = min(10, num_frames)
max_start = min(int(num_frames/2), 70)
# Fix case where range would be invalid
if min_start >= max_start:
max_start = min_start + 5 # Add small buffer
try:
start = random.choice(range(min_start, max_start))
except IndexError:
return ratio # Return zeros if still can't generate
frame_id = 0
while frame_id in range(num_frames):
if frame_id+start+5 <= num_frames - 1:
ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
frame_id = frame_id+start+5
else:
break
return ratio
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
syncnet_mel_step_size = 16
fps = 25
pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
if idlemode:
num_frames = int(length_of_audio * 25)
indiv_mels = np.zeros((num_frames, 80, 16))
else:
try:
wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
# Ensure minimum number of frames
if num_frames < 5: # Absolute minimum for processing
raise ValueError(f"Audio too short: only {num_frames} frames generated")
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy()
indiv_mels = []
for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [min(max(item, 0), orig_mel.shape[0]-1) for item in seq]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels)
except Exception as e:
raise RuntimeError(f"Audio processing failed: {str(e)}")
# More robust blink sequence generation
try:
ratio = generate_blink_seq_randomly(num_frames)
except Exception as e:
print(f"Warning: Blink sequence generation failed, using zeros: {str(e)}")
ratio = np.zeros((num_frames,1))
try:
source_semantics_dict = scio.loadmat(first_coeff_path)
ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70]
ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
except Exception as e:
raise RuntimeError(f"Failed to load source semantics: {str(e)}")
if ref_eyeblink_coeff_path is not None:
try:
ratio[:num_frames] = 0
refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
refeyeblink_num_frames = refeyeblink_coeff.shape[0]
if refeyeblink_num_frames < num_frames:
div = num_frames//refeyeblink_num_frames
re = num_frames%refeyeblink_num_frames
refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)]
refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64])
refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0)
ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64]
except Exception as e:
print(f"Warning: Eyeblink reference processing failed: {str(e)}")
# Convert to tensors
try:
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0)
ratio = torch.FloatTensor(ratio).unsqueeze(0) if use_blink else torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0)
indiv_mels = indiv_mels.to(device)
ratio = ratio.to(device)
ref_coeff = ref_coeff.to(device)
except Exception as e:
raise RuntimeError(f"Tensor conversion failed: {str(e)}")
return {
'indiv_mels': indiv_mels,
'ref': ref_coeff,
'num_frames': num_frames,
'ratio_gt': ratio,
'audio_name': audio_name,
'pic_name': pic_name
}