Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import argparse | |
| def _add_dataset_args(parser): | |
| parser.add_argument("--dataname", type=str, default="kit", help="dataset directory") | |
| parser.add_argument("--data_root", type=str, default=None, help="dataset directory") | |
| parser.add_argument("--max_seq_length", default=600, type=int) | |
| parser.add_argument("--add_frame_cond", type=float, choices=[1], default=None) | |
| parser.add_argument( | |
| "--data_format", type=str, default="pose", choices=["pose", "face"] | |
| ) | |
| parser.add_argument("--dataset", default="social", type=str) | |
| parser.add_argument("--batch_size", default=64, type=int, help="batch size") | |
| def _add_optim_args(parser): | |
| parser.add_argument( | |
| "--total_iter", | |
| default=300_000, | |
| type=int, | |
| help="number of total iterations to run", | |
| ) | |
| parser.add_argument( | |
| "--warm_up_iter", | |
| default=1000, | |
| type=int, | |
| help="number of total iterations for warmup", | |
| ) | |
| parser.add_argument("--lr", default=2e-4, type=float, help="max learning rate") | |
| parser.add_argument( | |
| "--lr_scheduler", | |
| default=[300_000], | |
| nargs="+", | |
| type=int, | |
| help="learning rate schedule (iterations)", | |
| ) | |
| parser.add_argument("--gamma", default=0.05, type=float, help="learning rate decay") | |
| parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay") | |
| parser.add_argument( | |
| "--commit", | |
| type=float, | |
| default=0.02, | |
| help="hyper-parameter for the commitment loss", | |
| ) | |
| parser.add_argument( | |
| "--loss_vel", | |
| type=float, | |
| default=0.1, | |
| help="hyper-parameter for the velocity loss", | |
| ) | |
| def _add_model_args(parser): | |
| parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension") | |
| parser.add_argument("--depth", type=int, default=3, help="depth of the network") | |
| parser.add_argument( | |
| "--output_emb_width", type=int, default=512, help="output embedding width" | |
| ) | |
| parser.add_argument( | |
| "--resume_pth", type=str, default=None, help="resume pth for VQ" | |
| ) | |
| def train_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Optimal Transport AutoEncoder training for AIST", | |
| add_help=True, | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| _add_dataset_args(parser) | |
| _add_optim_args(parser) | |
| _add_model_args(parser) | |
| ## output directory | |
| parser.add_argument("--out_dir", type=str, required=True, help="output directory") | |
| ## other | |
| parser.add_argument("--print_iter", default=200, type=int, help="print frequency") | |
| parser.add_argument( | |
| "--eval_iter", default=1000, type=int, help="evaluation frequency" | |
| ) | |
| parser.add_argument( | |
| "--seed", default=123, type=int, help="seed for initializing training." | |
| ) | |
| args = parser.parse_args() | |
| return args | |