FineMoGen / data /database /generate_kit.py
camenduru's picture
thanks to mingyuan-zhang ❤
5e6f92f
import os
import torch
import numpy as np
import clip
from tqdm import tqdm
device = 'cpu'
clip_model, _ = clip.load('ViT-B/32', device)
data_root_dir = "../datasets/kit_ml"
data_clip_dir = os.path.join(data_root_dir, "clip_feats")
data_caption_dir = os.path.join(data_root_dir, "texts")
data_motion_dir = os.path.join(data_root_dir, "motions")
train_split = os.path.join(data_root_dir, "train.txt")
all_text_features = []
all_captions = []
all_motions = []
all_m_lengths = []
all_clip_seq_features = []
std = np.load(os.path.join(data_root_dir, "std.npy"))
mean = np.load(os.path.join(data_root_dir, "mean.npy"))
for filename in tqdm(open(train_split)):
filename = filename.strip()
caption_file = os.path.join(data_caption_dir, filename + ".txt")
caption = open(caption_file).readlines()[0].strip()
text = clip.tokenize([caption], truncate=True).to(device)
with torch.no_grad():
text_feature = clip_model.encode_text(text)[0].numpy()
all_text_features.append(text_feature)
all_captions.append(caption)
motion_file = os.path.join(data_motion_dir, filename + ".npy")
motion_data = np.load(motion_file)
# import pdb; pdb.set_trace()
motion_data = (motion_data - mean) / (std + 1e-9)
motion_data = motion_data[:196]
motion = np.zeros((196, 251))
motion[:motion_data.shape[0], :] = motion_data
all_motions.append(motion)
m_length = motion_data.shape[0]
all_m_lengths.append(m_length)
clip_feat_file = os.path.join(data_clip_dir, filename + ".npy")
clip_feat = np.load(clip_feat_file)[0]
all_clip_seq_features.append(clip_feat)
all_text_features = np.array(all_text_features)
all_captions = np.array(all_captions)
all_motions = np.array(all_motions)
all_m_lengths = np.array(all_m_lengths)
all_clip_seq_features = np.array(all_clip_seq_features)
output = {
'text_features': all_text_features,
'captions': all_captions,
'motions': all_motions,
'm_lengths': all_m_lengths,
'clip_seq_features': all_clip_seq_features
}
npz_path = "kit_text_train.npz"
np.savez_compressed(npz_path, **output)