|
|
|
|
|
|
|
|
|
|
|
from model import AMD_models,AMDModel |
|
|
from model.utils import save_cfg, vae_encode, vae_decode, freeze, print_param_num,model_load_pretrain |
|
|
import torchvision.transforms as transforms |
|
|
from diffusers import AutoencoderKL |
|
|
from decord import VideoReader |
|
|
from decord import cpu, gpu |
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import argparse |
|
|
import pickle |
|
|
import time |
|
|
import gc |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Process audio files.') |
|
|
parser.add_argument('--data_pkl', type=str, required=True, help='data pkl') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
sample_size = (256,256) |
|
|
pixel_transforms = transforms.Compose([ |
|
|
transforms.Resize(min(sample_size)), |
|
|
transforms.CenterCrop(sample_size), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
|
]) |
|
|
|
|
|
|
|
|
amd_config = '/mnt/pfs-gv8sxa/tts/dhg/zqy/exp/amd-s-t4-d128-nonorm/config.json' |
|
|
amd_ckpt = '/mnt/pfs-gv8sxa/tts/dhg/zqy/exp/amd-s-t4-d128-nonorm/checkpoints/checkpoint-188000/model.safetensors' |
|
|
vae_version = '/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/model-checkpoints/sd-vae-ft-mse' |
|
|
|
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(vae_version, subfolder="vae").requires_grad_(False) |
|
|
vae.to(device) |
|
|
|
|
|
|
|
|
amd_model = AMDModel.from_config(AMDModel.load_config(amd_config)) |
|
|
model_load_pretrain(amd_model,amd_ckpt,not_load_keyword='abcabcacbd',strict=True) |
|
|
amd_model.to(device) |
|
|
print(f'######### load AMD weight from {amd_ckpt} #############') |
|
|
|
|
|
|
|
|
with open(args.data_pkl, 'rb') as f: |
|
|
datas = pickle.load(f) |
|
|
|
|
|
|
|
|
|
|
|
log = [] |
|
|
log_path = args.data_pkl.split('.')[0]+'_log.pkl' |
|
|
|
|
|
total_num = len(datas) |
|
|
for i,data in enumerate(datas): |
|
|
try: |
|
|
video_path = data['video_path'] |
|
|
motion_path = data['motion_path'] |
|
|
|
|
|
if os.path.exists(motion_path): |
|
|
continue |
|
|
|
|
|
|
|
|
video_reader = VideoReader(video_path, ctx=cpu(0)) |
|
|
idx = [j for j in range(len(video_reader))] |
|
|
videos = torch.from_numpy(video_reader.get_batch(idx).asnumpy()).permute(0, 3, 1, 2).contiguous() |
|
|
videos = videos / 255.0 |
|
|
videos = pixel_transforms(videos) |
|
|
videos = videos.to(device) |
|
|
videos = videos.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
z = vae_encode(vae,videos).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
motion = amd_model.extract_motion(z) |
|
|
motion = motion.squeeze(0) |
|
|
|
|
|
|
|
|
log.append({ |
|
|
'video_path':video_path, |
|
|
'motion_path':motion_path, |
|
|
'num_frames':len(video_reader), |
|
|
}) |
|
|
|
|
|
|
|
|
torch.save(motion,motion_path) |
|
|
|
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
del video_reader,z,motion,videos |
|
|
|
|
|
|
|
|
print(f'{i} has been processed, total:{total_num}') |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
print('error',e) |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
continue |
|
|
|
|
|
with open(log_path, 'wb') as file: |
|
|
|
|
|
pickle.dump(log, file) |
|
|
|
|
|
print('******** All Finished **********') |
|
|
|
|
|
|