NMR / data /main.py
Xxx999's picture
upload
45950ff
import os
import numpy as np
from tqdm import tqdm
import joblib
import random
from joblib import Parallel, delayed
def func(smplx_path):
gmr_path = smplx_path.replace('smplx_data', 'gmr_data')
assert os.path.exists(smplx_path) and os.path.exists(gmr_path)
return [np.load(smplx_path), np.load(gmr_path), smplx_path.replace('data/smplx_data/', '')]
data_paths = joblib.load('data/smplx_data/smplx_path.pkl')
all_data = Parallel(n_jobs=48)(delayed(func)(file_path) for file_path in tqdm(data_paths))
random.shuffle(all_data)
smplx_data = [item[0] for item in all_data]
gmr_data = [item[1] for item in all_data]
joblib.dump(all_data[:-20_000], 'data/train_400k.pkl')
joblib.dump(all_data[:200_000], 'data/train_200k.pkl')
joblib.dump(all_data[:100_000], 'data/train_100k.pkl')
joblib.dump(all_data[-20_000:], 'data/test_20k.pkl')
joblib.dump(all_data[-5_000:], 'data/test_5k.pkl')
joblib.dump(smplx_data[:-20_000], 'data/train_smplx_400k.pkl')
joblib.dump(smplx_data[:200_000], 'data/train_smplx_200k.pkl')
joblib.dump(smplx_data[:100_000], 'data/train_smplx_100k.pkl')
joblib.dump(smplx_data[-20_000:], 'data/test_smplx_20k.pkl')
joblib.dump(smplx_data[-5_000:], 'data/test_smplx_5k.pkl')
joblib.dump(gmr_data[:-20_000], 'data/train_gmr_400k.pkl')
joblib.dump(gmr_data[:200_000], 'data/train_gmr_200k.pkl')
joblib.dump(gmr_data[:100_000], 'data/train_gmr_100k.pkl')
joblib.dump(gmr_data[-20_000:], 'data/test_gmr_20k.pkl')
joblib.dump(gmr_data[-5_000:], 'data/test_gmr_5k.pkl')