import os, sys import torch import numpy as np import torchvision import os from PIL import Image from torchvision import transforms import torch.nn.functional as F from moviepy.editor import * import audio def load_image(filename, size): img = Image.open(filename).convert('RGB') img = img.resize((size, size)) img = np.asarray(img) img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256 return img / 255.0 def img_preprocessing(img_path, size): img = load_image(img_path, size) # [0, 1] img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1] imgs_norm = (img - 0.5) * 2.0 # [-1, 1] return imgs_norm def vid_preprocessing(vid_path): import av container = av.open(vid_path) stream = container.streams.video[0] fps = float(stream.average_rate) frames = [] for frame in container.decode(video=0): frames.append(torch.from_numpy(frame.to_ndarray(format='rgb24'))) container.close() vid = torch.stack(frames).permute(0, 3, 1, 2).unsqueeze(0) # (1, T, 3, H, W) vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1] transform = transforms.Compose([ transforms.Resize((256, 256)), ]) resized_frames = torch.stack([transform(frame) for frame in vid_norm[0]], dim=0).unsqueeze(0) return resized_frames, fps def save_video(vid_target_recon, save_path, fps): vid = vid_target_recon.permute(0, 2, 3, 4, 1) vid = vid.clamp(-1, 1).cpu() vid = ((vid - vid.min()) / (vid.max() - vid.min()) * 255).type('torch.ByteTensor') import imageio writer = imageio.get_writer(save_path, fps=fps, codec='libx264', quality=8) for frame in vid[0]: writer.append_data(frame.numpy()) writer.close() def parse_audio_length(audio_length, sr, fps): bit_per_frames = sr / fps num_frames = int(audio_length / bit_per_frames) audio_length = int(num_frames * bit_per_frames) return audio_length, num_frames def crop_pad_audio(wav, audio_length): if len(wav) > audio_length: wav = wav[:audio_length] elif len(wav) < audio_length: wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) return wav def get_mel(audio_path): wav = audio.load_wav(audio_path, 16000) wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) wav = crop_pad_audio(wav, wav_length) orig_mel = audio.melspectrogram(wav).T spec = orig_mel.copy() # nframes 80 indiv_mels = [] fps = 25 syncnet_mel_step_size = 16 for i in range(num_frames): start_frame_num = i-2 start_idx = int(80. * (start_frame_num / float(fps))) end_idx = start_idx + syncnet_mel_step_size seq = list(range(start_idx, end_idx)) seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] m = spec[seq, :] indiv_mels.append(m.T) indiv_mels = np.asarray(indiv_mels) # T 80 16 _device = os.environ.get("CMET_DEVICE", "cpu") indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0).to(_device) source_audio_feature = indiv_mels.type(torch.FloatTensor).to(_device) mel_input = source_audio_feature # bs T 1 80 16 bs = mel_input.shape[0] T = mel_input.shape[1] audiox = mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 return audiox, bs, T def audio_preprocessing(wav_path): source_audio_feature, bs, T = get_mel(wav_path) return source_audio_feature, bs, T def conv_feat(features, k_size, weight=None, sigma=1.0): c = features.shape[1] # torch.Size([101, 500]) if weight is None: pad = k_size // 2 k = np.zeros(k_size).astype(np.float64) for x in range(-pad, k_size-pad): k[x+pad] = np.exp(-x**2 / (2 * (sigma ** 2))) k = k / k.sum() print(k) # [0.27406862 0.45186276 0.27406862] else: k_size = len(weight) k = np.array(weight) pad = k_size // 2 print(k) k = torch.from_numpy(k).to(features.device).float().unsqueeze(0).unsqueeze(0) k = k.repeat(c, 1, 1) features = features.unsqueeze(0).permute(0, 2, 1) # [1, 512, n] features = F.conv1d(features, k, padding=pad, groups=c) features = features.permute(0, 2, 1).squeeze(0) return features def _load(checkpoint_path, device): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(model, path, device='cuda'): print("Load checkpoint from: {}".format(path)) checkpoint = _load(path, device) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): if k[:6] == 'module': new_k=k.replace('module.', '', 1) else: new_k =k new_s[new_k] = v model.load_state_dict(new_s) model = model.to(device) return model.eval()