from copy import deepcopy import os import random import numpy as np import torch from torch import optim from torch.utils.data.distributed import DistributedSampler from torch import nn import torch.distributed as dist from torch.utils.data import DataLoader from torch.nn.parallel import DistributedDataParallel from torch.utils.tensorboard import SummaryWriter from tqdm import trange, tqdm from modeling.encoder.text import fetch_tokenizers from ..common_utils import count_parameters from ..depth2cloud import fetch_depth2cloud from ..data_preprocessors import fetch_data_preprocessor from ..ema import EMA from ..schedulers import fetch_scheduler from .utils import compute_metrics class BaseTrainTester: """Train/test a trajectory optimization algorithm.""" def __init__(self, args, dataset_cls, model_cls): """Initialize.""" self.args = args self.dataset_cls = dataset_cls self.model_cls = model_cls self.preprocessor = fetch_data_preprocessor(self.args.dataset)( self.args.keypose_only, self.args.num_history, custom_imsize=self.args.custom_img_size, depth2cloud=fetch_depth2cloud(self.args.dataset) ) if dist.get_rank() == 0 and not self.args.eval_only: self.writer = SummaryWriter(log_dir=args.log_dir) def get_datasets(self): """Initialize datasets.""" # Initialize datasets with arguments train_dataset = self.dataset_cls( root=self.args.train_data_dir, instructions=self.args.train_instructions, relative_action=self.args.relative_action, mem_limit=self.args.memory_limit, chunk_size=self.args.chunk_size ) val_dataset = self.dataset_cls( root=self.args.eval_data_dir, instructions=self.args.val_instructions, copies=1, relative_action=self.args.relative_action, mem_limit=0.1, chunk_size=self.args.chunk_size ) return train_dataset, val_dataset def get_loaders(self): """Initialize data loaders.""" def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) # Datasets train_dataset, val_dataset = self.get_datasets() # Samplers and loaders g = torch.Generator() g.manual_seed(0) train_sampler = DistributedSampler(train_dataset, drop_last=True) train_loader = DataLoader( train_dataset, batch_size=self.args.batch_size // self.args.chunk_size, shuffle=False, num_workers=self.args.num_workers, worker_init_fn=seed_worker, collate_fn=base_collate_fn, pin_memory=True, sampler=train_sampler, drop_last=True, generator=g, prefetch_factor=4, persistent_workers=True ) # No sampler for val! if dist.get_rank() == 0: val_loader = DataLoader( val_dataset, batch_size=self.args.batch_size_val // self.args.chunk_size, shuffle=False, num_workers=self.args.num_workers, collate_fn=base_collate_fn, pin_memory=True, sampler=None, drop_last=False, prefetch_factor=4, persistent_workers=True ) else: val_loader = None return train_loader, val_loader, train_sampler def get_model(self): """Initialize the model.""" # Initialize model with arguments _model = self.model_cls( backbone=self.args.backbone, finetune_backbone=self.args.finetune_backbone, finetune_text_encoder=self.args.finetune_text_encoder, num_vis_instr_attn_layers=self.args.num_vis_instr_attn_layers, fps_subsampling_factor=self.args.fps_subsampling_factor, embedding_dim=self.args.embedding_dim, num_attn_heads=self.args.num_attn_heads, nhist=self.args.num_history, nhand=2 if self.args.bimanual else 1, num_shared_attn_layers=self.args.num_shared_attn_layers, relative=self.args.relative_action, rotation_format=self.args.rotation_format, denoise_timesteps=self.args.denoise_timesteps, denoise_model=self.args.denoise_model, lv2_batch_size=self.args.lv2_batch_size ) # Print basic modules' parameters if dist.get_rank() == 0: count_parameters(_model) # Useful for some models to ensure parameters are contiguous for name, param in _model.named_parameters(): if param.requires_grad and param.ndim > 1 and not param.is_contiguous(): print(f"Fixing layout for: {name}") param.data = param.contiguous() return _model @torch.no_grad() def get_workspace_normalizer(self, ndims=3): print("Computing workspace normalizer...") # Initialize datasets with arguments train_dataset = self.dataset_cls( root=self.args.train_data_dir, instructions=self.args.train_instructions, copies=1, relative_action=self.args.relative_action, mem_limit=0.1, actions_only=True, chunk_size=self.args.chunk_size ) data_loader = DataLoader( train_dataset, batch_size=max(self.args.batch_size, 64) // self.args.chunk_size, collate_fn=actions_collate_fn, shuffle=False, num_workers=self.args.num_workers ) # Loop and compute action min-max min_, max_ = torch.ones(ndims) * 10000, -torch.ones(ndims) * 10000 for sample in tqdm(data_loader): action = sample["action"][..., :ndims].reshape([-1, ndims]) min_ = torch.min(min_, action.min(0).values) max_ = torch.max(max_, action.max(0).values) min_ = min_ - self.args.workspace_normalizer_buffer max_ = max_ + self.args.workspace_normalizer_buffer return nn.Parameter(torch.stack([min_, max_]), requires_grad=False) def get_optimizer(self, model): """Initialize optimizer.""" optimizer_grouped_parameters = [ {"params": [], "weight_decay": 0.0, "lr": self.args.lr}, {"params": [], "weight_decay": self.args.wd, "lr": self.args.lr} ] if self.args.finetune_backbone: optimizer_grouped_parameters.append({ "params": [], "weight_decay": self.args.wd, "lr": self.args.backbone_lr }) # Collect names of all norm parameters norm_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LocalResponseNorm, torch.nn.RMSNorm ) norm_param_names = set() for module_name, module in model.named_modules(): if isinstance(module, norm_types): for param_name, _ in module.named_parameters(recurse=False): norm_param_names.add(f"{module_name}.{param_name}") # Now split parameters based on name for name, param in model.named_parameters(): if not param.requires_grad: continue if name in norm_param_names or name.endswith(".bias"): optimizer_grouped_parameters[0]["params"].append(param) elif self.args.finetune_backbone and 'backbone' in name: optimizer_grouped_parameters[2]["params"].append(param) else: optimizer_grouped_parameters[1]["params"].append(param) optimizer = optim.AdamW( optimizer_grouped_parameters, betas=(0.9, 0.95) ) return optimizer def main(self): """Run main training/testing pipeline.""" # Get loaders train_loader, val_loader, train_sampler = self.get_loaders() # Get model model = self.get_model() self.tokenizer = fetch_tokenizers(self.args.backbone) if not os.path.exists(self.args.checkpoint): normalizer = self.get_workspace_normalizer() model.workspace_normalizer.copy_(normalizer) dist.barrier(device_ids=[torch.cuda.current_device()]) # Get optimizer optimizer = self.get_optimizer(model) lr_scheduler = fetch_scheduler( self.args.lr_scheduler, optimizer, self.args.train_iters ) scaler = torch.GradScaler() # Move model to devices if torch.cuda.is_available(): model = model.cuda() # make sure to compile before DDP! if self.args.use_compile: model.compute_loss = torch.compile(model.compute_loss, fullgraph=True) model = DistributedDataParallel( model, device_ids=[self.args.local_rank], broadcast_buffers=False, find_unused_parameters=True ) # Initialize EMA copy ema_model = deepcopy(model) self.ema = EMA() # Check for a checkpoint start_iter, best_loss = 0, None if self.args.checkpoint: start_iter, best_loss = self.load_checkpoint(model, ema_model, optimizer) print(model.module.workspace_normalizer) # Eval only if self.args.eval_only: if dist.get_rank() == 0: print("Test evaluation.......") model.eval() self.evaluate_nsteps( ema_model if self.args.use_ema else model, val_loader, step_id=-1, val_iters=-1 ) dist.barrier(device_ids=[torch.cuda.current_device()]) return ema_model if self.args.use_ema else model # Step the lr scheduler to the current step for _ in range(start_iter): lr_scheduler.step() # Step the sampler to the currect "epoch" samples_per_epoch = len(train_loader) epoch = start_iter // samples_per_epoch + 1 train_sampler.set_epoch(epoch) # ensures new batches are sampled # Training loop model.train() iter_loader = iter(train_loader) for step_id in trange(start_iter, self.args.train_iters): try: sample = next(iter_loader) except StopIteration: # when the iterator is exhausted, we need to reset it # and increment the epoch epoch += 1 train_sampler.set_epoch(epoch) iter_loader = iter(train_loader) sample = next(iter_loader) self.train_one_step(model, optimizer, scaler, lr_scheduler, sample) self.ema.step(model, ema_model, self.args.use_ema, step_id) if (step_id + 1) % self.args.val_freq == 0 and dist.get_rank() == 0: print("Train evaluation.......") model.eval() self.evaluate_nsteps( ema_model if self.args.use_ema else model, train_loader, step_id, val_iters=10, split='train' ) print("Test evaluation.......") new_loss = self.evaluate_nsteps( ema_model if self.args.use_ema else model, val_loader, step_id, val_iters=1250 ) # save model best_loss = self.save_checkpoint( model, ema_model, optimizer, step_id, new_loss, best_loss ) model.train() dist.barrier(device_ids=[torch.cuda.current_device()]) return ema_model if self.args.use_ema else model @torch.no_grad() def prepare_batch(self, sample, augment=False): pass # implement in children def _model_forward(self, model, sample, training=True): action, action_mask, rgbs, rgb2d, pcds, instr, prop = self.prepare_batch( sample, augment=training ) if self.args.pre_tokenize: instr = self.tokenizer(instr).cuda(non_blocking=True) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): out = model( action, action_mask, rgbs, rgb2d, pcds, instr, prop, run_inference=not training ) return out # loss if training, else action def train_one_step(self, model, optimizer, scaler, lr_scheduler, sample): """Run a single training step.""" optimizer.zero_grad() # Forward pass loss = self._model_forward(model, sample) # Backward pass scaler.scale(loss).backward() # Clip gradients scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) # Update scaler.step(optimizer) scaler.update() # Step the lr scheduler lr_scheduler.step() @torch.inference_mode() def evaluate_nsteps(self, model, loader, step_id, val_iters, split='val'): """Run a given number of evaluation steps.""" values = {} device = next(model.parameters()).device model.eval() for i, sample in tqdm(enumerate(loader)): if i == val_iters: break pred_action = self._model_forward(model, sample, training=False) gt_action = sample["action"].cuda(non_blocking=True) if self.args.relative_action: pred_action = relative_to_absolute( pred_action[:, :, 0], sample["proprioception"].cuda(non_blocking=True)[:, :, 0] ) gt_action = relative_to_absolute( gt_action[:, :, 0], sample["proprioception"].cuda(non_blocking=True)[:, :, 0] ) losses, losses_B = compute_metrics(pred_action, gt_action) # Gather global statistics for n, l in losses.items(): key = f"{split}-losses/mean/{n}" if key not in values: values[key] = torch.Tensor([]).to(device) values[key] = torch.cat([values[key], l.unsqueeze(0)]) # Gather per-task statistics tasks = np.array(sample["task"]) for n, l in losses_B.items(): for task in np.unique(tasks): key = f"{split}-loss/{task}/{n}" l_task = l[tasks == task].mean() if key not in values: values[key] = torch.Tensor([]).to(device) values[key] = torch.cat([values[key], l_task.unsqueeze(0)]) # Log all statistics values = {k: v.mean().item() for k, v in values.items()} if dist.get_rank() == 0: if step_id > -1: for key, val in values.items(): self.writer.add_scalar(key, val, step_id) # Also log to terminal print(f"Step {step_id}:") for key, value in values.items(): print(f"{key}: {value:.03f}") return -values[f'{split}-losses/mean/traj_pos_acc_001'] def load_checkpoint(self, model, ema_model, optimizer): """Load from checkpoint.""" print("=> trying checkpoint '{}'".format(self.args.checkpoint)) if not os.path.exists(self.args.checkpoint): print('Warning: checkpoint was not found, starting from scratch') print('The main process will compute workspace bounds') return 0, None model_dict = torch.load( self.args.checkpoint, map_location="cpu", weights_only=True ) # Load weights flexibly msn, unxpct = model.load_state_dict(model_dict["weight"], strict=False) if msn: print(f"Missing keys (not found in checkpoint): {len(msn)}") print(msn) if unxpct: print(f"Unexpected keys (ignored): {len(unxpct)}") print(unxpct) if not msn and not unxpct: print("All keys matched successfully!") # EMA weights if model_dict.get("ema_weight") is not None: ema_model.load_state_dict(model_dict["ema_weight"], strict=True) # Useful for resuming training if 'optimizer' in model_dict and not self.args.eval_only: optimizer.load_state_dict(model_dict["optimizer"]) start_iter = model_dict.get("iter", 0) best_loss = model_dict.get("best_loss", None) print("=> loaded successfully '{}' (step {})".format( self.args.checkpoint, model_dict.get("iter", 0) )) del model_dict torch.cuda.empty_cache() return start_iter, best_loss def save_checkpoint(self, model, ema_model, optimizer, step_id, new_loss, best_loss): """Save checkpoint if requested.""" model_state = model.state_dict() ema_state = ema_model.state_dict() if self.args.use_ema else None # Best checkpoint if best_loss is None or new_loss <= best_loss: best_loss = new_loss torch.save({ "weight": model_state, "ema_weight": ema_state, "iter": step_id + 1, "best_loss": best_loss }, self.args.log_dir / "best.pth") # Last checkpoint (always saved) torch.save({ "weight": model_state, "ema_weight": ema_state, "optimizer": optimizer.state_dict(), "iter": step_id + 1, "best_loss": best_loss }, self.args.log_dir / "last.pth") # Save intermediate checkpoints if (step_id + 1) % self.args.interm_ckpt_freq == 0: torch.save({ "weight": model_state, "ema_weight": ema_state, "iter": step_id + 1, "best_loss": best_loss }, self.args.log_dir / f"interm{step_id + 1}.pth") return best_loss def base_collate_fn(batch): """Custom collate_fn, measured to be faster than default.""" _dict = {} # Values for these come as lists list_keys = ["task", "instr"] for key in list_keys: if key not in batch[0].keys(): continue _dict[key] = [] for item in batch: _dict[key].extend(item[key]) # Treat rest as tensors _dict.update({ k_: ( torch.cat([item[k_] for item in batch]) if batch[0][k_] is not None else None ) for k_ in batch[0].keys() if k_ not in list_keys }) return _dict def actions_collate_fn(batch): return {"action": torch.cat([item["action"] for item in batch])} def relative_to_absolute(action, proprio): # action (B, T, 8), proprio (B, 1, 7) pos = proprio[..., :3] + action[..., :3].cumsum(1) orn = proprio[..., 3:6] + action[..., 3:6].cumsum(1) orn = (orn + torch.pi) % (2 * torch.pi) - torch.pi return torch.cat([pos, orn, action[..., 6:]], -1)