C-MET / src /util.py
coldhyuk's picture
Upload folder using huggingface_hub
78165cb verified
import os, sys
import torch
import numpy as np
import torchvision
import os
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from moviepy.editor import *
import audio
def load_image(filename, size):
img = Image.open(filename).convert('RGB')
img = img.resize((size, size))
img = np.asarray(img)
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
return img / 255.0
def img_preprocessing(img_path, size):
img = load_image(img_path, size) # [0, 1]
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
return imgs_norm
def vid_preprocessing(vid_path):
import av
container = av.open(vid_path)
stream = container.streams.video[0]
fps = float(stream.average_rate)
frames = []
for frame in container.decode(video=0):
frames.append(torch.from_numpy(frame.to_ndarray(format='rgb24')))
container.close()
vid = torch.stack(frames).permute(0, 3, 1, 2).unsqueeze(0) # (1, T, 3, H, W)
vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
transform = transforms.Compose([
transforms.Resize((256, 256)),
])
resized_frames = torch.stack([transform(frame) for frame in vid_norm[0]], dim=0).unsqueeze(0)
return resized_frames, fps
def save_video(vid_target_recon, save_path, fps):
vid = vid_target_recon.permute(0, 2, 3, 4, 1)
vid = vid.clamp(-1, 1).cpu()
vid = ((vid - vid.min()) / (vid.max() - vid.min()) * 255).type('torch.ByteTensor')
torchvision.io.write_video(save_path, vid[0], fps=fps)
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = int(audio_length / bit_per_frames)
audio_length = int(num_frames * bit_per_frames)
return audio_length, num_frames
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
def get_mel(audio_path):
wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy() # nframes 80
indiv_mels = []
fps = 25
syncnet_mel_step_size = 16
for i in range(num_frames):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16
_device = "cuda" if torch.cuda.is_available() else "cpu"
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0).to(_device)
source_audio_feature = indiv_mels.type(torch.FloatTensor).to(_device)
mel_input = source_audio_feature # bs T 1 80 16
bs = mel_input.shape[0]
T = mel_input.shape[1]
audiox = mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
return audiox, bs, T
def audio_preprocessing(wav_path):
source_audio_feature, bs, T = get_mel(wav_path)
return source_audio_feature, bs, T
def conv_feat(features, k_size, weight=None, sigma=1.0):
c = features.shape[1] # torch.Size([101, 500])
if weight is None:
pad = k_size // 2
k = np.zeros(k_size).astype(np.float64)
for x in range(-pad, k_size-pad):
k[x+pad] = np.exp(-x**2 / (2 * (sigma ** 2)))
k = k / k.sum()
print(k) # [0.27406862 0.45186276 0.27406862]
else:
k_size = len(weight)
k = np.array(weight)
pad = k_size // 2
print(k)
k = torch.from_numpy(k).to(features.device).float().unsqueeze(0).unsqueeze(0)
k = k.repeat(c, 1, 1)
features = features.unsqueeze(0).permute(0, 2, 1) # [1, 512, n]
features = F.conv1d(features, k, padding=pad, groups=c)
features = features.permute(0, 2, 1).squeeze(0)
return features
def _load(checkpoint_path, device):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
return checkpoint
def load_model(model, path, device='cuda'):
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path, device)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
if k[:6] == 'module':
new_k=k.replace('module.', '', 1)
else:
new_k =k
new_s[new_k] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()