#!/usr/bin/env python import logging import os import time from contextlib import nullcontext from pprint import pformat from typing import Any import torch import torch.distributed as dist from termcolor import colored from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.optim.factory import make_optimizer_and_scheduler from lerobot.common.policies.factory import make_policy from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.train_utils import ( get_step_checkpoint_dir, get_step_identifier, load_training_state, save_checkpoint, update_last_checkpoint, ) from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, has_method, init_logging, ) from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.eval import eval_policy def update_policy( train_metrics: MetricsTracker, policy: PreTrainedPolicy, batch: Any, optimizer: Optimizer, grad_clip_norm: float, grad_scaler: GradScaler, lr_scheduler=None, use_amp: bool = False, lock=None, ) -> tuple[MetricsTracker, dict]: start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() with torch.autocast(device_type=device.type) if use_amp else nullcontext(): loss, output_dict = policy.forward(batch) grad_scaler.scale(loss).backward() grad_scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False ) with lock if lock is not None else nullcontext(): grad_scaler.step(optimizer) grad_scaler.update() optimizer.zero_grad() if lr_scheduler is not None: lr_scheduler.step() if has_method(policy, "update"): policy.update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.update_s = time.perf_counter() - start_time return train_metrics, output_dict @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate() logging.info(pformat(cfg.to_dict())) if "RANK" in os.environ and "WORLD_SIZE" in os.environ: dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) is_main_process = (local_rank == 0) else: device = get_safe_torch_device(cfg.policy.device, log=True) is_main_process = True local_rank = 0 if cfg.seed is not None: set_seed(cfg.seed + local_rank) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True if cfg.wandb.enable and cfg.wandb.project and is_main_process: wandb_logger = WandBLogger(cfg) else: wandb_logger = None if is_main_process: logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) logging.info("Creating dataset") if is_main_process: dataset = make_dataset(cfg) if dist.is_initialized(): dist.barrier() else: if dist.is_initialized(): dist.barrier() dataset = make_dataset(cfg) logging.info("Creating policy") policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta).to(device) if dist.is_initialized(): policy = DDP(policy, device_ids=[device], output_device=device, find_unused_parameters=False) raw_policy = policy.module if isinstance(policy, DDP) else policy logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, raw_policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) step = 0 if cfg.resume: step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) if is_main_process: logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None: logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") sampler = DistributedSampler(dataset, shuffle=True) if dist.is_initialized() else None dataloader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=cfg.batch_size, shuffle=(sampler is None), num_workers=cfg.num_workers, pin_memory=device.type != "cpu", drop_last=True, ) dl_iter = cycle(dataloader) policy.train() train_metrics = { "loss": AverageMeter("loss", ":.3f"), "grad_norm": AverageMeter("grdn", ":.3f"), "lr": AverageMeter("lr", ":0.1e"), "update_s": AverageMeter("updt_s", ":.3f"), "dataloading_s": AverageMeter("data_s", ":.3f"), } train_tracker = MetricsTracker( cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step ) if is_main_process: logging.info("Start offline training on a fixed dataset") for _ in range(step, cfg.steps): if dist.is_initialized(): sampler.set_epoch(_) start_time = time.perf_counter() batch = next(dl_iter) train_tracker.dataloading_s = time.perf_counter() - start_time for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) train_tracker, output_dict = update_policy( train_tracker, policy, batch, optimizer, cfg.optimizer.grad_clip_norm, grad_scaler, lr_scheduler=lr_scheduler, use_amp=cfg.policy.use_amp ) step += 1 train_tracker.step() is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps if is_log_step and is_main_process: logging.info(train_tracker) if wandb_logger: wandb_log_dict = train_tracker.to_dict() if output_dict: wandb_log_dict.update(output_dict) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step and is_main_process: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) save_checkpoint(checkpoint_dir, step, cfg, policy.module if dist.is_initialized() else policy, optimizer, lr_scheduler) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) if dist.is_initialized(): dist.destroy_process_group() logging.info("End of training") if __name__ == "__main__": init_logging() train()