lsnu's picture
Add files using upload-large-folder tool
ccf25b1 verified
"""
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
Usage
Single GPU:
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
Example:
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
Multi-GPU (single node):
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
Example:
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
Multi-Node Training:
torchrun \
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
--master_addr=<master_ip> --master_port=<port> \
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
"""
import dataclasses
import gc
import logging
import os
import platform
import shutil
import time
import jax
import numpy as np
import safetensors.torch
import torch
import torch.distributed as dist
import torch.nn.parallel
import tqdm
import openpi.models.pi0_config
import openpi.shared.normalize as _normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data
import openpi.transforms as _transforms
def get_wandb():
import wandb
return wandb
def init_logging():
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
else:
logger.handlers[0].setFormatter(formatter)
def configure_process_logging(is_main: bool) -> None:
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO if is_main else logging.WARNING)
# External libraries emit noisy startup INFO logs on every rank and can block
# torchrun worker pipes before the first train step.
logging.getLogger("jax").setLevel(logging.WARNING)
logging.getLogger("jaxlib").setLevel(logging.WARNING)
logging.getLogger("absl").setLevel(logging.WARNING)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
"""Initialize wandb logging."""
if not enabled:
return
wandb = get_wandb()
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
def setup_ddp():
world_size = int(os.environ.get("WORLD_SIZE", "1"))
use_ddp = world_size > 1
if use_ddp and not torch.distributed.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
torch.distributed.init_process_group(backend=backend, init_method="env://")
# Set up debugging environment variables for DDP issues
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.set_device(device)
return use_ddp, local_rank, device
def cleanup_ddp():
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.distributed.destroy_process_group()
def set_seed(seed: int, local_rank: int):
torch.manual_seed(seed + local_rank)
np.random.seed(seed + local_rank)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed + local_rank)
def build_datasets(config: _config.TrainConfig):
# Use the unified data loader with PyTorch framework
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
return data_loader, data_loader.data_config()
def resolve_norm_stats_path(config: _config.TrainConfig, data_config: _config.DataConfig):
if data_config.asset_id is None:
return None
return config.assets_dirs / data_config.asset_id / "norm_stats.json"
def norm_stats_summary(data_config: _config.DataConfig) -> dict[str, object]:
summary: dict[str, object] = {"keys": []}
if data_config.norm_stats is None:
return summary
summary["keys"] = sorted(data_config.norm_stats.keys())
for key in ("state", "actions"):
stats = data_config.norm_stats.get(key)
if stats is None:
continue
summary[f"{key}_mean_len"] = int(stats.mean.shape[-1])
summary[f"{key}_std_len"] = int(stats.std.shape[-1])
return summary
def action_mask_indices(action_loss_mask: tuple[float, ...] | None) -> tuple[list[int], list[int]]:
if action_loss_mask is None:
return [], []
active = [i for i, value in enumerate(action_loss_mask) if value != 0]
masked = [i for i, value in enumerate(action_loss_mask) if value == 0]
return active, masked
def is_packed_transform_active(data_config: _config.DataConfig) -> bool:
return any(isinstance(transform, _transforms.PackPerArmBlocks) for transform in data_config.model_transforms.inputs)
def compute_masked_action_loss(
losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None
) -> tuple[torch.Tensor, torch.Tensor | None]:
if losses.ndim != 3:
raise ValueError(f"Expected losses with shape [B, H, D], got {tuple(losses.shape)}.")
if action_loss_mask is None:
return losses.mean(), None
mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype)
if mask.ndim != 1 or mask.shape[0] != losses.shape[-1]:
raise ValueError(
f"Action loss mask must be 1D with length {losses.shape[-1]}, got shape {tuple(mask.shape)}."
)
denom = mask.sum() * losses.shape[0] * losses.shape[1]
if float(denom.item()) <= 0:
raise ValueError("Action loss mask must include at least one active dimension.")
return (losses * mask.view(1, 1, -1)).sum() / denom, mask
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
return model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
def grad_norm_for_parameters(parameters) -> float:
total_sq = None
for param in parameters:
if param.grad is None:
continue
grad_sq = torch.sum(param.grad.detach().to(torch.float32) ** 2)
total_sq = grad_sq if total_sq is None else total_sq + grad_sq
if total_sq is None:
return 0.0
return float(torch.sqrt(total_sq).item())
def collect_gradient_bucket_norms(model: torch.nn.Module) -> dict[str, float]:
model_for_logging = unwrap_model(model)
if model_for_logging.use_split_action_expert:
metrics = {
"grad_shared_backbone": grad_norm_for_parameters(
model_for_logging.paligemma_with_expert.paligemma.parameters()
),
"grad_left_action_in": grad_norm_for_parameters(model_for_logging.action_in_proj_arms[0].parameters()),
"grad_right_action_in": grad_norm_for_parameters(model_for_logging.action_in_proj_arms[1].parameters()),
"grad_left_expert": grad_norm_for_parameters(
model_for_logging.paligemma_with_expert.left_gemma_expert.parameters()
),
"grad_right_expert": grad_norm_for_parameters(
model_for_logging.paligemma_with_expert.right_gemma_expert.parameters()
),
"grad_action_out": grad_norm_for_parameters(model_for_logging.action_out_proj_arms.parameters()),
}
if model_for_logging.use_communicating_action_expert:
metrics["grad_cross_arm_comm"] = grad_norm_for_parameters(
[model_for_logging.paligemma_with_expert.cross_arm_comm]
)
for layer_idx, gate_value in enumerate(model_for_logging.paligemma_with_expert.cross_arm_comm.detach().cpu()):
metrics[f"cross_arm_comm_gate_layer_{layer_idx}"] = float(gate_value.item())
if model_for_logging.paligemma_with_expert.latest_cross_arm_attention_mass is not None:
for layer_idx, attn_mass in enumerate(
model_for_logging.paligemma_with_expert.latest_cross_arm_attention_mass.detach().cpu()
):
metrics[f"cross_arm_attention_mass_layer_{layer_idx}"] = float(attn_mass.item())
return metrics
metrics = {"grad_shared_expert": grad_norm_for_parameters(model_for_logging.paligemma_with_expert.parameters())}
if model_for_logging.use_parallel_action_heads:
metrics["grad_action_in_proj_arms"] = grad_norm_for_parameters(model_for_logging.action_in_proj_arms.parameters())
metrics["grad_arm_token_fuse"] = grad_norm_for_parameters(model_for_logging.arm_token_fuse.parameters())
metrics["grad_action_out_proj_arms"] = grad_norm_for_parameters(
model_for_logging.action_out_proj_arms.parameters()
)
else:
metrics["grad_action_in_proj"] = grad_norm_for_parameters(model_for_logging.action_in_proj.parameters())
metrics["grad_action_out_proj"] = grad_norm_for_parameters(model_for_logging.action_out_proj.parameters())
return metrics
def tensor_stats(tensor: torch.Tensor) -> tuple[float, float, float, float]:
tensor32 = tensor.detach().to(torch.float32)
return (
float(tensor32.min().item()),
float(tensor32.max().item()),
float(tensor32.mean().item()),
float(tensor32.std(unbiased=False).item()),
)
def block_nonzero_counts(tensor: torch.Tensor) -> list[int]:
counts = []
for start in range(0, tensor.shape[-1], 8):
end = min(start + 8, tensor.shape[-1])
counts.append(int(torch.count_nonzero(tensor[..., start:end]).item()))
return counts
def masked_zero_count(tensor: torch.Tensor, masked_dims: list[int]) -> int:
if not masked_dims:
return 0
masked_values = tensor[..., masked_dims]
return int((masked_values == 0).sum().item())
def cuda_memory_summary(device: torch.device) -> dict[str, float]:
if not torch.cuda.is_available():
return {"allocated_gb": 0.0, "reserved_gb": 0.0, "max_allocated_gb": 0.0, "max_reserved_gb": 0.0}
return {
"allocated_gb": torch.cuda.memory_allocated(device) / 1e9,
"reserved_gb": torch.cuda.memory_reserved(device) / 1e9,
"max_allocated_gb": torch.cuda.max_memory_allocated(device) / 1e9,
"max_reserved_gb": torch.cuda.max_memory_reserved(device) / 1e9,
}
def log_startup_summary(
config: _config.TrainConfig,
data_config: _config.DataConfig,
*,
world_size: int,
local_batch_size: int,
model_kind: str,
) -> None:
norm_path = resolve_norm_stats_path(config, data_config)
norm_summary = norm_stats_summary(data_config)
active_dims, masked_dims = action_mask_indices(config.action_loss_mask)
logging.info(f"Resolved config name: {config.name}")
logging.info(f"Dataset repo_id: {data_config.repo_id}")
logging.info(f"Norm-stats file path: {norm_path}")
logging.info(f"Norm-stats summary: {norm_summary}")
logging.info(f"Checkpoint source path: {config.pytorch_weight_path}")
logging.info(f"Model type: {model_kind}")
logging.info(f"Packed transforms active: {is_packed_transform_active(data_config)}")
logging.info(f"World size: {world_size}")
logging.info(f"Batch size: local={local_batch_size}, global={config.batch_size}")
logging.info(f"num_workers: {config.num_workers}")
logging.info(f"Precision: {config.pytorch_training_precision}")
logging.info(
"LR schedule summary: "
f"warmup_steps={config.lr_schedule.warmup_steps}, "
f"peak_lr={config.lr_schedule.peak_lr:.2e}, "
f"decay_steps={config.lr_schedule.decay_steps}, "
f"decay_lr={config.lr_schedule.decay_lr:.2e}"
)
logging.info(f"Save/log intervals: save_interval={config.save_interval}, log_interval={config.log_interval}")
logging.info(f"Action-loss mask: {config.action_loss_mask}")
logging.info(f"Active mask dims: {active_dims}")
logging.info(f"Masked dims: {masked_dims}")
def get_model_state_dict(model):
"""Get state dict from model, handling DDP wrapper."""
return unwrap_model(model).state_dict()
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
"""Save a checkpoint with model state, optimizer state, and metadata."""
if not is_main:
return
# Only save if it's time to save or if it's the final step
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps:
# Create temporary directory for atomic checkpoint saving
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
# Remove any existing temp directory and create new one
if tmp_ckpt_dir.exists():
shutil.rmtree(tmp_ckpt_dir)
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
# Save model state using safetensors (handle shared tensors)
model_to_save = unwrap_model(model)
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
# Save optimizer state using PyTorch format
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
metadata = {
"global_step": global_step,
"config": dataclasses.asdict(config),
"timestamp": time.time(),
}
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
# save norm stats
norm_stats = data_config.norm_stats
if norm_stats is not None and data_config.asset_id is not None:
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
# Atomically move temp directory to final location
if final_ckpt_dir.exists():
shutil.rmtree(final_ckpt_dir)
tmp_ckpt_dir.rename(final_ckpt_dir)
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
# Log checkpoint to wandb
if config.wandb_enabled:
wandb = get_wandb()
wandb.log({"checkpoint_step": global_step}, step=global_step)
def load_checkpoint(model, optimizer, checkpoint_dir, device):
"""Load the latest checkpoint and return the global step."""
checkpoint_steps = [
int(d.name)
for d in checkpoint_dir.iterdir()
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
]
if not checkpoint_steps:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
latest_step = max(checkpoint_steps)
ckpt_dir = checkpoint_dir / f"{latest_step}"
# Clear memory before loading checkpoints
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "before_loading_checkpoint")
try:
# Load model state with error handling
logging.info("Loading model state...")
safetensors_path = ckpt_dir / "model.safetensors"
if safetensors_path.exists():
model_to_load = unwrap_model(model)
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
logging.info("Loaded model state from safetensors format")
else:
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "after_loading_model")
# Load optimizer state with error handling
logging.info("Loading optimizer state...")
optimizer_path = ckpt_dir / "optimizer.pt"
if optimizer_path.exists():
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
logging.info("Loaded optimizer state from pt format")
else:
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "after_loading_optimizer")
# Load metadata
logging.info("Loading metadata...")
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
global_step = metadata.get("global_step", latest_step)
del metadata
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "after_loading_metadata")
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
return global_step
except RuntimeError as e:
if "out of memory" in str(e):
# Clear memory and provide detailed error message
torch.cuda.empty_cache()
gc.collect()
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
log_memory_usage(device, latest_step, "after_oom_error")
raise RuntimeError(
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
) from e
raise
def get_latest_checkpoint_step(checkpoint_dir):
"""Get the latest checkpoint step number from a checkpoint directory."""
checkpoint_steps = [
int(d.name)
for d in checkpoint_dir.iterdir()
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
]
return max(checkpoint_steps) if checkpoint_steps else None
def log_memory_usage(device, step, phase="unknown"):
"""Log detailed memory usage information."""
if not torch.cuda.is_available():
return
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
memory_free = memory_free / 1e9
# Get more detailed memory info
memory_stats = torch.cuda.memory_stats(device)
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
# Get DDP info if available
ddp_info = ""
if dist.is_initialized():
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
logging.info(
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
)
def train_loop(config: _config.TrainConfig):
use_ddp, local_rank, device = setup_ddp()
is_main = (not use_ddp) or (dist.get_rank() == 0)
configure_process_logging(is_main)
set_seed(config.seed, local_rank)
# Initialize checkpoint directory and wandb
resuming = False
if is_main:
if config.resume:
# Find checkpoint directory based on experiment name
exp_checkpoint_dir = config.checkpoint_dir
if exp_checkpoint_dir.exists():
# Use validation to find the latest working checkpoint
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
if latest_step is not None:
resuming = True
logging.info(
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
)
else:
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
else:
raise FileNotFoundError(
f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume"
)
elif config.overwrite and config.checkpoint_dir.exists():
shutil.rmtree(config.checkpoint_dir)
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
# Create checkpoint directory with experiment name
if not resuming:
# For new runs, create experiment-specific checkpoint directory
exp_checkpoint_dir = config.checkpoint_dir
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
else:
# For resume, checkpoint_dir is already set to the experiment directory
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
if use_ddp:
dist.barrier()
# Initialize wandb (only on main process)
if is_main:
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
# Build data loader using the unified data loader
# Calculate effective batch size per GPU for DDP
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
world_size = torch.distributed.get_world_size() if use_ddp else 1
if config.batch_size % world_size != 0:
raise ValueError(f"Global batch size {config.batch_size} must be divisible by world_size {world_size}.")
effective_batch_size = config.batch_size // world_size
logging.info(
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
)
# Pass the original batch size to data loader - it will handle DDP splitting internally
if use_ddp:
if is_main:
loader, data_config = build_datasets(config)
dist.barrier()
else:
dist.barrier()
loader, data_config = build_datasets(config)
else:
loader, data_config = build_datasets(config)
# Log sample images to wandb on first batch
if is_main and config.wandb_enabled and not resuming:
wandb = get_wandb()
# Create a separate data loader for sample batch to avoid consuming the main loader
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
sample_batch = next(iter(sample_data_loader))
# Convert observation and actions to torch tensors
observation, actions = sample_batch
sample_batch = observation.to_dict()
sample_batch["actions"] = actions
# Create sample images for wandb
images_to_log = []
# Get batch size from the first image tensor
batch_size = next(iter(sample_batch["image"].values())).shape[0]
for i in range(min(5, batch_size)):
# Concatenate all camera views horizontally for this batch item
# Convert from NCHW to NHWC format for wandb
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
img_concatenated = img_concatenated.cpu().numpy()
images_to_log.append(wandb.Image(img_concatenated))
wandb.log({"camera_views": images_to_log}, step=0)
# Clear sample batch from memory aggressively
del sample_batch, observation, actions, images_to_log, img_concatenated
del sample_data_loader # Also delete the sample data loader
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logging.info("Cleared sample batch and data loader from memory")
# Build model
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
# Convert dataclass to Pi0Config if needed
model_cfg = openpi.models.pi0_config.Pi0Config(
dtype=config.pytorch_training_precision,
action_dim=config.model.action_dim,
action_horizon=config.model.action_horizon,
max_token_len=config.model.max_token_len,
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
arm_action_dims=getattr(config.model, "arm_action_dims", None),
)
else:
model_cfg = config.model
# Update dtype to match pytorch_training_precision
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
import openpi.models_pytorch.pi0_pytorch as pi0_pytorch
model = pi0_pytorch.PI0Pytorch(model_cfg).to(device)
if hasattr(model, "gradient_checkpointing_enable"):
enable_gradient_checkpointing = True
model.gradient_checkpointing_enable()
logging.info("Enabled gradient checkpointing for memory optimization")
else:
enable_gradient_checkpointing = False
logging.info("Gradient checkpointing is not supported for this model")
# Log initial memory usage after model creation
if is_main and torch.cuda.is_available():
log_memory_usage(device, 0, "after_model_creation")
# Enable memory optimizations for large-scale training
if world_size >= 8:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set memory allocation configuration
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
logging.info("Enabled memory optimizations for 8+ GPU training")
if use_ddp:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[device.index] if device.type == "cuda" else None,
find_unused_parameters=False,
gradient_as_bucket_view=True, # Enable for memory efficiency
static_graph=True,
)
# Load weights from weight_loader if specified (for fine-tuning)
if config.pytorch_weight_path is not None:
logging.info(f"Loading weights from: {config.pytorch_weight_path}")
model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
missing, unexpected = safetensors.torch.load_model(unwrap_model(model), model_path, strict=False)
logging.info(f"Weight loading missing key count: {len(missing)}")
logging.info(f"Weight loading missing keys: {missing}")
logging.info(f"Weight loading unexpected key count: {len(unexpected)}")
logging.info(f"Weight loading unexpected keys: {unexpected}")
logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
# Optimizer + learning rate schedule from config
warmup_steps = config.lr_schedule.warmup_steps
peak_lr = config.lr_schedule.peak_lr
decay_steps = config.lr_schedule.decay_steps
end_lr = config.lr_schedule.decay_lr
# Create optimizer with config parameters
optim = torch.optim.AdamW(
model.parameters(),
lr=peak_lr,
betas=(config.optimizer.b1, config.optimizer.b2),
eps=config.optimizer.eps,
weight_decay=config.optimizer.weight_decay,
)
# Load checkpoint if resuming
global_step = 0
if resuming:
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
logging.info(f"Resumed training from step {global_step}")
def lr_schedule(step: int):
if step < warmup_steps:
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
init_lr = peak_lr / (warmup_steps + 1)
return init_lr + (peak_lr - init_lr) * step / warmup_steps
# cosine decay
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
cos = 0.5 * (1 + np.cos(np.pi * progress))
return end_lr + (peak_lr - end_lr) * cos
model.train()
infos = []
interval_start_time = time.perf_counter()
last_step_end = time.perf_counter()
smoothed_loss = None
if is_main:
model_kind = model_cfg.action_expert_mode
logging.info(f"Running on: {platform.node()} | world_size={world_size}")
logging.info(
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
)
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
logging.info("DDP settings: find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True")
logging.info(
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
)
logging.info(
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
)
logging.info("EMA is not supported for PyTorch training")
logging.info(f"Training precision: {model_cfg.dtype}")
log_startup_summary(
config,
data_config,
world_size=world_size,
local_batch_size=effective_batch_size,
model_kind=model_kind,
)
logging.info(
"Gradient bucket diagnostics: "
+ (
"action_in_proj, action_out_proj, shared_expert"
if not model_cfg.use_parallel_action_heads
else (
"left_action_in, right_action_in, left_expert, right_expert, action_out, cross_arm_comm"
if model_cfg.use_split_action_expert
else "action_in_proj_arms, arm_token_fuse, action_out_proj_arms, shared_expert"
)
)
)
# Training loop - iterate until we reach num_train_steps
pbar = (
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
if is_main
else None
)
while global_step < config.num_train_steps:
# Set epoch for distributed training
if use_ddp and hasattr(loader, "set_epoch"):
loader.set_epoch(global_step // len(loader))
for observation, actions in loader:
# Check if we've reached the target number of steps
if global_step >= config.num_train_steps:
break
data_ready_time = time.perf_counter()
data_time = data_ready_time - last_step_end
step_start_time = data_ready_time
step_index = global_step
completed_step = step_index + 1
# The unified data loader returns (observation, actions) tuple
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
actions = actions.to(torch.float32) # noqa: PLW2901
actions = actions.to(device) # noqa: PLW2901
# Update LR
for pg in optim.param_groups:
pg["lr"] = lr_schedule(global_step)
# Forward pass
losses = model(observation, actions)
# Ensure losses is a tensor and handle different return types
if isinstance(losses, list | tuple):
losses = torch.stack(losses)
elif not isinstance(losses, torch.Tensor):
losses = torch.tensor(losses, device=device, dtype=torch.float32)
loss, _ = compute_masked_action_loss(losses, config.action_loss_mask)
if not torch.isfinite(loss):
raise FloatingPointError(
f"Non-finite loss detected at step {step_index}: loss={float(loss.detach().item())}"
)
# Backward pass
loss.backward()
should_log_grad_buckets = is_main and (
completed_step <= 5 or completed_step % config.log_interval == 0
)
grad_bucket_metrics = collect_gradient_bucket_norms(model) if should_log_grad_buckets else {}
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
grad_norm_value = float(grad_norm.item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm)
# Optimizer step
optim.step()
optim.zero_grad(set_to_none=True)
# Clear gradients more aggressively
for param in model.parameters():
if param.grad is not None:
param.grad.detach_()
param.grad = None
step_end_time = time.perf_counter()
step_time = step_end_time - step_start_time
last_step_end = step_end_time
loss_value = float(loss.item())
# Collect stats
if is_main:
smoothed_loss = loss_value if smoothed_loss is None else 0.9 * smoothed_loss + 0.1 * loss_value
infos.append(
{
"loss": loss_value,
"smoothed_loss": smoothed_loss,
"learning_rate": optim.param_groups[0]["lr"],
"grad_norm": grad_norm_value,
"step_time": step_time,
"data_time": data_time,
}
)
if is_main and completed_step <= 5:
active_dims, masked_dims = action_mask_indices(config.action_loss_mask)
prompt_lengths = []
if observation.tokenized_prompt_mask is not None:
prompt_lengths = observation.tokenized_prompt_mask.sum(dim=1).tolist()
image_shapes = {k: tuple(v.shape) for k, v in observation.images.items()}
state_min, state_max, state_mean, state_std = tensor_stats(observation.state)
action_min, action_max, action_mean, action_std = tensor_stats(actions)
memory = cuda_memory_summary(device)
logging.info(
f"debug_step={completed_step} observation.state shape={tuple(observation.state.shape)} dtype={observation.state.dtype} "
f"actions shape={tuple(actions.shape)} dtype={actions.dtype}"
)
logging.info(
f"debug_step={completed_step} image_keys={list(observation.images.keys())} image_shapes={image_shapes}"
)
logging.info(f"debug_step={completed_step} prompt_token_lengths={prompt_lengths}")
logging.info(
f"debug_step={completed_step} state_stats min={state_min:.4f} max={state_max:.4f} mean={state_mean:.4f} std={state_std:.4f}"
)
logging.info(
f"debug_step={completed_step} action_stats min={action_min:.4f} max={action_max:.4f} mean={action_mean:.4f} std={action_std:.4f}"
)
logging.info(
f"debug_step={completed_step} state_nonzero_counts_8d_blocks={block_nonzero_counts(observation.state)} "
f"action_nonzero_counts_8d_blocks={block_nonzero_counts(actions)}"
)
logging.info(
f"debug_step={completed_step} masked_dims={masked_dims} active_dims={active_dims} "
f"masked_zero_counts state={masked_zero_count(observation.state, masked_dims)} "
f"actions={masked_zero_count(actions, masked_dims)}"
)
logging.info(
f"debug_step={completed_step} lr={optim.param_groups[0]['lr']:.2e} grad_norm={grad_norm_value:.4f} "
f"data_time={data_time:.4f}s step_time={step_time:.4f}s "
f"gpu_mem_allocated={memory['allocated_gb']:.2f}GB gpu_mem_reserved={memory['reserved_gb']:.2f}GB "
f"gpu_mem_max_allocated={memory['max_allocated_gb']:.2f}GB gpu_mem_max_reserved={memory['max_reserved_gb']:.2f}GB"
)
if grad_bucket_metrics:
grad_metrics_text = " ".join(f"{key}={value:.4f}" for key, value in grad_bucket_metrics.items())
logging.info(f"debug_step={completed_step} {grad_metrics_text}")
global_step = completed_step
if is_main and (global_step % config.log_interval == 0):
interval_elapsed = time.perf_counter() - interval_start_time
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
avg_grad_norm = sum(info["grad_norm"] for info in infos) / len(infos)
avg_step_time = sum(info["step_time"] for info in infos) / len(infos)
avg_data_time = sum(info["data_time"] for info in infos) / len(infos)
items_per_second = len(infos) / interval_elapsed if interval_elapsed > 0 else float("inf")
eta_seconds = (
max(config.num_train_steps - global_step, 0) / items_per_second
if items_per_second > 0
else float("inf")
)
memory = cuda_memory_summary(device)
grad_metrics_text = " ".join(
f"{key}={value:.4f}" for key, value in sorted(grad_bucket_metrics.items())
)
logging.info(
f"step={global_step} loss={loss_value:.4f} smoothed_loss={smoothed_loss:.4f} "
f"lr={avg_lr:.2e} grad_norm={avg_grad_norm:.4f} step_time={avg_step_time:.4f}s "
f"data_time={avg_data_time:.4f}s it/s={items_per_second:.3f} eta_to_{config.num_train_steps}={eta_seconds:.1f}s "
f"max_cuda_memory={memory['max_allocated_gb']:.2f}GB "
f"{grad_metrics_text}".rstrip()
)
if config.wandb_enabled and len(infos) > 0:
wandb = get_wandb()
log_payload = {
"loss": loss_value,
"smoothed_loss": smoothed_loss,
"learning_rate": avg_lr,
"step": global_step,
"time_per_step": avg_step_time,
"data_time": avg_data_time,
"items_per_second": items_per_second,
"eta_seconds": eta_seconds,
"max_cuda_memory_gb": memory["max_allocated_gb"],
}
log_payload["grad_norm"] = avg_grad_norm
log_payload.update(grad_bucket_metrics)
wandb.log(log_payload, step=global_step)
interval_start_time = time.perf_counter()
infos = []
# Save checkpoint using the new mechanism
save_checkpoint(model, optim, global_step, config, is_main, data_config)
# Update progress bar
if pbar is not None:
pbar.update(1)
pbar.set_postfix(
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
)
# Close progress bar
if pbar is not None:
pbar.close()
# Finish wandb run
if is_main and config.wandb_enabled:
wandb = get_wandb()
wandb.finish()
cleanup_ddp()
def main():
init_logging()
config = _config.cli()
train_loop(config)
if __name__ == "__main__":
main()