Spaces:
Running on Zero
Running on Zero
| import os, sys | |
| import torch | |
| import numpy as np | |
| import torchvision | |
| import os | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from moviepy.editor import * | |
| import audio | |
| def load_image(filename, size): | |
| img = Image.open(filename).convert('RGB') | |
| img = img.resize((size, size)) | |
| img = np.asarray(img) | |
| img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256 | |
| return img / 255.0 | |
| def img_preprocessing(img_path, size): | |
| img = load_image(img_path, size) # [0, 1] | |
| img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1] | |
| imgs_norm = (img - 0.5) * 2.0 # [-1, 1] | |
| return imgs_norm | |
| def vid_preprocessing(vid_path): | |
| import av | |
| container = av.open(vid_path) | |
| stream = container.streams.video[0] | |
| fps = float(stream.average_rate) | |
| frames = [] | |
| for frame in container.decode(video=0): | |
| frames.append(torch.from_numpy(frame.to_ndarray(format='rgb24'))) | |
| container.close() | |
| vid = torch.stack(frames).permute(0, 3, 1, 2).unsqueeze(0) # (1, T, 3, H, W) | |
| vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1] | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| ]) | |
| resized_frames = torch.stack([transform(frame) for frame in vid_norm[0]], dim=0).unsqueeze(0) | |
| return resized_frames, fps | |
| def save_video(vid_target_recon, save_path, fps): | |
| vid = vid_target_recon.permute(0, 2, 3, 4, 1) | |
| vid = vid.clamp(-1, 1).cpu() | |
| vid = ((vid - vid.min()) / (vid.max() - vid.min()) * 255).type('torch.ByteTensor') | |
| torchvision.io.write_video(save_path, vid[0], fps=fps) | |
| def parse_audio_length(audio_length, sr, fps): | |
| bit_per_frames = sr / fps | |
| num_frames = int(audio_length / bit_per_frames) | |
| audio_length = int(num_frames * bit_per_frames) | |
| return audio_length, num_frames | |
| def crop_pad_audio(wav, audio_length): | |
| if len(wav) > audio_length: | |
| wav = wav[:audio_length] | |
| elif len(wav) < audio_length: | |
| wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) | |
| return wav | |
| def get_mel(audio_path): | |
| wav = audio.load_wav(audio_path, 16000) | |
| wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) | |
| wav = crop_pad_audio(wav, wav_length) | |
| orig_mel = audio.melspectrogram(wav).T | |
| spec = orig_mel.copy() # nframes 80 | |
| indiv_mels = [] | |
| fps = 25 | |
| syncnet_mel_step_size = 16 | |
| for i in range(num_frames): | |
| start_frame_num = i-2 | |
| start_idx = int(80. * (start_frame_num / float(fps))) | |
| end_idx = start_idx + syncnet_mel_step_size | |
| seq = list(range(start_idx, end_idx)) | |
| seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] | |
| m = spec[seq, :] | |
| indiv_mels.append(m.T) | |
| indiv_mels = np.asarray(indiv_mels) # T 80 16 | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0).to(_device) | |
| source_audio_feature = indiv_mels.type(torch.FloatTensor).to(_device) | |
| mel_input = source_audio_feature # bs T 1 80 16 | |
| bs = mel_input.shape[0] | |
| T = mel_input.shape[1] | |
| audiox = mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 | |
| return audiox, bs, T | |
| def audio_preprocessing(wav_path): | |
| source_audio_feature, bs, T = get_mel(wav_path) | |
| return source_audio_feature, bs, T | |
| def conv_feat(features, k_size, weight=None, sigma=1.0): | |
| c = features.shape[1] # torch.Size([101, 500]) | |
| if weight is None: | |
| pad = k_size // 2 | |
| k = np.zeros(k_size).astype(np.float64) | |
| for x in range(-pad, k_size-pad): | |
| k[x+pad] = np.exp(-x**2 / (2 * (sigma ** 2))) | |
| k = k / k.sum() | |
| print(k) # [0.27406862 0.45186276 0.27406862] | |
| else: | |
| k_size = len(weight) | |
| k = np.array(weight) | |
| pad = k_size // 2 | |
| print(k) | |
| k = torch.from_numpy(k).to(features.device).float().unsqueeze(0).unsqueeze(0) | |
| k = k.repeat(c, 1, 1) | |
| features = features.unsqueeze(0).permute(0, 2, 1) # [1, 512, n] | |
| features = F.conv1d(features, k, padding=pad, groups=c) | |
| features = features.permute(0, 2, 1).squeeze(0) | |
| return features | |
| def _load(checkpoint_path, device): | |
| if device == 'cuda': | |
| checkpoint = torch.load(checkpoint_path) | |
| else: | |
| checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) | |
| return checkpoint | |
| def load_model(model, path, device='cuda'): | |
| print("Load checkpoint from: {}".format(path)) | |
| checkpoint = _load(path, device) | |
| s = checkpoint["state_dict"] | |
| new_s = {} | |
| for k, v in s.items(): | |
| if k[:6] == 'module': | |
| new_k=k.replace('module.', '', 1) | |
| else: | |
| new_k =k | |
| new_s[new_k] = v | |
| model.load_state_dict(new_s) | |
| model = model.to(device) | |
| return model.eval() |