| | """ |
| | PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. |
| | This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs |
| | entirely in PyTorch using the `PI0Pytorch` model and your existing config/data |
| | pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. |
| | |
| | Usage |
| | Single GPU: |
| | python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval> |
| | Example: |
| | python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test |
| | python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint |
| | Multi-GPU (single node): |
| | torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name> |
| | Example: |
| | torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test |
| | torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume |
| | Multi-Node Training: |
| | torchrun \ |
| | --nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \ |
| | --master_addr=<master_ip> --master_port=<port> \ |
| | scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval> |
| | |
| | """ |
| |
|
| | import dataclasses |
| | import gc |
| | import logging |
| | import os |
| | import platform |
| | import shutil |
| | import time |
| |
|
| | import jax |
| | import numpy as np |
| | import safetensors.torch |
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn.parallel |
| | import tqdm |
| |
|
| | import openpi.models.pi0_config |
| | import openpi.shared.normalize as _normalize |
| | import openpi.training.config as _config |
| | import openpi.training.data_loader as _data |
| | import openpi.transforms as _transforms |
| |
|
| |
|
| | def get_wandb(): |
| | import wandb |
| |
|
| | return wandb |
| |
|
| |
|
| | def init_logging(): |
| | level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} |
| |
|
| | class CustomFormatter(logging.Formatter): |
| | def format(self, record): |
| | record.levelname = level_mapping.get(record.levelname, record.levelname) |
| | return super().format(record) |
| |
|
| | formatter = CustomFormatter( |
| | fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", |
| | datefmt="%H:%M:%S", |
| | ) |
| | logger = logging.getLogger() |
| | logger.setLevel(logging.INFO) |
| | if not logger.handlers: |
| | ch = logging.StreamHandler() |
| | ch.setFormatter(formatter) |
| | logger.addHandler(ch) |
| | else: |
| | logger.handlers[0].setFormatter(formatter) |
| |
|
| |
|
| | def configure_process_logging(is_main: bool) -> None: |
| | root_logger = logging.getLogger() |
| | root_logger.setLevel(logging.INFO if is_main else logging.WARNING) |
| |
|
| | |
| | |
| | logging.getLogger("jax").setLevel(logging.WARNING) |
| | logging.getLogger("jaxlib").setLevel(logging.WARNING) |
| | logging.getLogger("absl").setLevel(logging.WARNING) |
| |
|
| |
|
| | def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): |
| | """Initialize wandb logging.""" |
| | if not enabled: |
| | return |
| |
|
| | wandb = get_wandb() |
| | ckpt_dir = config.checkpoint_dir |
| | if not ckpt_dir.exists(): |
| | raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") |
| |
|
| | if resuming: |
| | run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() |
| | wandb.init(id=run_id, resume="must", project=config.project_name) |
| | else: |
| | wandb.init( |
| | name=config.exp_name, |
| | config=dataclasses.asdict(config), |
| | project=config.project_name, |
| | ) |
| | (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) |
| |
|
| |
|
| | def setup_ddp(): |
| | world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| | use_ddp = world_size > 1 |
| | if use_ddp and not torch.distributed.is_initialized(): |
| | backend = "nccl" if torch.cuda.is_available() else "gloo" |
| | torch.distributed.init_process_group(backend=backend, init_method="env://") |
| |
|
| | |
| | if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: |
| | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" |
| |
|
| | local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) |
| | device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
| | if torch.cuda.is_available(): |
| | torch.cuda.set_device(device) |
| | return use_ddp, local_rank, device |
| |
|
| |
|
| | def cleanup_ddp(): |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| | torch.distributed.destroy_process_group() |
| |
|
| |
|
| | def set_seed(seed: int, local_rank: int): |
| | torch.manual_seed(seed + local_rank) |
| | np.random.seed(seed + local_rank) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed + local_rank) |
| |
|
| |
|
| | def build_datasets(config: _config.TrainConfig): |
| | |
| | data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) |
| | return data_loader, data_loader.data_config() |
| |
|
| |
|
| | def resolve_norm_stats_path(config: _config.TrainConfig, data_config: _config.DataConfig): |
| | if data_config.asset_id is None: |
| | return None |
| | return config.assets_dirs / data_config.asset_id / "norm_stats.json" |
| |
|
| |
|
| | def norm_stats_summary(data_config: _config.DataConfig) -> dict[str, object]: |
| | summary: dict[str, object] = {"keys": []} |
| | if data_config.norm_stats is None: |
| | return summary |
| |
|
| | summary["keys"] = sorted(data_config.norm_stats.keys()) |
| | for key in ("state", "actions"): |
| | stats = data_config.norm_stats.get(key) |
| | if stats is None: |
| | continue |
| | summary[f"{key}_mean_len"] = int(stats.mean.shape[-1]) |
| | summary[f"{key}_std_len"] = int(stats.std.shape[-1]) |
| | return summary |
| |
|
| |
|
| | def action_mask_indices(action_loss_mask: tuple[float, ...] | None) -> tuple[list[int], list[int]]: |
| | if action_loss_mask is None: |
| | return [], [] |
| | active = [i for i, value in enumerate(action_loss_mask) if value != 0] |
| | masked = [i for i, value in enumerate(action_loss_mask) if value == 0] |
| | return active, masked |
| |
|
| |
|
| | def is_packed_transform_active(data_config: _config.DataConfig) -> bool: |
| | return any(isinstance(transform, _transforms.PackPerArmBlocks) for transform in data_config.model_transforms.inputs) |
| |
|
| |
|
| | def compute_masked_action_loss( |
| | losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None |
| | ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| | if losses.ndim != 3: |
| | raise ValueError(f"Expected losses with shape [B, H, D], got {tuple(losses.shape)}.") |
| |
|
| | if action_loss_mask is None: |
| | return losses.mean(), None |
| |
|
| | mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype) |
| | if mask.ndim != 1 or mask.shape[0] != losses.shape[-1]: |
| | raise ValueError( |
| | f"Action loss mask must be 1D with length {losses.shape[-1]}, got shape {tuple(mask.shape)}." |
| | ) |
| |
|
| | denom = mask.sum() * losses.shape[0] * losses.shape[1] |
| | if float(denom.item()) <= 0: |
| | raise ValueError("Action loss mask must include at least one active dimension.") |
| | return (losses * mask.view(1, 1, -1)).sum() / denom, mask |
| |
|
| |
|
| | def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: |
| | return model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model |
| |
|
| |
|
| | def grad_norm_for_parameters(parameters) -> float: |
| | total_sq = None |
| | for param in parameters: |
| | if param.grad is None: |
| | continue |
| | grad_sq = torch.sum(param.grad.detach().to(torch.float32) ** 2) |
| | total_sq = grad_sq if total_sq is None else total_sq + grad_sq |
| | if total_sq is None: |
| | return 0.0 |
| | return float(torch.sqrt(total_sq).item()) |
| |
|
| |
|
| | def collect_gradient_bucket_norms(model: torch.nn.Module) -> dict[str, float]: |
| | model_for_logging = unwrap_model(model) |
| | if model_for_logging.use_split_action_expert: |
| | metrics = { |
| | "grad_shared_backbone": grad_norm_for_parameters( |
| | model_for_logging.paligemma_with_expert.paligemma.parameters() |
| | ), |
| | "grad_left_action_in": grad_norm_for_parameters(model_for_logging.action_in_proj_arms[0].parameters()), |
| | "grad_right_action_in": grad_norm_for_parameters(model_for_logging.action_in_proj_arms[1].parameters()), |
| | "grad_left_expert": grad_norm_for_parameters( |
| | model_for_logging.paligemma_with_expert.left_gemma_expert.parameters() |
| | ), |
| | "grad_right_expert": grad_norm_for_parameters( |
| | model_for_logging.paligemma_with_expert.right_gemma_expert.parameters() |
| | ), |
| | "grad_action_out": grad_norm_for_parameters(model_for_logging.action_out_proj_arms.parameters()), |
| | } |
| | if model_for_logging.use_communicating_action_expert: |
| | metrics["grad_cross_arm_comm"] = grad_norm_for_parameters( |
| | [model_for_logging.paligemma_with_expert.cross_arm_comm] |
| | ) |
| | for layer_idx, gate_value in enumerate(model_for_logging.paligemma_with_expert.cross_arm_comm.detach().cpu()): |
| | metrics[f"cross_arm_comm_gate_layer_{layer_idx}"] = float(gate_value.item()) |
| | if model_for_logging.paligemma_with_expert.latest_cross_arm_attention_mass is not None: |
| | for layer_idx, attn_mass in enumerate( |
| | model_for_logging.paligemma_with_expert.latest_cross_arm_attention_mass.detach().cpu() |
| | ): |
| | metrics[f"cross_arm_attention_mass_layer_{layer_idx}"] = float(attn_mass.item()) |
| | return metrics |
| |
|
| | metrics = {"grad_shared_expert": grad_norm_for_parameters(model_for_logging.paligemma_with_expert.parameters())} |
| | if model_for_logging.use_parallel_action_heads: |
| | metrics["grad_action_in_proj_arms"] = grad_norm_for_parameters(model_for_logging.action_in_proj_arms.parameters()) |
| | metrics["grad_arm_token_fuse"] = grad_norm_for_parameters(model_for_logging.arm_token_fuse.parameters()) |
| | metrics["grad_action_out_proj_arms"] = grad_norm_for_parameters( |
| | model_for_logging.action_out_proj_arms.parameters() |
| | ) |
| | else: |
| | metrics["grad_action_in_proj"] = grad_norm_for_parameters(model_for_logging.action_in_proj.parameters()) |
| | metrics["grad_action_out_proj"] = grad_norm_for_parameters(model_for_logging.action_out_proj.parameters()) |
| | return metrics |
| |
|
| |
|
| | def tensor_stats(tensor: torch.Tensor) -> tuple[float, float, float, float]: |
| | tensor32 = tensor.detach().to(torch.float32) |
| | return ( |
| | float(tensor32.min().item()), |
| | float(tensor32.max().item()), |
| | float(tensor32.mean().item()), |
| | float(tensor32.std(unbiased=False).item()), |
| | ) |
| |
|
| |
|
| | def block_nonzero_counts(tensor: torch.Tensor) -> list[int]: |
| | counts = [] |
| | for start in range(0, tensor.shape[-1], 8): |
| | end = min(start + 8, tensor.shape[-1]) |
| | counts.append(int(torch.count_nonzero(tensor[..., start:end]).item())) |
| | return counts |
| |
|
| |
|
| | def masked_zero_count(tensor: torch.Tensor, masked_dims: list[int]) -> int: |
| | if not masked_dims: |
| | return 0 |
| | masked_values = tensor[..., masked_dims] |
| | return int((masked_values == 0).sum().item()) |
| |
|
| |
|
| | def cuda_memory_summary(device: torch.device) -> dict[str, float]: |
| | if not torch.cuda.is_available(): |
| | return {"allocated_gb": 0.0, "reserved_gb": 0.0, "max_allocated_gb": 0.0, "max_reserved_gb": 0.0} |
| | return { |
| | "allocated_gb": torch.cuda.memory_allocated(device) / 1e9, |
| | "reserved_gb": torch.cuda.memory_reserved(device) / 1e9, |
| | "max_allocated_gb": torch.cuda.max_memory_allocated(device) / 1e9, |
| | "max_reserved_gb": torch.cuda.max_memory_reserved(device) / 1e9, |
| | } |
| |
|
| |
|
| | def log_startup_summary( |
| | config: _config.TrainConfig, |
| | data_config: _config.DataConfig, |
| | *, |
| | world_size: int, |
| | local_batch_size: int, |
| | model_kind: str, |
| | ) -> None: |
| | norm_path = resolve_norm_stats_path(config, data_config) |
| | norm_summary = norm_stats_summary(data_config) |
| | active_dims, masked_dims = action_mask_indices(config.action_loss_mask) |
| |
|
| | logging.info(f"Resolved config name: {config.name}") |
| | logging.info(f"Dataset repo_id: {data_config.repo_id}") |
| | logging.info(f"Norm-stats file path: {norm_path}") |
| | logging.info(f"Norm-stats summary: {norm_summary}") |
| | logging.info(f"Checkpoint source path: {config.pytorch_weight_path}") |
| | logging.info(f"Model type: {model_kind}") |
| | logging.info(f"Packed transforms active: {is_packed_transform_active(data_config)}") |
| | logging.info(f"World size: {world_size}") |
| | logging.info(f"Batch size: local={local_batch_size}, global={config.batch_size}") |
| | logging.info(f"num_workers: {config.num_workers}") |
| | logging.info(f"Precision: {config.pytorch_training_precision}") |
| | logging.info( |
| | "LR schedule summary: " |
| | f"warmup_steps={config.lr_schedule.warmup_steps}, " |
| | f"peak_lr={config.lr_schedule.peak_lr:.2e}, " |
| | f"decay_steps={config.lr_schedule.decay_steps}, " |
| | f"decay_lr={config.lr_schedule.decay_lr:.2e}" |
| | ) |
| | logging.info(f"Save/log intervals: save_interval={config.save_interval}, log_interval={config.log_interval}") |
| | logging.info(f"Action-loss mask: {config.action_loss_mask}") |
| | logging.info(f"Active mask dims: {active_dims}") |
| | logging.info(f"Masked dims: {masked_dims}") |
| |
|
| |
|
| | def get_model_state_dict(model): |
| | """Get state dict from model, handling DDP wrapper.""" |
| | return unwrap_model(model).state_dict() |
| |
|
| |
|
| | def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): |
| | """Save a checkpoint with model state, optimizer state, and metadata.""" |
| | if not is_main: |
| | return |
| |
|
| | |
| | if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps: |
| | |
| | final_ckpt_dir = config.checkpoint_dir / f"{global_step}" |
| | tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" |
| |
|
| | |
| | if tmp_ckpt_dir.exists(): |
| | shutil.rmtree(tmp_ckpt_dir) |
| | tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | model_to_save = unwrap_model(model) |
| | safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") |
| |
|
| | |
| | torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") |
| |
|
| | |
| | metadata = { |
| | "global_step": global_step, |
| | "config": dataclasses.asdict(config), |
| | "timestamp": time.time(), |
| | } |
| | torch.save(metadata, tmp_ckpt_dir / "metadata.pt") |
| |
|
| | |
| | norm_stats = data_config.norm_stats |
| | if norm_stats is not None and data_config.asset_id is not None: |
| | _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) |
| |
|
| | |
| | if final_ckpt_dir.exists(): |
| | shutil.rmtree(final_ckpt_dir) |
| | tmp_ckpt_dir.rename(final_ckpt_dir) |
| |
|
| | logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") |
| |
|
| | |
| | if config.wandb_enabled: |
| | wandb = get_wandb() |
| | wandb.log({"checkpoint_step": global_step}, step=global_step) |
| |
|
| |
|
| | def load_checkpoint(model, optimizer, checkpoint_dir, device): |
| | """Load the latest checkpoint and return the global step.""" |
| | checkpoint_steps = [ |
| | int(d.name) |
| | for d in checkpoint_dir.iterdir() |
| | if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") |
| | ] |
| |
|
| | if not checkpoint_steps: |
| | raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") |
| |
|
| | latest_step = max(checkpoint_steps) |
| | ckpt_dir = checkpoint_dir / f"{latest_step}" |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | log_memory_usage(device, latest_step, "before_loading_checkpoint") |
| |
|
| | try: |
| | |
| | logging.info("Loading model state...") |
| | safetensors_path = ckpt_dir / "model.safetensors" |
| |
|
| | if safetensors_path.exists(): |
| | model_to_load = unwrap_model(model) |
| | safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) |
| | logging.info("Loaded model state from safetensors format") |
| | else: |
| | raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") |
| |
|
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | log_memory_usage(device, latest_step, "after_loading_model") |
| |
|
| | |
| | logging.info("Loading optimizer state...") |
| | optimizer_path = ckpt_dir / "optimizer.pt" |
| |
|
| | if optimizer_path.exists(): |
| | optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) |
| | logging.info("Loaded optimizer state from pt format") |
| | else: |
| | raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") |
| |
|
| | optimizer.load_state_dict(optimizer_state_dict) |
| | del optimizer_state_dict |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | log_memory_usage(device, latest_step, "after_loading_optimizer") |
| |
|
| | |
| | logging.info("Loading metadata...") |
| | metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) |
| | global_step = metadata.get("global_step", latest_step) |
| | del metadata |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | log_memory_usage(device, latest_step, "after_loading_metadata") |
| |
|
| | logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") |
| | return global_step |
| |
|
| | except RuntimeError as e: |
| | if "out of memory" in str(e): |
| | |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | logging.error(f"Out of memory error while loading checkpoint: {e!s}") |
| | log_memory_usage(device, latest_step, "after_oom_error") |
| | raise RuntimeError( |
| | "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" |
| | ) from e |
| | raise |
| |
|
| |
|
| | def get_latest_checkpoint_step(checkpoint_dir): |
| | """Get the latest checkpoint step number from a checkpoint directory.""" |
| | checkpoint_steps = [ |
| | int(d.name) |
| | for d in checkpoint_dir.iterdir() |
| | if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") |
| | ] |
| | return max(checkpoint_steps) if checkpoint_steps else None |
| |
|
| |
|
| | def log_memory_usage(device, step, phase="unknown"): |
| | """Log detailed memory usage information.""" |
| | if not torch.cuda.is_available(): |
| | return |
| |
|
| | memory_allocated = torch.cuda.memory_allocated(device) / 1e9 |
| | memory_reserved = torch.cuda.memory_reserved(device) / 1e9 |
| | memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) |
| | memory_free = memory_free / 1e9 |
| |
|
| | |
| | memory_stats = torch.cuda.memory_stats(device) |
| | max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 |
| | max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 |
| |
|
| | |
| | ddp_info = "" |
| | if dist.is_initialized(): |
| | ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" |
| |
|
| | logging.info( |
| | f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" |
| | ) |
| |
|
| |
|
| | def train_loop(config: _config.TrainConfig): |
| | use_ddp, local_rank, device = setup_ddp() |
| | is_main = (not use_ddp) or (dist.get_rank() == 0) |
| | configure_process_logging(is_main) |
| | set_seed(config.seed, local_rank) |
| |
|
| | |
| | resuming = False |
| | if is_main: |
| | if config.resume: |
| | |
| | exp_checkpoint_dir = config.checkpoint_dir |
| | if exp_checkpoint_dir.exists(): |
| | |
| | latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) |
| | if latest_step is not None: |
| | resuming = True |
| | logging.info( |
| | f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" |
| | ) |
| | else: |
| | raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") |
| | else: |
| | raise FileNotFoundError( |
| | f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume" |
| | ) |
| | elif config.overwrite and config.checkpoint_dir.exists(): |
| | shutil.rmtree(config.checkpoint_dir) |
| | logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") |
| |
|
| | |
| | if not resuming: |
| | |
| | exp_checkpoint_dir = config.checkpoint_dir |
| | exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| | logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") |
| | else: |
| | |
| | logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") |
| |
|
| | if use_ddp: |
| | dist.barrier() |
| |
|
| | |
| | if is_main: |
| | init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) |
| |
|
| | |
| | |
| | |
| | world_size = torch.distributed.get_world_size() if use_ddp else 1 |
| | if config.batch_size % world_size != 0: |
| | raise ValueError(f"Global batch size {config.batch_size} must be divisible by world_size {world_size}.") |
| | effective_batch_size = config.batch_size // world_size |
| | logging.info( |
| | f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" |
| | ) |
| |
|
| | |
| | if use_ddp: |
| | if is_main: |
| | loader, data_config = build_datasets(config) |
| | dist.barrier() |
| | else: |
| | dist.barrier() |
| | loader, data_config = build_datasets(config) |
| | else: |
| | loader, data_config = build_datasets(config) |
| |
|
| | |
| | if is_main and config.wandb_enabled and not resuming: |
| | wandb = get_wandb() |
| | |
| | sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) |
| | sample_batch = next(iter(sample_data_loader)) |
| | |
| | observation, actions = sample_batch |
| | sample_batch = observation.to_dict() |
| | sample_batch["actions"] = actions |
| |
|
| | |
| | images_to_log = [] |
| | |
| | batch_size = next(iter(sample_batch["image"].values())).shape[0] |
| | for i in range(min(5, batch_size)): |
| | |
| | |
| | img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) |
| | img_concatenated = img_concatenated.cpu().numpy() |
| | images_to_log.append(wandb.Image(img_concatenated)) |
| |
|
| | wandb.log({"camera_views": images_to_log}, step=0) |
| |
|
| | |
| | del sample_batch, observation, actions, images_to_log, img_concatenated |
| | del sample_data_loader |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | logging.info("Cleared sample batch and data loader from memory") |
| |
|
| | |
| | if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): |
| | |
| | model_cfg = openpi.models.pi0_config.Pi0Config( |
| | dtype=config.pytorch_training_precision, |
| | action_dim=config.model.action_dim, |
| | action_horizon=config.model.action_horizon, |
| | max_token_len=config.model.max_token_len, |
| | paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), |
| | action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), |
| | pi05=getattr(config.model, "pi05", False), |
| | arm_action_dims=getattr(config.model, "arm_action_dims", None), |
| | ) |
| | else: |
| | model_cfg = config.model |
| | |
| | object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) |
| |
|
| | import openpi.models_pytorch.pi0_pytorch as pi0_pytorch |
| |
|
| | model = pi0_pytorch.PI0Pytorch(model_cfg).to(device) |
| |
|
| | if hasattr(model, "gradient_checkpointing_enable"): |
| | enable_gradient_checkpointing = True |
| | model.gradient_checkpointing_enable() |
| | logging.info("Enabled gradient checkpointing for memory optimization") |
| | else: |
| | enable_gradient_checkpointing = False |
| | logging.info("Gradient checkpointing is not supported for this model") |
| |
|
| | |
| | if is_main and torch.cuda.is_available(): |
| | log_memory_usage(device, 0, "after_model_creation") |
| |
|
| | |
| | if world_size >= 8: |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | |
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" |
| | logging.info("Enabled memory optimizations for 8+ GPU training") |
| |
|
| | if use_ddp: |
| | model = torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | device_ids=[device.index] if device.type == "cuda" else None, |
| | find_unused_parameters=False, |
| | gradient_as_bucket_view=True, |
| | static_graph=True, |
| | ) |
| |
|
| | |
| | if config.pytorch_weight_path is not None: |
| | logging.info(f"Loading weights from: {config.pytorch_weight_path}") |
| |
|
| | model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") |
| | missing, unexpected = safetensors.torch.load_model(unwrap_model(model), model_path, strict=False) |
| | logging.info(f"Weight loading missing key count: {len(missing)}") |
| | logging.info(f"Weight loading missing keys: {missing}") |
| | logging.info(f"Weight loading unexpected key count: {len(unexpected)}") |
| | logging.info(f"Weight loading unexpected keys: {unexpected}") |
| | logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") |
| |
|
| | |
| | warmup_steps = config.lr_schedule.warmup_steps |
| | peak_lr = config.lr_schedule.peak_lr |
| | decay_steps = config.lr_schedule.decay_steps |
| | end_lr = config.lr_schedule.decay_lr |
| |
|
| | |
| | optim = torch.optim.AdamW( |
| | model.parameters(), |
| | lr=peak_lr, |
| | betas=(config.optimizer.b1, config.optimizer.b2), |
| | eps=config.optimizer.eps, |
| | weight_decay=config.optimizer.weight_decay, |
| | ) |
| |
|
| | |
| | global_step = 0 |
| | if resuming: |
| | global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) |
| | logging.info(f"Resumed training from step {global_step}") |
| |
|
| | def lr_schedule(step: int): |
| | if step < warmup_steps: |
| | |
| | init_lr = peak_lr / (warmup_steps + 1) |
| | return init_lr + (peak_lr - init_lr) * step / warmup_steps |
| | |
| | progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) |
| | cos = 0.5 * (1 + np.cos(np.pi * progress)) |
| | return end_lr + (peak_lr - end_lr) * cos |
| |
|
| | model.train() |
| | infos = [] |
| | interval_start_time = time.perf_counter() |
| | last_step_end = time.perf_counter() |
| | smoothed_loss = None |
| | if is_main: |
| | model_kind = model_cfg.action_expert_mode |
| | logging.info(f"Running on: {platform.node()} | world_size={world_size}") |
| | logging.info( |
| | f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" |
| | ) |
| | logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") |
| | logging.info("DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True") |
| | logging.info( |
| | f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" |
| | ) |
| | logging.info( |
| | f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" |
| | ) |
| | logging.info("EMA is not supported for PyTorch training") |
| | logging.info(f"Training precision: {model_cfg.dtype}") |
| | log_startup_summary( |
| | config, |
| | data_config, |
| | world_size=world_size, |
| | local_batch_size=effective_batch_size, |
| | model_kind=model_kind, |
| | ) |
| | logging.info( |
| | "Gradient bucket diagnostics: " |
| | + ( |
| | "action_in_proj, action_out_proj, shared_expert" |
| | if not model_cfg.use_parallel_action_heads |
| | else ( |
| | "left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm" |
| | if model_cfg.use_split_action_expert |
| | else "action_in_proj_arms, arm_token_fuse, action_out_proj_arms, shared_expert" |
| | ) |
| | ) |
| | ) |
| |
|
| | |
| | pbar = ( |
| | tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) |
| | if is_main |
| | else None |
| | ) |
| |
|
| | while global_step < config.num_train_steps: |
| | |
| | if use_ddp and hasattr(loader, "set_epoch"): |
| | loader.set_epoch(global_step // len(loader)) |
| |
|
| | for observation, actions in loader: |
| | |
| | if global_step >= config.num_train_steps: |
| | break |
| |
|
| | data_ready_time = time.perf_counter() |
| | data_time = data_ready_time - last_step_end |
| | step_start_time = data_ready_time |
| | step_index = global_step |
| | completed_step = step_index + 1 |
| |
|
| | |
| | observation = jax.tree.map(lambda x: x.to(device), observation) |
| | actions = actions.to(torch.float32) |
| | actions = actions.to(device) |
| |
|
| | |
| | for pg in optim.param_groups: |
| | pg["lr"] = lr_schedule(global_step) |
| |
|
| | |
| | losses = model(observation, actions) |
| | |
| | if isinstance(losses, list | tuple): |
| | losses = torch.stack(losses) |
| | elif not isinstance(losses, torch.Tensor): |
| | losses = torch.tensor(losses, device=device, dtype=torch.float32) |
| |
|
| | loss, _ = compute_masked_action_loss(losses, config.action_loss_mask) |
| | if not torch.isfinite(loss): |
| | raise FloatingPointError( |
| | f"Non-finite loss detected at step {step_index}: loss={float(loss.detach().item())}" |
| | ) |
| |
|
| | |
| | loss.backward() |
| |
|
| | should_log_grad_buckets = is_main and ( |
| | completed_step <= 5 or completed_step % config.log_interval == 0 |
| | ) |
| | grad_bucket_metrics = collect_gradient_bucket_norms(model) if should_log_grad_buckets else {} |
| |
|
| | |
| | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) |
| | grad_norm_value = float(grad_norm.item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm) |
| |
|
| | |
| | optim.step() |
| | optim.zero_grad(set_to_none=True) |
| |
|
| | |
| | for param in model.parameters(): |
| | if param.grad is not None: |
| | param.grad.detach_() |
| | param.grad = None |
| |
|
| | step_end_time = time.perf_counter() |
| | step_time = step_end_time - step_start_time |
| | last_step_end = step_end_time |
| | loss_value = float(loss.item()) |
| |
|
| | |
| | if is_main: |
| | smoothed_loss = loss_value if smoothed_loss is None else 0.9 * smoothed_loss + 0.1 * loss_value |
| | infos.append( |
| | { |
| | "loss": loss_value, |
| | "smoothed_loss": smoothed_loss, |
| | "learning_rate": optim.param_groups[0]["lr"], |
| | "grad_norm": grad_norm_value, |
| | "step_time": step_time, |
| | "data_time": data_time, |
| | } |
| | ) |
| |
|
| | if is_main and completed_step <= 5: |
| | active_dims, masked_dims = action_mask_indices(config.action_loss_mask) |
| | prompt_lengths = [] |
| | if observation.tokenized_prompt_mask is not None: |
| | prompt_lengths = observation.tokenized_prompt_mask.sum(dim=1).tolist() |
| | image_shapes = {k: tuple(v.shape) for k, v in observation.images.items()} |
| | state_min, state_max, state_mean, state_std = tensor_stats(observation.state) |
| | action_min, action_max, action_mean, action_std = tensor_stats(actions) |
| | memory = cuda_memory_summary(device) |
| | logging.info( |
| | f"debug_step={completed_step} observation.state shape={tuple(observation.state.shape)} dtype={observation.state.dtype} " |
| | f"actions shape={tuple(actions.shape)} dtype={actions.dtype}" |
| | ) |
| | logging.info( |
| | f"debug_step={completed_step} image_keys={list(observation.images.keys())} image_shapes={image_shapes}" |
| | ) |
| | logging.info(f"debug_step={completed_step} prompt_token_lengths={prompt_lengths}") |
| | logging.info( |
| | f"debug_step={completed_step} state_stats min={state_min:.4f} max={state_max:.4f} mean={state_mean:.4f} std={state_std:.4f}" |
| | ) |
| | logging.info( |
| | f"debug_step={completed_step} action_stats min={action_min:.4f} max={action_max:.4f} mean={action_mean:.4f} std={action_std:.4f}" |
| | ) |
| | logging.info( |
| | f"debug_step={completed_step} state_nonzero_counts_8d_blocks={block_nonzero_counts(observation.state)} " |
| | f"action_nonzero_counts_8d_blocks={block_nonzero_counts(actions)}" |
| | ) |
| | logging.info( |
| | f"debug_step={completed_step} masked_dims={masked_dims} active_dims={active_dims} " |
| | f"masked_zero_counts state={masked_zero_count(observation.state, masked_dims)} " |
| | f"actions={masked_zero_count(actions, masked_dims)}" |
| | ) |
| | logging.info( |
| | f"debug_step={completed_step} lr={optim.param_groups[0]['lr']:.2e} grad_norm={grad_norm_value:.4f} " |
| | f"data_time={data_time:.4f}s step_time={step_time:.4f}s " |
| | f"gpu_mem_allocated={memory['allocated_gb']:.2f}GB gpu_mem_reserved={memory['reserved_gb']:.2f}GB " |
| | f"gpu_mem_max_allocated={memory['max_allocated_gb']:.2f}GB gpu_mem_max_reserved={memory['max_reserved_gb']:.2f}GB" |
| | ) |
| | if grad_bucket_metrics: |
| | grad_metrics_text = " ".join(f"{key}={value:.4f}" for key, value in grad_bucket_metrics.items()) |
| | logging.info(f"debug_step={completed_step} {grad_metrics_text}") |
| |
|
| | global_step = completed_step |
| |
|
| | if is_main and (global_step % config.log_interval == 0): |
| | interval_elapsed = time.perf_counter() - interval_start_time |
| | avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) |
| | avg_grad_norm = sum(info["grad_norm"] for info in infos) / len(infos) |
| | avg_step_time = sum(info["step_time"] for info in infos) / len(infos) |
| | avg_data_time = sum(info["data_time"] for info in infos) / len(infos) |
| | items_per_second = len(infos) / interval_elapsed if interval_elapsed > 0 else float("inf") |
| | eta_seconds = ( |
| | max(config.num_train_steps - global_step, 0) / items_per_second |
| | if items_per_second > 0 |
| | else float("inf") |
| | ) |
| | memory = cuda_memory_summary(device) |
| | grad_metrics_text = " ".join( |
| | f"{key}={value:.4f}" for key, value in sorted(grad_bucket_metrics.items()) |
| | ) |
| | logging.info( |
| | f"step={global_step} loss={loss_value:.4f} smoothed_loss={smoothed_loss:.4f} " |
| | f"lr={avg_lr:.2e} grad_norm={avg_grad_norm:.4f} step_time={avg_step_time:.4f}s " |
| | f"data_time={avg_data_time:.4f}s it/s={items_per_second:.3f} eta_to_{config.num_train_steps}={eta_seconds:.1f}s " |
| | f"max_cuda_memory={memory['max_allocated_gb']:.2f}GB " |
| | f"{grad_metrics_text}".rstrip() |
| | ) |
| |
|
| | if config.wandb_enabled and len(infos) > 0: |
| | wandb = get_wandb() |
| | log_payload = { |
| | "loss": loss_value, |
| | "smoothed_loss": smoothed_loss, |
| | "learning_rate": avg_lr, |
| | "step": global_step, |
| | "time_per_step": avg_step_time, |
| | "data_time": avg_data_time, |
| | "items_per_second": items_per_second, |
| | "eta_seconds": eta_seconds, |
| | "max_cuda_memory_gb": memory["max_allocated_gb"], |
| | } |
| | log_payload["grad_norm"] = avg_grad_norm |
| | log_payload.update(grad_bucket_metrics) |
| | wandb.log(log_payload, step=global_step) |
| |
|
| | interval_start_time = time.perf_counter() |
| | infos = [] |
| |
|
| | |
| | save_checkpoint(model, optim, global_step, config, is_main, data_config) |
| |
|
| | |
| | if pbar is not None: |
| | pbar.update(1) |
| | pbar.set_postfix( |
| | {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step} |
| | ) |
| |
|
| | |
| | if pbar is not None: |
| | pbar.close() |
| |
|
| | |
| | if is_main and config.wandb_enabled: |
| | wandb = get_wandb() |
| | wandb.finish() |
| |
|
| | cleanup_ddp() |
| |
|
| |
|
| | def main(): |
| | init_logging() |
| | config = _config.cli() |
| | train_loop(config) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|