# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by [Jinhui YE / HKUST University] in [2025]. import sys sys.path.append("/mnt/data/fangyu/code/reward_new") """ StarVLA’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack. Conventions: 1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc). 2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures. 3. Put each training strategy in its own `trainer_*.py` file (avoid large if‑else chains). """ import warnings warnings.filterwarnings("ignore") # Standard Library import argparse import json import os os.environ["WANDB_API_KEY"] = "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" from pathlib import Path from typing import Tuple from torch.utils.data import Dataset, DataLoader import numpy as np import time import glob import re # Third-Party Libraries import torch import torch.distributed as dist import wandb import yaml from accelerate import Accelerator, DeepSpeedPlugin from accelerate.logging import get_logger from accelerate.utils import set_seed, DistributedType from omegaconf import OmegaConf from tqdm import tqdm from transformers import AutoProcessor, get_scheduler # Local Modules from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args from starVLA.model.framework import build_framework from starVLA.training.trainer_utils.trainer_tools import TrainerUtils from starVLA.training.trainer_utils.trainer_tools import build_param_lr_groups deepspeed_plugin = DeepSpeedPlugin() accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) accelerator.print(accelerator.state) # Sane Defaults os.environ["TOKENIZERS_PARALLELISM"] = "false" # Initialize Overwatch =>> Wraps `logging.Logger` from accelerate.logging import get_logger logger = get_logger(__name__) def load_fast_tokenizer(): fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) return fast_tokenizer def setup_directories(cfg) -> Path: """create output directory and save config""" cfg.output_dir = os.path.join(cfg.run_root_dir, cfg.run_id) output_dir = Path(cfg.output_dir) if not dist.is_initialized() or dist.get_rank() == 0: # create output directory and checkpoint directory os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir / "checkpoints", exist_ok=True) # save config OmegaConf.save(cfg, output_dir / "config.yaml") with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json: yaml_cfg = yaml.safe_load(f_yaml) json.dump(yaml_cfg, f_json, indent=2) return output_dir def build_model(cfg) -> torch.nn.Module: """build model framework""" logger.info(f"Loading Base VLM `{cfg.framework.qwenvl.base_vlm}` from ID/Path") model = build_framework(cfg) return model # here changes need to 📦 encapsulate Dataloader from starVLA.dataloader import build_dataloader def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]: """prepare training data""" # VLA data loader logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.data_mix}`") vla_train_dataloader = build_dataloader(cfg=cfg, dataset_py=cfg.datasets.vla_data.dataset_py) accelerator.dataloader_config.dispatch_batches = False dist.barrier() return vla_train_dataloader def get_warmup_stable_cosine_scheduler(optimizer, num_warmup_steps, num_stable_steps, num_training_steps, min_lr_ratio=0.01): """ Warmup → Stable → Cosine Decay scheduler Args: optimizer: PyTorch optimizer num_warmup_steps: warmup 阶段步数 num_stable_steps: 保持 max_lr 的步数(在 warmup 之后) num_training_steps: 总训练步数 min_lr_ratio: 最终 lr / max_lr 的比例 Returns: LambdaLR scheduler """ import math def lr_lambda(current_step): # Warmup 阶段:线性增长 if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) # Stable 阶段:保持 max_lr stable_end = num_warmup_steps + num_stable_steps if current_step < stable_end: return 1.0 # Cosine decay 阶段 decay_steps = num_training_steps - stable_end if decay_steps <= 0: return min_lr_ratio progress = float(current_step - stable_end) / float(decay_steps) return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) # 为每个参数组提供相同的 lr_lambda(支持多参数组优化器) num_param_groups = len(optimizer.param_groups) return torch.optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda] * num_param_groups) def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]: """set optimizer and scheduler""" # initialize optimizer param_groups = build_param_lr_groups(model=model, cfg=cfg) optimizer = torch.optim.AdamW( param_groups, lr=cfg.trainer.learning_rate.base, betas=tuple(cfg.trainer.optimizer.betas), weight_decay=cfg.trainer.optimizer.weight_decay, eps=cfg.trainer.optimizer.eps, ) # print optimizer group info if dist.is_initialized() and dist.get_rank() == 0: for i, group in enumerate(optimizer.param_groups): logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}") # initialize learning rate scheduler if cfg.trainer.lr_scheduler_type == "warmup_stable_cosine": # 自定义 scheduler: Warmup → Stable → Cosine Decay min_lr_ratio = cfg.trainer.scheduler_specific_kwargs.get("min_lr_ratio", 0.01) num_stable_steps = cfg.trainer.get("num_stable_steps", 0) lr_scheduler = get_warmup_stable_cosine_scheduler( optimizer=optimizer, num_warmup_steps=cfg.trainer.num_warmup_steps, num_stable_steps=num_stable_steps, num_training_steps=cfg.trainer.max_train_steps, min_lr_ratio=min_lr_ratio, ) if dist.is_initialized() and dist.get_rank() == 0: logger.info(f"Using warmup_stable_cosine scheduler: warmup={cfg.trainer.num_warmup_steps}, " f"stable={num_stable_steps}, total={cfg.trainer.max_train_steps}, min_lr_ratio={min_lr_ratio}") elif cfg.trainer.lr_scheduler_type == "onecycle": # OneCycleLR: supports multiple param groups with different peak lrs. scheduler_kwargs = cfg.trainer.scheduler_specific_kwargs or {} pct_start = scheduler_kwargs.get("pct_start", None) if pct_start is None: pct_start = float(cfg.trainer.num_warmup_steps) / float(max(1, cfg.trainer.max_train_steps)) pct_start = max(0.0, min(1.0, float(pct_start))) max_lrs = [group["lr"] for group in optimizer.param_groups] lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=max_lrs, total_steps=cfg.trainer.max_train_steps, pct_start=pct_start, anneal_strategy=scheduler_kwargs.get("anneal_strategy", "cos"), cycle_momentum=scheduler_kwargs.get("cycle_momentum", False), div_factor=scheduler_kwargs.get("div_factor", 25.0), final_div_factor=scheduler_kwargs.get("final_div_factor", 10000.0), three_phase=scheduler_kwargs.get("three_phase", False), ) if dist.is_initialized() and dist.get_rank() == 0: logger.info( "Using onecycle scheduler: total=%s, pct_start=%.6f, max_lrs=%s, anneal=%s, " "div_factor=%s, final_div_factor=%s, cycle_momentum=%s, three_phase=%s", cfg.trainer.max_train_steps, pct_start, max_lrs, scheduler_kwargs.get("anneal_strategy", "cos"), scheduler_kwargs.get("div_factor", 25.0), scheduler_kwargs.get("final_div_factor", 10000.0), scheduler_kwargs.get("cycle_momentum", False), scheduler_kwargs.get("three_phase", False), ) else: # 使用 transformers 内置 scheduler lr_scheduler = get_scheduler( name=cfg.trainer.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=cfg.trainer.num_warmup_steps, num_training_steps=cfg.trainer.max_train_steps, scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs, ) return optimizer, lr_scheduler class VLATrainer(TrainerUtils): def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator): self.config = cfg self.model = model self.vla_train_dataloader = vla_train_dataloader # Note: optimizer/lr_scheduler are intentionally created in `prepare_training()` # after we load checkpoints and freeze modules, to avoid empty param-groups in # DeepSpeed ZeRO initialization. self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.accelerator = accelerator self._printed_first_batch = False # training status tracking self.completed_steps = 0 self.total_batch_size = self._calculate_total_batch_size() self._grad_norm_buffer: list[float] = [] self.training_mode = getattr(self.config.trainer, "mode", "default") self.loss_weights_decay_steps = int(getattr(self.config.trainer, "loss_weights_decay_steps", 5000)) if self.loss_weights_decay_steps <= 0: logger.warning( f"Invalid loss_weights_decay_steps={self.loss_weights_decay_steps}, fallback to 1." ) self.loss_weights_decay_steps = 1 def _debug_print_first_batch(self, batch) -> None: if self._printed_first_batch or not self.accelerator.is_local_main_process: return self._printed_first_batch = True sample = None if isinstance(batch, list): sample = batch[0] if len(batch) > 0 else None elif isinstance(batch, dict): sample = batch if sample is None: self.accelerator.print("First batch is empty.") return def _describe_value(value): if hasattr(value, "shape"): try: return f"{type(value).__name__}(shape={tuple(value.shape)})" except Exception: return type(value).__name__ if isinstance(value, list): inner = type(value[0]).__name__ if value else "empty" return f"list(len={len(value)}, inner={inner})" return type(value).__name__ self.accelerator.print(f"First batch type: {type(batch).__name__}") if isinstance(batch, list): self.accelerator.print(f"First batch size: {len(batch)}") self.accelerator.print("First sample keys:") for key, value in sample.items(): self.accelerator.print(f" - {key}: {_describe_value(value)}") # Print full content for first 5 samples to inspect inputs. if isinstance(batch, list): max_samples = min(5, len(batch)) for i in range(max_samples): self.accelerator.print(f"Sample[{i}] content:") for key, value in batch[i].items(): if hasattr(value, "shape"): try: value_str = np.array2string( value, threshold=np.inf, max_line_width=200 ) except Exception: value_str = repr(value) else: value_str = repr(value) self.accelerator.print(f" - {key}: {value_str}") def prepare_training(self): rank = dist.get_rank() if dist.is_initialized() else 0 seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047 set_seed(seed) # load pretrained weights # 如果 action_model 已经在 __init__ 中从 ckpt_path 加载了权重,需要保护它不被覆盖 action_model_ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None) if action_model_ckpt_path: # 保存 action_model 的权重用于验证 action_model_state_before = { k: v.clone() for k, v in self.model.action_model.state_dict().items() } if hasattr(self.config.trainer, "pretrained_checkpoint") and self.config.trainer.pretrained_checkpoint: pretrained_checkpoint = self.config.trainer.pretrained_checkpoint reload_modules = ( self.config.trainer.reload_modules if hasattr(self.config.trainer, "reload_modules") else None ) # 如果 action_model 有预加载的权重,且 reload_modules 未指定,则自动排除 action_model if action_model_ckpt_path and not reload_modules: # 检查 checkpoint 是否包含 action_model 的权重 try: checkpoint = torch.load(pretrained_checkpoint, map_location="cpu") has_action_model_keys = any(k.startswith("action_model.") for k in checkpoint.keys()) if has_action_model_keys: logger.warning( f"⚠️ pretrained_checkpoint contains action_model weights, but action_model " f"was already loaded from {action_model_ckpt_path}. " f"Will reload action_model from {action_model_ckpt_path} after loading checkpoint." ) except Exception: pass # 如果无法读取 checkpoint,继续正常流程 self.model = self.load_pretrained_backbones(self.model, pretrained_checkpoint, reload_modules=reload_modules) # 如果 action_model 有预加载的权重,重新加载以确保不被覆盖 if action_model_ckpt_path: logger.info(f"🔄 Reloading action_model from {action_model_ckpt_path} to ensure correct weights") self.model.action_model.load_state_dict( torch.load(action_model_ckpt_path, map_location="cpu"), strict=True ) # 验证权重是否被正确恢复 action_model_state_after = self.model.action_model.state_dict() mismatched = [] for k in action_model_state_before.keys(): if not torch.equal(action_model_state_before[k], action_model_state_after[k]): mismatched.append(k) if mismatched: logger.error(f"❌ action_model weights mismatch after reload: {mismatched}") else: logger.info("✅ action_model weights verified after checkpoint loading") # print model trainable parameters: self.print_trainable_parameters(self.model) # build optimizer and scheduler AFTER freezing (critical for DeepSpeed ZeRO) self.optimizer, self.lr_scheduler = setup_optimizer_and_scheduler(model=self.model, cfg=self.config) # initialize distributed training components # 注意:不传入 lr_scheduler,避免被 AcceleratedScheduler 包装(会导致 step 被调用 num_processes 倍) self.model, self.optimizer, self.vla_train_dataloader = self.setup_distributed_training( self.accelerator, # must be the first param self.model, self.optimizer, self.vla_train_dataloader, ) self._init_wandb() self._init_checkpointing() def _calculate_total_batch_size(self): """calculate global batch size""" return ( self.config.datasets.vla_data.per_device_batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps ) def _init_wandb(self): """initialize Weights & Biases""" if self.accelerator.is_main_process: wandb.init( name=self.config.run_id, dir=os.path.join(self.config.output_dir, "wandb"), project=self.config.wandb_project, entity=self.config.wandb_entity, group="vla-train", settings=wandb.Settings( _disable_stats=False, # 确保启用系统监控 x_stats_sampling_interval=10.0, # 每10秒采样一次系统指标 ), ) def _init_checkpointing(self): """initialize checkpoint directory""" self.checkpoint_dir = os.path.join(self.config.output_dir, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) pretrained_checkpoint = getattr(self.config.trainer, "pretrained_checkpoint", None) is_resume = getattr(self.config.trainer, "is_resume", False) # resume train ckpt if pretrained_checkpoint and is_resume: self._load_checkpoint(self.config.resume_from_checkpoint) def _load_checkpoint(self, checkpoint_path): """load checkpoint""" self.accelerator.load_state(checkpoint_path) self.accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") def _save_checkpoint(self): """save current training state""" if self.accelerator.is_main_process: checkpoint_path = os.path.join(self.checkpoint_dir, f"steps_{self.completed_steps}") # save model state state_dict = self.accelerator.get_state_dict(self.model) torch.save(state_dict, checkpoint_path + "_pytorch_model.pt") # save training metadata summary_data = { "steps": self.completed_steps, } with open(os.path.join(self.config.output_dir, "summary.jsonl"), "a") as f: f.write(json.dumps(summary_data) + "\n") self.accelerator.print(f"✅ Checkpoint saved at {checkpoint_path}") # 删除旧的checkpoint,只保留最近的N个 max_checkpoints = getattr(self.config.trainer, "max_checkpoints_to_keep", None) if max_checkpoints is not None and max_checkpoints > 0: self._cleanup_old_checkpoints(max_checkpoints) self.accelerator.wait_for_everyone() def _cleanup_old_checkpoints(self, max_checkpoints: int): """删除旧的checkpoint,只保留最近的N个""" # 只在主进程中执行,避免多进程竞态条件 if not self.accelerator.is_main_process: return # 获取所有checkpoint文件 checkpoint_pattern = os.path.join(self.checkpoint_dir, "steps_*_pytorch_model.pt") checkpoint_files = glob.glob(checkpoint_pattern) if len(checkpoint_files) <= max_checkpoints: return # 从文件名中提取步数,并按步数排序 def extract_steps(filepath): match = re.search(r'steps_(\d+)_pytorch_model\.pt', filepath) return int(match.group(1)) if match else 0 checkpoint_files.sort(key=extract_steps) # 删除最旧的checkpoint files_to_delete = checkpoint_files[:-max_checkpoints] for filepath in files_to_delete: try: os.remove(filepath) self.accelerator.print(f"🗑️ Deleted old checkpoint: {os.path.basename(filepath)}") except Exception as e: self.accelerator.print(f"⚠️ Failed to delete checkpoint {filepath}: {e}") def _log_metrics(self, metrics): """record training metrics""" if self.completed_steps % self.config.trainer.logging_frequency == 0: # Average grad_norm over the logging window (cleared every emit). if self._grad_norm_buffer: metrics["grad_norm_pre_clip_avg"] = float( sum(self._grad_norm_buffer) / len(self._grad_norm_buffer) ) self._grad_norm_buffer.clear() if dist.get_rank() == 0: # add learning rate metrics["learning_rate"] = self.lr_scheduler.get_last_lr()[0] # see lr group in yaml.trainer.learning_rate # add epoch info metrics["epoch"] = round(self.completed_steps / len(self.vla_train_dataloader), 2) # record to W&B wandb.log(metrics, step=self.completed_steps) # debug output gn_str = f"{metrics['grad_norm_pre_clip']:.4f}" if "grad_norm_pre_clip" in metrics else "N/A" gn_avg_str = f"{metrics['grad_norm_pre_clip_avg']:.4f}" if "grad_norm_pre_clip_avg" in metrics else "N/A" logger.info( f"\nStep {self.completed_steps} | " f"grad_norm_pre_clip={gn_str} | grad_norm_pre_clip_avg={gn_avg_str} | " f"Metrics: {metrics}" ) def _create_data_iterators(self): """create data iterators""" self.vla_iter = iter(self.vla_train_dataloader) # self.vlm_iter = iter(self.vlm_train_dataloader) def _get_next_batch(self): """get next batch (automatically handle data loop)""" try: batch_vla = next(self.vla_iter) except StopIteration: if not hasattr(self, "vla_epoch_count"): self.vla_epoch_count = 0 self.vla_iter, self.vla_epoch_count = TrainerUtils._reset_dataloader( self.vla_train_dataloader, self.vla_epoch_count ) batch_vla = next(self.vla_iter) return batch_vla def train(self): """execute training loop""" # print training config self._log_training_config() # prepare data iterators self._create_data_iterators() # create progress bar progress_bar = tqdm( range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process ) # main training loop while self.completed_steps < self.config.trainer.max_train_steps: # get data batch t_start_data = time.perf_counter() batch_vla = self._get_next_batch() self._debug_print_first_batch(batch_vla) t_end_data = time.perf_counter() # execute training step t_start_model = time.perf_counter() step_metrics = self._train_step(batch_vla) t_end_model = time.perf_counter() # update progress if self.accelerator.sync_gradients: progress_bar.update(1) self.completed_steps += 1 if self.accelerator.is_local_main_process: progress_bar.set_postfix( { "data_times": f"{t_end_data - t_start_data:.3f}", "model_times": f"{t_end_model - t_start_model:.3f}", } ) # evaluate model (reuse current training batch to avoid consuming extra samples) if self.completed_steps % self.config.trainer.eval_interval == 0: step_metrics = self.eval_action_model(step_metrics) # record metrics step_metrics["data_time"] = t_end_data - t_start_data step_metrics["model_time"] = t_end_model - t_start_model self._log_metrics(step_metrics) # save checkpoint if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0: self._save_checkpoint() # check termination condition if self.completed_steps >= self.config.trainer.max_train_steps: break # training end processing self._finalize_training() # execute evaluation step def eval_action_model(self, step_metrics: dict = None, examples=None) -> float: """ Evaluate the model on the given dataset using the specified metric function. :param eval_dataset: List of evaluation samples, each containing 'image', 'instruction', and 'action'. :param metric_fn: Function to compute the distance between predicted and ground truth actions. :return: Average metric score across the evaluation dataset. """ if examples is None: examples = self._get_next_batch() score = 0.0 # When using history, actions contain both history and future # We only evaluate on the future part (predicted actions) if self.model.num_history_steps > 0: start = self.model.num_history_steps end = start + self.model.chunk_size actions = [example["action"][start:end] for example in examples] # label aligned with predicted future chunk else: actions = [example["action"][: self.model.chunk_size] for example in examples] # label aligned with prediction length # Predict actions using the model output_dict = self.model.predict_action(examples=examples) if self.accelerator.is_main_process: normalized_actions = output_dict["normalized_actions"] # B, T, D actions = np.array(actions) # convert actions to numpy.ndarray # B, Chunk, dim = actions.shape num_pots = np.prod(actions.shape) # Compute the metric score (L1 = MAE, 更直观) score = TrainerUtils.l1_distance(normalized_actions, actions) average_score = score / num_pots step_metrics["mae_score"] = average_score del examples dist.barrier() # ensure all processes are synchronized return step_metrics def _log_training_config(self): """record training config""" if self.accelerator.is_main_process: logger.info("***** Training Configuration *****") logger.info(f" Total optimization steps = {self.config.trainer.max_train_steps}") logger.info(f" Per device batch size = {self.config.datasets.vla_data.per_device_batch_size}") logger.info(f" Gradient accumulation steps = {self.config.trainer.gradient_accumulation_steps}") logger.info(f" Total batch size = {self.total_batch_size}") logger.info("***** LR Scheduler Debug Info *****") logger.info(f" lr_scheduler type = {type(self.lr_scheduler)}") base_scheduler = getattr(self.lr_scheduler, 'scheduler', self.lr_scheduler) logger.info(f" base_scheduler type = {type(base_scheduler)}") logger.info(f" initial last_epoch = {getattr(base_scheduler, 'last_epoch', 'N/A')}") logger.info(f" initial lr = {self.lr_scheduler.get_last_lr()}") logger.info(f" num_warmup_steps = {self.config.trainer.num_warmup_steps}") logger.info(f" num_stable_steps = {self.config.trainer.get('num_stable_steps', 0)}") logger.info(f" max_train_steps = {self.config.trainer.max_train_steps}") logger.info(f" accelerator.num_processes = {self.accelerator.num_processes}") logger.info(f" accelerator.gradient_accumulation_steps = {self.accelerator.gradient_accumulation_steps}") logger.info(f" trainer.mode = {self.training_mode}") logger.info(f" loss_weights_decay_steps = {self.loss_weights_decay_steps}") def _get_aux_loss_decay_weight(self) -> float: if self.training_mode != "decay_aux_loss": return 1.0 progress = min(float(self.completed_steps) / float(self.loss_weights_decay_steps), 1.0) return 1.0 - progress @staticmethod def _total_grad_norm_l2_local(parameters) -> float: """L2 norm over all grads (same recipe as torch.nn.utils.clip_grad_norm_). DeepSpeed-safe fallback when clip_grad_norm_ returns None.""" total_sq = 0.0 for p in parameters: if p.grad is None: continue # float32 for stable norm under bf16/fp16 grads param_norm = p.grad.detach().float().norm(2) total_sq += float(param_norm) ** 2 return total_sq ** 0.5 @staticmethod def _grad_norm_scalar(value) -> float: if value is None: return float("nan") if isinstance(value, torch.Tensor): return float(value.detach().item()) return float(value) def _train_step(self, batch_vla, batch_vlm=None): """execute single training step""" is_deepspeed = self.accelerator.distributed_type == DistributedType.DEEPSPEED grad_norm_pre_clip = None with self.accelerator.accumulate(self.model): self.optimizer.zero_grad() # VLA task forward propagation(传入 training_step 使各 rank 的 history 随机一致,避免不同步) with torch.autocast("cuda", dtype=torch.bfloat16): output_dict = self.model.forward(batch_vla, training_step=self.completed_steps) align_loss = output_dict["align_loss"] recon_loss = output_dict["recon_loss"] predict_loss = output_dict["predict_loss"] aux_loss_decay_weight = self._get_aux_loss_decay_weight() if align_loss is not None and recon_loss is not None: total_loss = ( self.config.trainer.loss_scale.align_loss * aux_loss_decay_weight * align_loss + self.config.trainer.loss_scale.recon_loss * aux_loss_decay_weight * recon_loss + predict_loss ) else: total_loss = predict_loss # VLA backward propagation self.accelerator.backward(total_loss) # For non-DeepSpeed: clip explicitly and capture pre-clip norm before optimizer.step(). # For DeepSpeed: gradient clipping is handled by ds_config internally; calling # clip_grad_norm_ here returns the *previous* step's norm (stored in engine._global_grad_norm # which is only updated during optimizer.step()), so we skip it here and retrieve # the norm after step() below. if not is_deepspeed: gc = getattr(self.config.trainer, "gradient_clipping", None) max_norm = float(gc) if gc is not None else float("inf") grad_norm_pre_clip = self.accelerator.clip_grad_norm_( self.model.parameters(), max_norm ) if grad_norm_pre_clip is None: grad_norm_pre_clip = self._total_grad_norm_l2_local(self.model.parameters()) self.optimizer.step() if self.accelerator.sync_gradients: self.lr_scheduler.step() # For DeepSpeed: gradient clipping is handled internally during optimizer.step(), # which also populates engine._global_grad_norm. Calling clip_grad_norm_(inf) # is a no-op for DeepSpeed and returns None, so we read _global_grad_norm directly. if is_deepspeed: gn = getattr(self.model, "_global_grad_norm", None) if gn is None: # Older DeepSpeed / different ZeRO stage: try accelerator fallback gn = self.accelerator.clip_grad_norm_(self.model.parameters(), float("inf")) grad_norm_pre_clip = gn gn_scalar = self._grad_norm_scalar(grad_norm_pre_clip) self._grad_norm_buffer.append(gn_scalar) step_metrics = { "align_loss": align_loss.item() if align_loss is not None else None, "recon_loss": recon_loss.item() if recon_loss is not None else None, "predict_loss": predict_loss.item(), "aux_loss_decay_weight": aux_loss_decay_weight, "grad_norm_pre_clip": gn_scalar, } return step_metrics def _finalize_training(self): """training end processing""" # save final model if self.accelerator.is_main_process: final_checkpoint = os.path.join(self.config.output_dir, "final_model") os.makedirs(final_checkpoint, exist_ok=True) state_dict = self.accelerator.get_state_dict(self.model) torch.save(state_dict, os.path.join(final_checkpoint, "pytorch_model.pt")) logger.info(f"Training complete. Final model saved at {final_checkpoint}") # close W&B if self.accelerator.is_main_process: wandb.finish() self.accelerator.wait_for_everyone() def main(cfg) -> None: logger.info("VLA Training :: Warming Up") # create output directory and save config output_dir = setup_directories(cfg=cfg) # build model vla = build_framework(cfg) # prepare data vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir) # create trainer # Run VLA Training trainer = VLATrainer( cfg=cfg, model=vla, vla_train_dataloader=vla_train_dataloader, optimizer=None, lr_scheduler=None, accelerator=accelerator, ) # execute training preparation trainer.prepare_training() # execute training trainer.train() # And... we're done! logger.info("... and that's all, folks!") dist.barrier() dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() # Load YAML config & Convert CLI overrides to dotlist config cfg = OmegaConf.load(args.config_yaml) dotlist = normalize_dotlist_args(clipargs) # Normalize CLI args to dotlist format cli_cfg = OmegaConf.from_dotlist(dotlist) cfg = OmegaConf.merge(cfg, cli_cfg) # if cfg.is_debug: if cfg.is_debug and dist.is_initialized() and dist.get_rank() == 0: import debugpy debugpy.listen(("0.0.0.0", 10092)) print("🔍 Rank 0 waiting for debugger attach on port 10092...") debugpy.wait_for_client() main(cfg)