| """Main script for training and testing.""" |
|
|
| import argparse |
| import os |
| from pathlib import Path |
| import sys |
|
|
| import torch |
|
|
| from datasets import fetch_dataset_class |
| from modeling.policy import fetch_model_class |
| from utils.common_utils import str2bool, str_none |
| from utils.trainers import fetch_train_tester |
|
|
|
|
| def parse_arguments(): |
| parser = argparse.ArgumentParser("Parse arguments for main.py") |
| |
| arguments = [ |
| |
| ('train_data_dir', Path, ''), |
| ('eval_data_dir', Path, ''), |
| ('train_instructions', Path, ''), |
| ('val_instructions', Path, ''), |
| ('dataset', str, "Peract"), |
| ('num_workers', int, 4), |
| ('batch_size', int, 64), |
| ('batch_size_val', int, 64), |
| ('chunk_size', int, 1), |
| ('memory_limit', float, 8), |
| |
| ('base_log_dir', Path, Path(__file__).parent / "train_logs"), |
| ('exp_log_dir', Path, "exp"), |
| ('run_log_dir', Path, "run"), |
| |
| ('checkpoint', str_none, None), |
| ('val_freq', int, 4000), |
| ('interm_ckpt_freq', int, 1000000), |
| ('eval_only', str2bool, False), |
| ('lr', float, 1e-4), |
| ('backbone_lr', float, 1e-4), |
| ('lr_scheduler', str, "constant"), |
| ('wd', float, 5e-3), |
| ('train_iters', int, 600000), |
| ('use_compile', str2bool, False), |
| ('use_ema', str2bool, False), |
| ('lv2_batch_size', int, 1), |
| |
| ('model_type', str, 'denoise3d'), |
| ('bimanual', str2bool, False), |
| ('keypose_only', str2bool, True), |
| ('pre_tokenize', str2bool, True), |
| ('custom_img_size', int, None), |
| ('workspace_normalizer_buffer', float, 0.04), |
| |
| ('backbone', str, "clip"), |
| ('finetune_backbone', str2bool, False), |
| ('finetune_text_encoder', str2bool, False), |
| ('fps_subsampling_factor', int, 5), |
| |
| ('embedding_dim', int, 120), |
| ('num_attn_heads', int, 8), |
| ('num_vis_instr_attn_layers', int, 3), |
| ('num_history', int, 1), |
| |
| ('num_shared_attn_layers', int, 4), |
| ('relative_action', str2bool, False), |
| ('rotation_format', str, 'quat_xyzw'), |
| ('denoise_timesteps', int, 10), |
| ('denoise_model', str, "rectified_flow") |
| ] |
| for arg in arguments: |
| parser.add_argument(f'--{arg[0]}', type=arg[1], default=arg[2]) |
|
|
| return parser.parse_args() |
|
|
|
|
| def suppress_output_on_non_main(): |
| if int(os.environ.get("RANK", 0)) != 0: |
| sys.stdout = open(os.devnull, "w") |
| sys.stderr = open(os.devnull, "w") |
|
|
|
|
| if __name__ == '__main__': |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
| |
| args = parse_arguments() |
| print("Arguments:") |
| print(args) |
| print("-" * 100) |
|
|
| log_dir = args.base_log_dir / args.exp_log_dir / args.run_log_dir |
| args.log_dir = log_dir |
| log_dir.mkdir(exist_ok=True, parents=True) |
| print("Logging:", log_dir) |
| print( |
| "Available devices (CUDA_VISIBLE_DEVICES):", |
| os.environ.get("CUDA_VISIBLE_DEVICES") |
| ) |
| print("Device count:", torch.cuda.device_count()) |
| args.local_rank = int(os.environ["LOCAL_RANK"]) |
| suppress_output_on_non_main() |
|
|
| |
| torch.cuda.set_device(args.local_rank) |
| torch.distributed.init_process_group(backend='nccl', init_method='env://') |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.deterministic = False |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| dataset_class = fetch_dataset_class(args.dataset) |
| model_class = fetch_model_class(args.model_type) |
|
|
| |
| TrainTester = fetch_train_tester(args.dataset) |
| train_tester = TrainTester(args, dataset_class, model_class) |
| train_tester.main() |
|
|
| |
| if torch.distributed.is_initialized(): |
| torch.cuda.empty_cache() |
| torch.distributed.destroy_process_group() |
|
|