motion-stream / humanml3d_272 /dataset_eval_t2m.py
zirobtc's picture
Initial upload of MotionStreamer code, excluding large extracted data and output folders.
0e267a7 verified
import torch
from torch.utils import data
import numpy as np
from os.path import join as pjoin
import random
import codecs as cs
from tqdm import tqdm
import os
import utils.paramUtil as paramUtil
from torch.utils.data._utils.collate import default_collate
def collate_fn(batch):
batch.sort(key=lambda x: x[2], reverse=True)
return default_collate(batch)
class Text2MotionDataset(data.Dataset):
def __init__(self, dataset_name, is_test, max_text_len = 20, unit_length = 4):
self.max_length = 20
self.pointer = 0
self.dataset_name = dataset_name
self.is_test = is_test
self.max_text_len = max_text_len
self.unit_length = unit_length
if dataset_name == 't2m_272':
self.data_root = './humanml3d_272'
self.motion_dir = pjoin(self.data_root, 'motion_data')
self.text_dir = pjoin(self.data_root, 'texts')
self.joints_num = 22
self.max_motion_length = 300
fps = 30
self.meta_dir = './humanml3d_272/mean_std'
if is_test:
split_file = pjoin(self.data_root, 'split', 'test.txt')
else:
split_file = pjoin(self.data_root, 'split', 'val.txt')
mean = np.load(pjoin(self.meta_dir, 'Mean.npy'))
std = np.load(pjoin(self.meta_dir, 'Std.npy'))
min_motion_len = 60 # 30 fps
data_dict = {}
id_list = []
with cs.open(split_file, 'r') as f:
for line in f.readlines():
id_list.append(line.strip())
new_name_list = []
length_list = []
for name in tqdm(id_list):
motion = np.load(pjoin(self.motion_dir, name + '.npy'))
if (len(motion)) < min_motion_len or (len(motion) >= self.max_motion_length):
continue
text_data = []
flag = False
with cs.open(pjoin(self.text_dir, name + '.txt')) as f:
for line in f.readlines():
text_dict = {}
line_split = line.strip().split('#')
caption = line_split[0]
tokens = line_split[1].split(' ')
f_tag = float(line_split[2])
to_tag = float(line_split[3])
f_tag = 0.0 if np.isnan(f_tag) else f_tag
to_tag = 0.0 if np.isnan(to_tag) else to_tag
text_dict['caption'] = caption
text_dict['tokens'] = tokens
if f_tag == 0.0 and to_tag == 0.0:
flag = True
text_data.append(text_dict)
else:
n_motion = motion[int(f_tag*fps) : int(to_tag*fps)]
if (len(n_motion)) < min_motion_len or (len(n_motion) >= self.max_motion_length):
continue
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
while new_name in data_dict:
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
data_dict[new_name] = {'motion': n_motion,
'length': len(n_motion),
'text':[text_dict]}
new_name_list.append(new_name)
length_list.append(len(n_motion))
if flag:
data_dict[name] = {'motion': motion,
'length': len(motion),
'text': text_data}
new_name_list.append(name)
length_list.append(len(motion))
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
self.mean = mean
self.std = std
self.length_arr = np.array(length_list)
self.data_dict = data_dict
self.name_list = name_list
self.reset_max_len(self.max_length)
def reset_max_len(self, length):
assert length <= self.max_motion_length
self.pointer = np.searchsorted(self.length_arr, length)
print("Pointer Pointing at %d"%self.pointer)
self.max_length = length
def inv_transform(self, data):
return data * self.std + self.mean
def forward_transform(self, data):
return (data - self.mean) / self.std
def __len__(self):
return len(self.data_dict) - self.pointer
def __getitem__(self, item):
idx = self.pointer + item
name = self.name_list[idx]
data = self.data_dict[name]
motion, m_length, text_list = data['motion'], data['length'], data['text']
text_data = random.choice(text_list)
caption = text_data['caption']
if self.unit_length < 10:
coin2 = np.random.choice(['single', 'single', 'double'])
else:
coin2 = 'single'
if coin2 == 'double':
m_length = (m_length // self.unit_length - 1) * self.unit_length
elif coin2 == 'single':
m_length = (m_length // self.unit_length) * self.unit_length
idx = random.randint(0, len(motion) - m_length)
motion = motion[idx:idx+m_length]
#"Motion Normalization"
motion = (motion - self.mean) / self.std
if m_length < self.max_motion_length:
motion = np.concatenate([motion,
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
], axis=0)
return caption, motion, m_length
def DATALoader(dataset_name, is_test,
batch_size,
num_workers = 64, unit_length = 4, drop_last=True) :
val_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, is_test, unit_length=unit_length),
batch_size,
shuffle = True,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last = drop_last)
return val_loader
def cycle(iterable):
while True:
for x in iterable:
yield x