File size: 5,578 Bytes
9c79341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
from tqdm import tqdm
import torch
import numpy as np
import random
import scipy.io as scio
import src.utils.audio as audio
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = max(int(audio_length / bit_per_frames), 30) # Ít nhất 30 frames
return int(num_frames * bit_per_frames), num_frames
def generate_blink_seq(num_frames):
ratio = np.zeros((num_frames,1))
frame_id = 0
while frame_id in range(num_frames):
start = 80
if frame_id+start+9 <= num_frames - 1:
ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
frame_id = frame_id+start+9
else:
break
return ratio
def generate_blink_seq_randomly(num_frames):
ratio = np.zeros((num_frames,1))
if num_frames <= 20:
return ratio
# Ensure valid range for random selection
min_start = min(10, num_frames)
max_start = min(int(num_frames/2), 70)
# Fix case where range would be invalid
if min_start >= max_start:
max_start = min_start + 5 # Add small buffer
try:
start = random.choice(range(min_start, max_start))
except IndexError:
return ratio # Return zeros if still can't generate
frame_id = 0
while frame_id in range(num_frames):
if frame_id+start+5 <= num_frames - 1:
ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
frame_id = frame_id+start+5
else:
break
return ratio
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
syncnet_mel_step_size = 16
fps = 25
pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
if idlemode:
num_frames = int(length_of_audio * 25)
indiv_mels = np.zeros((num_frames, 80, 16))
else:
try:
wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
# Ensure minimum number of frames
if num_frames < 5: # Absolute minimum for processing
raise ValueError(f"Audio too short: only {num_frames} frames generated")
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy()
indiv_mels = []
for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [min(max(item, 0), orig_mel.shape[0]-1) for item in seq]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels)
except Exception as e:
raise RuntimeError(f"Audio processing failed: {str(e)}")
# More robust blink sequence generation
try:
ratio = generate_blink_seq_randomly(num_frames)
except Exception as e:
print(f"Warning: Blink sequence generation failed, using zeros: {str(e)}")
ratio = np.zeros((num_frames,1))
try:
source_semantics_dict = scio.loadmat(first_coeff_path)
ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70]
ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
except Exception as e:
raise RuntimeError(f"Failed to load source semantics: {str(e)}")
if ref_eyeblink_coeff_path is not None:
try:
ratio[:num_frames] = 0
refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
refeyeblink_num_frames = refeyeblink_coeff.shape[0]
if refeyeblink_num_frames < num_frames:
div = num_frames//refeyeblink_num_frames
re = num_frames%refeyeblink_num_frames
refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)]
refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64])
refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0)
ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64]
except Exception as e:
print(f"Warning: Eyeblink reference processing failed: {str(e)}")
# Convert to tensors
try:
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0)
ratio = torch.FloatTensor(ratio).unsqueeze(0) if use_blink else torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0)
indiv_mels = indiv_mels.to(device)
ratio = ratio.to(device)
ref_coeff = ref_coeff.to(device)
except Exception as e:
raise RuntimeError(f"Tensor conversion failed: {str(e)}")
return {
'indiv_mels': indiv_mels,
'ref': ref_coeff,
'num_frames': num_frames,
'ratio_gt': ratio,
'audio_name': audio_name,
'pic_name': pic_name
}
|