| import os, sys |
| import torch |
| import numpy as np |
| import torchvision |
| import os |
| from PIL import Image |
|
|
| from torchvision import transforms |
| import torch.nn.functional as F |
| from moviepy.editor import * |
| import audio |
|
|
|
|
| def load_image(filename, size): |
| img = Image.open(filename).convert('RGB') |
| img = img.resize((size, size)) |
| img = np.asarray(img) |
| img = np.transpose(img, (2, 0, 1)) |
|
|
| return img / 255.0 |
|
|
|
|
| def img_preprocessing(img_path, size): |
| img = load_image(img_path, size) |
| img = torch.from_numpy(img).unsqueeze(0).float() |
| imgs_norm = (img - 0.5) * 2.0 |
|
|
| return imgs_norm |
|
|
|
|
| def vid_preprocessing(vid_path): |
| import av |
| container = av.open(vid_path) |
| stream = container.streams.video[0] |
| fps = float(stream.average_rate) |
| frames = [] |
| for frame in container.decode(video=0): |
| frames.append(torch.from_numpy(frame.to_ndarray(format='rgb24'))) |
| container.close() |
|
|
| vid = torch.stack(frames).permute(0, 3, 1, 2).unsqueeze(0) |
| vid_norm = (vid / 255.0 - 0.5) * 2.0 |
| transform = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| ]) |
| resized_frames = torch.stack([transform(frame) for frame in vid_norm[0]], dim=0).unsqueeze(0) |
| return resized_frames, fps |
|
|
|
|
| def save_video(vid_target_recon, save_path, fps): |
| vid = vid_target_recon.permute(0, 2, 3, 4, 1) |
| vid = vid.clamp(-1, 1).cpu() |
| vid = ((vid - vid.min()) / (vid.max() - vid.min()) * 255).type('torch.ByteTensor') |
|
|
| import imageio |
| writer = imageio.get_writer(save_path, fps=fps, codec='libx264', quality=8) |
| for frame in vid[0]: |
| writer.append_data(frame.numpy()) |
| writer.close() |
|
|
|
|
| def parse_audio_length(audio_length, sr, fps): |
| bit_per_frames = sr / fps |
|
|
| num_frames = int(audio_length / bit_per_frames) |
| audio_length = int(num_frames * bit_per_frames) |
|
|
| return audio_length, num_frames |
|
|
| def crop_pad_audio(wav, audio_length): |
| if len(wav) > audio_length: |
| wav = wav[:audio_length] |
| elif len(wav) < audio_length: |
| wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) |
| return wav |
|
|
| def get_mel(audio_path): |
|
|
| wav = audio.load_wav(audio_path, 16000) |
| wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) |
| wav = crop_pad_audio(wav, wav_length) |
| orig_mel = audio.melspectrogram(wav).T |
| spec = orig_mel.copy() |
| indiv_mels = [] |
| fps = 25 |
| syncnet_mel_step_size = 16 |
|
|
|
|
| for i in range(num_frames): |
| start_frame_num = i-2 |
| start_idx = int(80. * (start_frame_num / float(fps))) |
| end_idx = start_idx + syncnet_mel_step_size |
| seq = list(range(start_idx, end_idx)) |
| seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] |
| m = spec[seq, :] |
| indiv_mels.append(m.T) |
| indiv_mels = np.asarray(indiv_mels) |
| _device = os.environ.get("CMET_DEVICE", "cpu") |
| indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0).to(_device) |
| source_audio_feature = indiv_mels.type(torch.FloatTensor).to(_device) |
|
|
| mel_input = source_audio_feature |
| bs = mel_input.shape[0] |
| T = mel_input.shape[1] |
| audiox = mel_input.view(-1, 1, 80, 16) |
|
|
| return audiox, bs, T |
|
|
|
|
| def audio_preprocessing(wav_path): |
| source_audio_feature, bs, T = get_mel(wav_path) |
|
|
| return source_audio_feature, bs, T |
|
|
|
|
| def conv_feat(features, k_size, weight=None, sigma=1.0): |
| c = features.shape[1] |
| if weight is None: |
| pad = k_size // 2 |
| k = np.zeros(k_size).astype(np.float64) |
| for x in range(-pad, k_size-pad): |
| k[x+pad] = np.exp(-x**2 / (2 * (sigma ** 2))) |
| k = k / k.sum() |
| print(k) |
| else: |
| k_size = len(weight) |
| k = np.array(weight) |
| pad = k_size // 2 |
| print(k) |
| |
| k = torch.from_numpy(k).to(features.device).float().unsqueeze(0).unsqueeze(0) |
| k = k.repeat(c, 1, 1) |
| features = features.unsqueeze(0).permute(0, 2, 1) |
| features = F.conv1d(features, k, padding=pad, groups=c) |
| features = features.permute(0, 2, 1).squeeze(0) |
|
|
| return features |
|
|
|
|
| def _load(checkpoint_path, device): |
| if device == 'cuda': |
| checkpoint = torch.load(checkpoint_path) |
| else: |
| checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) |
| return checkpoint |
|
|
| def load_model(model, path, device='cuda'): |
| print("Load checkpoint from: {}".format(path)) |
| checkpoint = _load(path, device) |
| s = checkpoint["state_dict"] |
| new_s = {} |
| for k, v in s.items(): |
| if k[:6] == 'module': |
| new_k=k.replace('module.', '', 1) |
| else: |
| new_k =k |
| new_s[new_k] = v |
| model.load_state_dict(new_s) |
| model = model.to(device) |
| return model.eval() |