lsnu's picture
Add files using upload-large-folder tool
5ce8761 verified
"""Main script for training and testing."""
import argparse
import os
from pathlib import Path
import sys
import torch
from datasets import fetch_dataset_class
from modeling.policy import fetch_model_class
from utils.common_utils import str2bool, str_none
from utils.trainers import fetch_train_tester
def parse_arguments():
parser = argparse.ArgumentParser("Parse arguments for main.py")
# Tuples: (name, type, default)
arguments = [
# Dataset/loader arguments
('train_data_dir', Path, ''),
('eval_data_dir', Path, ''),
('train_instructions', Path, ''),
('val_instructions', Path, ''),
('dataset', str, "Peract"),
('num_workers', int, 4),
('batch_size', int, 64),
('batch_size_val', int, 64),
('chunk_size', int, 1),
('memory_limit', float, 8), # cache limit in GB
# Logging arguments
('base_log_dir', Path, Path(__file__).parent / "train_logs"),
('exp_log_dir', Path, "exp"),
('run_log_dir', Path, "run"),
# Training and testing arguments
('checkpoint', str_none, None),
('val_freq', int, 4000),
('interm_ckpt_freq', int, 1000000),
('eval_only', str2bool, False),
('lr', float, 1e-4),
('backbone_lr', float, 1e-4),
('lr_scheduler', str, "constant"),
('wd', float, 5e-3),
('train_iters', int, 600000),
('use_compile', str2bool, False),
('use_ema', str2bool, False),
('lv2_batch_size', int, 1),
# Model arguments: general policy type
('model_type', str, 'denoise3d'),
('bimanual', str2bool, False),
('keypose_only', str2bool, True),
('pre_tokenize', str2bool, True),
('custom_img_size', int, None),
('workspace_normalizer_buffer', float, 0.04),
# Model arguments: encoder
('backbone', str, "clip"),
('finetune_backbone', str2bool, False),
('finetune_text_encoder', str2bool, False),
('fps_subsampling_factor', int, 5),
# Model arguments: encoder and head
('embedding_dim', int, 120), # divisible by num_attn_heads
('num_attn_heads', int, 8),
('num_vis_instr_attn_layers', int, 3),
('num_history', int, 1),
# Model arguments: head
('num_shared_attn_layers', int, 4),
('relative_action', str2bool, False),
('rotation_format', str, 'quat_xyzw'),
('denoise_timesteps', int, 10),
('denoise_model', str, "rectified_flow")
]
for arg in arguments:
parser.add_argument(f'--{arg[0]}', type=arg[1], default=arg[2])
return parser.parse_args()
def suppress_output_on_non_main():
if int(os.environ.get("RANK", 0)) != 0:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Arguments
args = parse_arguments()
print("Arguments:")
print(args)
print("-" * 100)
log_dir = args.base_log_dir / args.exp_log_dir / args.run_log_dir
args.log_dir = log_dir
log_dir.mkdir(exist_ok=True, parents=True)
print("Logging:", log_dir)
print(
"Available devices (CUDA_VISIBLE_DEVICES):",
os.environ.get("CUDA_VISIBLE_DEVICES")
)
print("Device count:", torch.cuda.device_count())
args.local_rank = int(os.environ["LOCAL_RANK"])
suppress_output_on_non_main()
# DDP initialization
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Select dataset and model classes
dataset_class = fetch_dataset_class(args.dataset)
model_class = fetch_model_class(args.model_type)
# Run
TrainTester = fetch_train_tester(args.dataset)
train_tester = TrainTester(args, dataset_class, model_class)
train_tester.main()
# Safe program termination
if torch.distributed.is_initialized():
torch.cuda.empty_cache()
torch.distributed.destroy_process_group()