Upload 3 files
Browse files- options/get_eval_option.py +83 -0
- options/option_transformer.py +68 -0
- options/option_vq.py +61 -0
options/get_eval_option.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import Namespace
|
| 2 |
+
import re
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def is_float(numStr):
|
| 7 |
+
flag = False
|
| 8 |
+
numStr = str(numStr).strip().lstrip('-').lstrip('+')
|
| 9 |
+
try:
|
| 10 |
+
reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$')
|
| 11 |
+
res = reg.match(str(numStr))
|
| 12 |
+
if res:
|
| 13 |
+
flag = True
|
| 14 |
+
except Exception as ex:
|
| 15 |
+
print("is_float() - error: " + str(ex))
|
| 16 |
+
return flag
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def is_number(numStr):
|
| 20 |
+
flag = False
|
| 21 |
+
numStr = str(numStr).strip().lstrip('-').lstrip('+')
|
| 22 |
+
if str(numStr).isdigit():
|
| 23 |
+
flag = True
|
| 24 |
+
return flag
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_opt(opt_path, device):
|
| 28 |
+
opt = Namespace()
|
| 29 |
+
opt_dict = vars(opt)
|
| 30 |
+
|
| 31 |
+
skip = ('-------------- End ----------------',
|
| 32 |
+
'------------ Options -------------',
|
| 33 |
+
'\n')
|
| 34 |
+
print('Reading', opt_path)
|
| 35 |
+
with open(opt_path) as f:
|
| 36 |
+
for line in f:
|
| 37 |
+
if line.strip() not in skip:
|
| 38 |
+
# print(line.strip())
|
| 39 |
+
key, value = line.strip().split(': ')
|
| 40 |
+
if value in ('True', 'False'):
|
| 41 |
+
opt_dict[key] = (value == 'True')
|
| 42 |
+
# print(key, value)
|
| 43 |
+
elif is_float(value):
|
| 44 |
+
opt_dict[key] = float(value)
|
| 45 |
+
elif is_number(value):
|
| 46 |
+
opt_dict[key] = int(value)
|
| 47 |
+
else:
|
| 48 |
+
opt_dict[key] = str(value)
|
| 49 |
+
|
| 50 |
+
# print(opt)
|
| 51 |
+
opt_dict['which_epoch'] = 'finest'
|
| 52 |
+
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
| 53 |
+
opt.model_dir = pjoin(opt.save_root, 'model')
|
| 54 |
+
opt.meta_dir = pjoin(opt.save_root, 'meta')
|
| 55 |
+
|
| 56 |
+
if opt.dataset_name == 't2m':
|
| 57 |
+
opt.data_root = './dataset/Sample1/'
|
| 58 |
+
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
|
| 59 |
+
opt.text_dir = pjoin(opt.data_root, 'texts')
|
| 60 |
+
opt.joints_num = 22
|
| 61 |
+
opt.dim_pose = 263
|
| 62 |
+
opt.max_motion_length = 196
|
| 63 |
+
opt.max_motion_frame = 196
|
| 64 |
+
opt.max_motion_token = 55
|
| 65 |
+
elif opt.dataset_name == 'kit':
|
| 66 |
+
opt.data_root = './dataset/KIT-ML/'
|
| 67 |
+
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
|
| 68 |
+
opt.text_dir = pjoin(opt.data_root, 'texts')
|
| 69 |
+
opt.joints_num = 21
|
| 70 |
+
opt.dim_pose = 251
|
| 71 |
+
opt.max_motion_length = 196
|
| 72 |
+
opt.max_motion_frame = 196
|
| 73 |
+
opt.max_motion_token = 55
|
| 74 |
+
else:
|
| 75 |
+
raise KeyError('Dataset not recognized')
|
| 76 |
+
|
| 77 |
+
opt.dim_word = 300
|
| 78 |
+
opt.num_classes = 200 // opt.unit_length
|
| 79 |
+
opt.is_train = False
|
| 80 |
+
opt.is_continue = False
|
| 81 |
+
opt.device = device
|
| 82 |
+
|
| 83 |
+
return opt
|
options/option_transformer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def get_args_parser():
|
| 4 |
+
parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for Amass',
|
| 5 |
+
add_help=True,
|
| 6 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 7 |
+
|
| 8 |
+
## dataloader
|
| 9 |
+
|
| 10 |
+
parser.add_argument('--dataname', type=str, default='kit', help='dataset directory')
|
| 11 |
+
parser.add_argument('--batch-size', default=128, type=int, help='batch size')
|
| 12 |
+
parser.add_argument('--fps', default=[20], nargs="+", type=int, help='frames per second')
|
| 13 |
+
parser.add_argument('--seq-len', type=int, default=64, help='training motion length')
|
| 14 |
+
|
| 15 |
+
## optimization
|
| 16 |
+
parser.add_argument('--total-iter', default=100000, type=int, help='number of total iterations to run')
|
| 17 |
+
parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup')
|
| 18 |
+
parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
|
| 19 |
+
parser.add_argument('--lr-scheduler', default=[60000], nargs="+", type=int, help="learning rate schedule (iterations)")
|
| 20 |
+
parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
|
| 21 |
+
|
| 22 |
+
parser.add_argument('--weight-decay', default=1e-6, type=float, help='weight decay')
|
| 23 |
+
parser.add_argument('--decay-option',default='all', type=str, choices=['all', 'noVQ'], help='disable weight decay on codebook')
|
| 24 |
+
parser.add_argument('--optimizer',default='adamw', type=str, choices=['adam', 'adamw'], help='disable weight decay on codebook')
|
| 25 |
+
|
| 26 |
+
## vqvae arch
|
| 27 |
+
parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension")
|
| 28 |
+
parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding")
|
| 29 |
+
parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
|
| 30 |
+
parser.add_argument("--down-t", type=int, default=3, help="downsampling rate")
|
| 31 |
+
parser.add_argument("--stride-t", type=int, default=2, help="stride size")
|
| 32 |
+
parser.add_argument("--width", type=int, default=512, help="width of the network")
|
| 33 |
+
parser.add_argument("--depth", type=int, default=3, help="depth of the network")
|
| 34 |
+
parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate")
|
| 35 |
+
parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width")
|
| 36 |
+
parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
|
| 37 |
+
|
| 38 |
+
## gpt arch
|
| 39 |
+
parser.add_argument("--block-size", type=int, default=25, help="seq len")
|
| 40 |
+
parser.add_argument("--embed-dim-gpt", type=int, default=512, help="embedding dimension")
|
| 41 |
+
parser.add_argument("--clip-dim", type=int, default=512, help="latent dimension in the clip feature")
|
| 42 |
+
parser.add_argument("--num-layers", type=int, default=2, help="nb of transformer layers")
|
| 43 |
+
parser.add_argument("--n-head-gpt", type=int, default=8, help="nb of heads")
|
| 44 |
+
parser.add_argument("--ff-rate", type=int, default=4, help="feedforward size")
|
| 45 |
+
parser.add_argument("--drop-out-rate", type=float, default=0.1, help="dropout ratio in the pos encoding")
|
| 46 |
+
|
| 47 |
+
## quantizer
|
| 48 |
+
parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
|
| 49 |
+
parser.add_argument('--quantbeta', type=float, default=1.0, help='dataset directory')
|
| 50 |
+
|
| 51 |
+
## resume
|
| 52 |
+
parser.add_argument("--resume-pth", type=str, default=None, help='resume vq pth')
|
| 53 |
+
parser.add_argument("--resume-trans", type=str, default=None, help='resume gpt pth')
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
## output directory
|
| 57 |
+
parser.add_argument('--out-dir', type=str, default='output_GPT_Final/', help='output directory')
|
| 58 |
+
parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir')
|
| 59 |
+
parser.add_argument('--vq-name', type=str, default='exp_debug', help='name of the generated dataset .npy, will create a file inside out-dir')
|
| 60 |
+
## other
|
| 61 |
+
parser.add_argument('--print-iter', default=200, type=int, help='print frequency')
|
| 62 |
+
parser.add_argument('--eval-iter', default=5000, type=int, help='evaluation frequency')
|
| 63 |
+
parser.add_argument('--seed', default=123, type=int, help='seed for initializing training. ')
|
| 64 |
+
parser.add_argument("--if-maxtest", action='store_true', help="test in max")
|
| 65 |
+
parser.add_argument('--pkeep', type=float, default=1.0, help='keep rate for gpt training')
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
return parser.parse_args()
|
options/option_vq.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
def get_args_parser():
|
| 4 |
+
parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for AIST',
|
| 5 |
+
add_help=True,
|
| 6 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 7 |
+
|
| 8 |
+
## dataloader
|
| 9 |
+
parser.add_argument('--dataname', type=str, default='kit', help='dataset directory')
|
| 10 |
+
parser.add_argument('--batch-size', default=128, type=int, help='batch size')
|
| 11 |
+
parser.add_argument('--window-size', type=int, default=64, help='training motion length')
|
| 12 |
+
|
| 13 |
+
## optimization
|
| 14 |
+
parser.add_argument('--total-iter', default=200000, type=int, help='number of total iterations to run')
|
| 15 |
+
parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup')
|
| 16 |
+
parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
|
| 17 |
+
parser.add_argument('--lr-scheduler', default=[50000, 400000], nargs="+", type=int, help="learning rate schedule (iterations)")
|
| 18 |
+
parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--weight-decay', default=0.0, type=float, help='weight decay')
|
| 21 |
+
parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss")
|
| 22 |
+
parser.add_argument('--loss-vel', type=float, default=0.1, help='hyper-parameter for the velocity loss')
|
| 23 |
+
parser.add_argument('--recons-loss', type=str, default='l2', help='reconstruction loss')
|
| 24 |
+
|
| 25 |
+
## vqvae arch
|
| 26 |
+
parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension")
|
| 27 |
+
parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding")
|
| 28 |
+
parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
|
| 29 |
+
parser.add_argument("--down-t", type=int, default=2, help="downsampling rate")
|
| 30 |
+
parser.add_argument("--stride-t", type=int, default=2, help="stride size")
|
| 31 |
+
parser.add_argument("--width", type=int, default=512, help="width of the network")
|
| 32 |
+
parser.add_argument("--depth", type=int, default=3, help="depth of the network")
|
| 33 |
+
parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate")
|
| 34 |
+
parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width")
|
| 35 |
+
parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
|
| 36 |
+
parser.add_argument('--vq-norm', type=str, default=None, help='dataset directory')
|
| 37 |
+
|
| 38 |
+
## quantizer
|
| 39 |
+
parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
|
| 40 |
+
parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ')
|
| 41 |
+
|
| 42 |
+
## resume
|
| 43 |
+
parser.add_argument("--resume-pth", type=str, default=None, help='resume pth for VQ')
|
| 44 |
+
parser.add_argument("--resume-gpt", type=str, default=None, help='resume pth for GPT')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
## output directory
|
| 48 |
+
parser.add_argument('--out-dir', type=str, default='output_vqfinal/', help='output directory')
|
| 49 |
+
parser.add_argument('--results-dir', type=str, default='visual_results/', help='output directory')
|
| 50 |
+
parser.add_argument('--visual-name', type=str, default='baseline', help='output directory')
|
| 51 |
+
parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir')
|
| 52 |
+
## other
|
| 53 |
+
parser.add_argument('--print-iter', default=200, type=int, help='print frequency')
|
| 54 |
+
parser.add_argument('--eval-iter', default=1000, type=int, help='evaluation frequency')
|
| 55 |
+
parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.')
|
| 56 |
+
|
| 57 |
+
parser.add_argument('--vis-gt', action='store_true', help='whether visualize GT motions')
|
| 58 |
+
parser.add_argument('--nb-vis', default=20, type=int, help='nb of visualizations')
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
return parser.parse_args()
|