File size: 3,679 Bytes
bd546bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# cd /mnt/pfs-gv8sxa/tts/dhg/zqy/code/AMD2
# env CUDA_VISIBLE_DEVICES=1 python extract_motion.py --data_pkl /mnt/pfs-gv8sxa/tts/dhg/zqy/code/AMD2/dataset/path/a2m_split/mead_test.pkl
from model import AMD_models,AMDModel
from model.utils import save_cfg, vae_encode, vae_decode, freeze, print_param_num,model_load_pretrain
import torchvision.transforms as transforms
from diffusers import AutoencoderKL
from decord import VideoReader
from decord import cpu, gpu
import os
import sys
import torch
import argparse
import pickle
import time
import gc
# args
parser = argparse.ArgumentParser(description='Process audio files.')
parser.add_argument('--data_pkl', type=str, required=True, help='data pkl')
args = parser.parse_args()
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# transform
sample_size = (256,256)
pixel_transforms = transforms.Compose([
transforms.Resize(min(sample_size)),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
# dir
amd_config = '/mnt/pfs-gv8sxa/tts/dhg/zqy/exp/amd-s-t4-d128-nonorm/config.json'
amd_ckpt = '/mnt/pfs-gv8sxa/tts/dhg/zqy/exp/amd-s-t4-d128-nonorm/checkpoints/checkpoint-188000/model.safetensors'
vae_version = '/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/model-checkpoints/sd-vae-ft-mse'
# vae
vae = AutoencoderKL.from_pretrained(vae_version, subfolder="vae").requires_grad_(False)
vae.to(device)
# amd_model
amd_model = AMDModel.from_config(AMDModel.load_config(amd_config))
model_load_pretrain(amd_model,amd_ckpt,not_load_keyword='abcabcacbd',strict=True)
amd_model.to(device)
print(f'######### load AMD weight from {amd_ckpt} #############')
# load pkl
with open(args.data_pkl, 'rb') as f:
datas = pickle.load(f)
# log
log = []
log_path = args.data_pkl.split('.')[0]+'_log.pkl'
total_num = len(datas)
for i,data in enumerate(datas):
try:
video_path = data['video_path']
motion_path = data['motion_path']
if os.path.exists(motion_path):
continue
# read
video_reader = VideoReader(video_path, ctx=cpu(0))
idx = [j for j in range(len(video_reader))]
videos = torch.from_numpy(video_reader.get_batch(idx).asnumpy()).permute(0, 3, 1, 2).contiguous() #(T,H,W,C)->(T,C,H,W)
videos = videos / 255.0
videos = pixel_transforms(videos)
videos = videos.to(device)
videos = videos.unsqueeze(0) #(N,T,C,H,W)
with torch.no_grad():
z = vae_encode(vae,videos).to(device) # N,T,c,h,w
# # test loss
# test_video = z[:,:4]
# test_ref_img = z[:,-4:]
# _,_,loss_dict = amd_model(test_video,test_ref_img)
# get motion
motion = amd_model.extract_motion(z)
motion = motion.squeeze(0)
# log
log.append({
'video_path':video_path,
'motion_path':motion_path,
'num_frames':len(video_reader),
})
# save
torch.save(motion,motion_path)
gc.collect()
torch.cuda.empty_cache()
del video_reader,z,motion,videos
# print
print(f'{i} has been processed, total:{total_num}')
except Exception as e:
# file_name = self.metadata_list[idx]['name']
# print(file_name)
print('error',e)
gc.collect()
torch.cuda.empty_cache()
continue
with open(log_path, 'wb') as file:
# 使用pickle模块的dump方法保存数据
pickle.dump(log, file)
print('******** All Finished **********')
|