File size: 3,679 Bytes
55f3ab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# cd /mnt/pfs-gv8sxa/tts/dhg/zqy/code/AMD2
# env CUDA_VISIBLE_DEVICES=1 python extract_motion.py --data_pkl /mnt/pfs-gv8sxa/tts/dhg/zqy/code/AMD2/dataset/path/a2m_split/mead_test.pkl

from model import AMD_models,AMDModel
from model.utils import save_cfg, vae_encode, vae_decode, freeze, print_param_num,model_load_pretrain
import torchvision.transforms as transforms
from diffusers import AutoencoderKL
from decord import VideoReader
from decord import cpu, gpu
import os
import sys
import torch
import argparse
import pickle
import time
import gc

# args
parser = argparse.ArgumentParser(description='Process audio files.')
parser.add_argument('--data_pkl', type=str, required=True, help='data pkl')
args = parser.parse_args()

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# transform
sample_size = (256,256)
pixel_transforms = transforms.Compose([
            transforms.Resize(min(sample_size)),
            transforms.CenterCrop(sample_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])

# dir
amd_config = '/mnt/pfs-gv8sxa/tts/dhg/zqy/exp/amd-s-t4-d128-nonorm/config.json'
amd_ckpt = '/mnt/pfs-gv8sxa/tts/dhg/zqy/exp/amd-s-t4-d128-nonorm/checkpoints/checkpoint-188000/model.safetensors'
vae_version = '/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/model-checkpoints/sd-vae-ft-mse'


# vae
vae = AutoencoderKL.from_pretrained(vae_version, subfolder="vae").requires_grad_(False)
vae.to(device)

# amd_model
amd_model = AMDModel.from_config(AMDModel.load_config(amd_config))
model_load_pretrain(amd_model,amd_ckpt,not_load_keyword='abcabcacbd',strict=True)
amd_model.to(device)
print(f'######### load AMD weight from {amd_ckpt} #############')

# load pkl
with open(args.data_pkl, 'rb') as f:
    datas = pickle.load(f)


# log
log = []
log_path = args.data_pkl.split('.')[0]+'_log.pkl'

total_num = len(datas)
for i,data in enumerate(datas):
    try:
        video_path = data['video_path']
        motion_path = data['motion_path']

        if os.path.exists(motion_path):
            continue

        # read
        video_reader = VideoReader(video_path, ctx=cpu(0))
        idx = [j for j in range(len(video_reader))]
        videos = torch.from_numpy(video_reader.get_batch(idx).asnumpy()).permute(0, 3, 1, 2).contiguous() #(T,H,W,C)->(T,C,H,W)
        videos = videos / 255.0
        videos = pixel_transforms(videos)
        videos = videos.to(device)
        videos = videos.unsqueeze(0) #(N,T,C,H,W)

        with torch.no_grad():
            z = vae_encode(vae,videos).to(device) # N,T,c,h,w

            # # test loss
            # test_video = z[:,:4]
            # test_ref_img = z[:,-4:]
            # _,_,loss_dict = amd_model(test_video,test_ref_img)

            # get motion
            motion = amd_model.extract_motion(z)
            motion = motion.squeeze(0)

            # log
            log.append({
                'video_path':video_path,
                'motion_path':motion_path,
                'num_frames':len(video_reader),
            })

        # save
        torch.save(motion,motion_path)

        gc.collect()
        torch.cuda.empty_cache()
        del video_reader,z,motion,videos

        # print
        print(f'{i} has been processed, total:{total_num}')


    except Exception as e:
        # file_name = self.metadata_list[idx]['name']
        # print(file_name)
        print('error',e)
        gc.collect()
        torch.cuda.empty_cache()
        continue

with open(log_path, 'wb') as file:
    # 使用pickle模块的dump方法保存数据
    pickle.dump(log, file)

print('******** All Finished **********')