# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); # Implemented by [Jinhui YE / HKUST University] in [2025]. import sys sys.path.append("/mnt/data/fangyu/code/reward_new") """ StarVLA’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack. Conventions: 1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc). 2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures. 3. Put each training strategy in its own `trainer_*.py` file (avoid large if‑else chains). """ import warnings warnings.filterwarnings("ignore") # Standard Library import argparse import json import os os.environ["WANDB_API_KEY"] = "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" from pathlib import Path from typing import Tuple from torch.utils.data import Dataset, DataLoader import numpy as np import time import glob import re # Third-Party Libraries import torch import torch.distributed as dist # import wandb import yaml from accelerate import Accelerator, DeepSpeedPlugin from accelerate.logging import get_logger from accelerate.utils import set_seed from omegaconf import OmegaConf from tqdm import tqdm from transformers import AutoProcessor, get_scheduler # Local Modules from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args from starVLA.model.framework import build_framework from starVLA.training.trainer_utils.trainer_tools import TrainerUtils from starVLA.training.trainer_utils.trainer_tools import build_param_lr_groups deepspeed_plugin = DeepSpeedPlugin() accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) accelerator.print(accelerator.state) # Sane Defaults os.environ["TOKENIZERS_PARALLELISM"] = "false" # Initialize Overwatch =>> Wraps `logging.Logger` from accelerate.logging import get_logger logger = get_logger(__name__) def load_fast_tokenizer(): fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) return fast_tokenizer def setup_directories(cfg) -> Path: """create output directory and save config""" cfg.output_dir = os.path.join(cfg.run_root_dir, cfg.run_id) output_dir = Path(cfg.output_dir) if not dist.is_initialized() or dist.get_rank() == 0: # create output directory and checkpoint directory os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir / "checkpoints", exist_ok=True) # save config OmegaConf.save(cfg, output_dir / "config.yaml") with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json: yaml_cfg = yaml.safe_load(f_yaml) json.dump(yaml_cfg, f_json, indent=2) return output_dir def build_model(cfg) -> torch.nn.Module: """build model framework""" logger.info(f"Loading Base VLM `{cfg.framework.qwenvl.base_vlm}` from ID/Path") model = build_framework(cfg) return model # here changes need to 📦 encapsulate Dataloader from starVLA.dataloader import build_dataloader def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]: """prepare training data""" # VLA data loader logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.data_mix}`") vla_train_dataloader = build_dataloader(cfg=cfg, dataset_py=cfg.datasets.vla_data.dataset_py) accelerator.dataloader_config.dispatch_batches = False dist.barrier() return vla_train_dataloader def get_warmup_stable_cosine_scheduler(optimizer, num_warmup_steps, num_stable_steps, num_training_steps, min_lr_ratio=0.01): """ Warmup → Stable → Cosine Decay scheduler Args: optimizer: PyTorch optimizer num_warmup_steps: warmup 阶段步数 num_stable_steps: 保持 max_lr 的步数(在 warmup 之后) num_training_steps: 总训练步数 min_lr_ratio: 最终 lr / max_lr 的比例 Returns: LambdaLR scheduler """ import math def lr_lambda(current_step): # Warmup 阶段:线性增长 if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) # Stable 阶段:保持 max_lr stable_end = num_warmup_steps + num_stable_steps if current_step < stable_end: return 1.0 # Cosine decay 阶段 decay_steps = num_training_steps - stable_end if decay_steps <= 0: return min_lr_ratio progress = float(current_step - stable_end) / float(decay_steps) return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) # 为每个参数组提供相同的 lr_lambda(支持多参数组优化器) num_param_groups = len(optimizer.param_groups) return torch.optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda] * num_param_groups) def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]: """set optimizer and scheduler""" # initialize optimizer param_groups = build_param_lr_groups(model=model, cfg=cfg) optimizer = torch.optim.AdamW( param_groups, lr=cfg.trainer.learning_rate.base, betas=tuple(cfg.trainer.optimizer.betas), weight_decay=cfg.trainer.optimizer.weight_decay, eps=cfg.trainer.optimizer.eps, ) # print optimizer group info if dist.is_initialized() and dist.get_rank() == 0: for i, group in enumerate(optimizer.param_groups): logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}") # initialize learning rate scheduler if cfg.trainer.lr_scheduler_type == "warmup_stable_cosine": # 自定义 scheduler: Warmup → Stable → Cosine Decay min_lr_ratio = cfg.trainer.scheduler_specific_kwargs.get("min_lr_ratio", 0.01) num_stable_steps = cfg.trainer.get("num_stable_steps", 0) lr_scheduler = get_warmup_stable_cosine_scheduler( optimizer=optimizer, num_warmup_steps=cfg.trainer.num_warmup_steps, num_stable_steps=num_stable_steps, num_training_steps=cfg.trainer.max_train_steps, min_lr_ratio=min_lr_ratio, ) if dist.is_initialized() and dist.get_rank() == 0: logger.info(f"Using warmup_stable_cosine scheduler: warmup={cfg.trainer.num_warmup_steps}, " f"stable={num_stable_steps}, total={cfg.trainer.max_train_steps}, min_lr_ratio={min_lr_ratio}") else: # 使用 transformers 内置 scheduler lr_scheduler = get_scheduler( name=cfg.trainer.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=cfg.trainer.num_warmup_steps, num_training_steps=cfg.trainer.max_train_steps, scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs, ) return optimizer, lr_scheduler class VLATrainer(TrainerUtils): def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator): self.config = cfg self.model = model self.vla_train_dataloader = vla_train_dataloader self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.accelerator = accelerator self._printed_first_batch = False # training status tracking self.completed_steps = 0 self.total_batch_size = self._calculate_total_batch_size() def _debug_print_first_batch(self, batch) -> None: if self._printed_first_batch or not self.accelerator.is_local_main_process: return self._printed_first_batch = True sample = None if isinstance(batch, list): sample = batch[0] if len(batch) > 0 else None elif isinstance(batch, dict): sample = batch if sample is None: self.accelerator.print("First batch is empty.") return def _describe_value(value): if hasattr(value, "shape"): try: return f"{type(value).__name__}(shape={tuple(value.shape)})" except Exception: return type(value).__name__ if isinstance(value, list): inner = type(value[0]).__name__ if value else "empty" return f"list(len={len(value)}, inner={inner})" return type(value).__name__ self.accelerator.print(f"First batch type: {type(batch).__name__}") if isinstance(batch, list): self.accelerator.print(f"First batch size: {len(batch)}") self.accelerator.print("First sample keys:") for key, value in sample.items(): self.accelerator.print(f" - {key}: {_describe_value(value)}") # Print full content for first 5 samples to inspect inputs. if isinstance(batch, list): import numpy as np max_samples = min(5, len(batch)) for i in range(max_samples): self.accelerator.print(f"Sample[{i}] content:") for key, value in batch[i].items(): if hasattr(value, "shape"): try: value_str = np.array2string( value, threshold=np.inf, max_line_width=200 ) except Exception: value_str = repr(value) else: value_str = repr(value) self.accelerator.print(f" - {key}: {value_str}") def prepare_training(self): rank = dist.get_rank() if dist.is_initialized() else 0 seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047 set_seed(seed) # load pretrained weights if hasattr(self.config.trainer, "pretrained_checkpoint") and self.config.trainer.pretrained_checkpoint: pretrained_checkpoint = self.config.trainer.pretrained_checkpoint reload_modules = ( self.config.trainer.reload_modules if hasattr(self.config.trainer, "reload_modules") else None ) self.model = self.load_pretrained_backbones(self.model, pretrained_checkpoint, reload_modules=reload_modules) # freeze parameters freeze_modules = ( self.config.trainer.freeze_modules if (self.config and hasattr(self.config.trainer, "freeze_modules")) else None ) self.model = self.freeze_backbones(self.model, freeze_modules=freeze_modules) # print model trainable parameters: self.print_trainable_parameters(self.model) # build optimizer and scheduler AFTER freezing (critical for DeepSpeed ZeRO) self.optimizer, self.lr_scheduler = setup_optimizer_and_scheduler(model=self.model, cfg=self.config) # initialize distributed training components # 注意:不传入 lr_scheduler,避免被 AcceleratedScheduler 包装(会导致 step 被调用 num_processes 倍) self.model, self.optimizer, self.vla_train_dataloader = self.setup_distributed_training( self.accelerator, # must be the first param self.model, self.optimizer, self.vla_train_dataloader, ) # lr_scheduler 保持原始的 LambdaLR,不被 Accelerate 包装 # self._init_wandb() self._init_checkpointing() def _calculate_total_batch_size(self): """calculate global batch size""" return ( self.config.datasets.vla_data.per_device_batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps ) def _init_wandb(self): """initialize Weights & Biases""" # if self.accelerator.is_main_process: # wandb.init( # name=self.config.run_id, # dir=os.path.join(self.config.output_dir, "wandb"), # project=self.config.wandb_project, # entity=self.config.wandb_entity, # group="vla-train", # settings=wandb.Settings( # _disable_stats=False, # 确保启用系统监控 # x_stats_sampling_interval=10.0, # 每10秒采样一次系统指标 # ), # ) pass def _init_checkpointing(self): """initialize checkpoint directory""" self.checkpoint_dir = os.path.join(self.config.output_dir, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) pretrained_checkpoint = getattr(self.config.trainer, "pretrained_checkpoint", None) is_resume = getattr(self.config.trainer, "is_resume", False) # resume train ckpt if pretrained_checkpoint and is_resume: self._load_checkpoint(self.config.resume_from_checkpoint) def _load_checkpoint(self, checkpoint_path): """load checkpoint""" self.accelerator.load_state(checkpoint_path) self.accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") def _save_checkpoint(self): """save current training state""" if self.accelerator.is_main_process: checkpoint_path = os.path.join(self.checkpoint_dir, f"steps_{self.completed_steps}") # save model state state_dict = self.accelerator.get_state_dict(self.model) torch.save(state_dict, checkpoint_path + "_pytorch_model.pt") # save training metadata summary_data = { "steps": self.completed_steps, } with open(os.path.join(self.config.output_dir, "summary.jsonl"), "a") as f: f.write(json.dumps(summary_data) + "\n") self.accelerator.print(f"✅ Checkpoint saved at {checkpoint_path}") # 删除旧的checkpoint,只保留最近的N个 max_checkpoints = getattr(self.config.trainer, "max_checkpoints_to_keep", None) if max_checkpoints is not None and max_checkpoints > 0: self._cleanup_old_checkpoints(max_checkpoints) self.accelerator.wait_for_everyone() def _cleanup_old_checkpoints(self, max_checkpoints: int): """删除旧的checkpoint,只保留最近的N个""" # 只在主进程中执行,避免多进程竞态条件 if not self.accelerator.is_main_process: return # 获取所有checkpoint文件 checkpoint_pattern = os.path.join(self.checkpoint_dir, "steps_*_pytorch_model.pt") checkpoint_files = glob.glob(checkpoint_pattern) if len(checkpoint_files) <= max_checkpoints: return # 从文件名中提取步数,并按步数排序 def extract_steps(filepath): match = re.search(r'steps_(\d+)_pytorch_model\.pt', filepath) return int(match.group(1)) if match else 0 checkpoint_files.sort(key=extract_steps) # 删除最旧的checkpoint files_to_delete = checkpoint_files[:-max_checkpoints] for filepath in files_to_delete: try: os.remove(filepath) self.accelerator.print(f"🗑️ Deleted old checkpoint: {os.path.basename(filepath)}") except Exception as e: self.accelerator.print(f"⚠️ Failed to delete checkpoint {filepath}: {e}") def _log_metrics(self, metrics): """record training metrics""" if self.completed_steps % self.config.trainer.logging_frequency == 0: if dist.get_rank() == 0: # add learning rate metrics["learning_rate"] = self.lr_scheduler.get_last_lr()[ 0] # see lr group in yaml.trainer.learning_rate # add epoch info metrics["epoch"] = round(self.completed_steps / len(self.vla_train_dataloader), 2) # record to W&B # wandb.log(metrics, step=self.completed_steps) # debug output logger.info(f"\nStep {self.completed_steps}, Loss: {metrics})") def _create_data_iterators(self): """create data iterators""" self.vla_iter = iter(self.vla_train_dataloader) # self.vlm_iter = iter(self.vlm_train_dataloader) def _get_next_batch(self): """get next batch (automatically handle data loop)""" try: batch_vla = next(self.vla_iter) except StopIteration: if not hasattr(self, "vla_epoch_count"): self.vla_epoch_count = 0 self.vla_iter, self.vla_epoch_count = TrainerUtils._reset_dataloader( self.vla_train_dataloader, self.vla_epoch_count ) batch_vla = next(self.vla_iter) return batch_vla def train(self): """execute training loop""" # print training config self._log_training_config() # prepare data iterators self._create_data_iterators() # create progress bar progress_bar = tqdm( range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process ) # main training loop while self.completed_steps < self.config.trainer.max_train_steps: # get data batch t_start_data = time.perf_counter() batch_vla = self._get_next_batch() self._debug_print_first_batch(batch_vla) t_end_data = time.perf_counter() # execute training step t_start_model = time.perf_counter() step_metrics = self._train_step(batch_vla) t_end_model = time.perf_counter() # update progress if self.accelerator.sync_gradients: progress_bar.update(1) self.completed_steps += 1 if self.accelerator.is_local_main_process: progress_bar.set_postfix( { "data_times": f"{t_end_data - t_start_data:.3f}", "model_times": f"{t_end_model - t_start_model:.3f}", } ) # evaluate model (predict action once and compute MAE) eval_interval = getattr(self.config.trainer, "eval_interval", 0) if eval_interval > 0 and self.completed_steps > 0 and self.completed_steps % eval_interval == 0: step_metrics = self.eval_action_model(step_metrics) # record metrics step_metrics["data_time"] = t_end_data - t_start_data step_metrics["model_time"] = t_end_model - t_start_model self._log_metrics(step_metrics) # save checkpoint if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0: self._save_checkpoint() # check termination condition if self.completed_steps >= self.config.trainer.max_train_steps: break # training end processing self._finalize_training() # execute evaluation step def eval_action_model(self, step_metrics: dict = None): """ Evaluate action model: encode -> decode one batch, then compute MAE (L1) between predicted and ground-truth actions. Compatible with ActionModelFM (encode_actions + decode_actions). """ if step_metrics is None: step_metrics = {} examples = self._get_next_batch() device = next(self.model.parameters()).device batch_size = len(examples) # Use same chunk length for all samples (min over batch, capped by config) max_chunk = getattr( self.model.config, "max_action_chunk_size", 50 ) chunk_len = min(max_chunk, min(len(ex["action"]) for ex in examples)) if chunk_len < 1: dist.barrier() return step_metrics # (B, L, D) param_dtype = next(self.model.parameters()).dtype raw_actions = np.array([ex["action"][:chunk_len] for ex in examples]) actions_tensor = torch.tensor(raw_actions, device=device, dtype=param_dtype) # [B, L, D] use_state = self.model.use_state if use_state: states_tensor = torch.tensor( np.array([ex["state"][:chunk_len] for ex in examples]), device=device, dtype=param_dtype, ) # [B, L, state_dim] else: states_tensor = None dataset_ids = [ex.get("dataset_id") for ex in examples] with torch.no_grad(): action_embedding = self.model.encode_actions(actions_tensor, dataset_ids, states_tensor) pred_actions = self.model.decode_actions(action_embedding, chunk_size=chunk_len) pred_np = pred_actions.cpu().float().numpy() gt_np = raw_actions if self.accelerator.is_main_process: score = TrainerUtils.l1_distance(pred_np, gt_np) num_elements = pred_np.size mae_score = score / max(num_elements, 1) step_metrics["mae_score"] = float(mae_score) del examples, actions_tensor, action_embedding, pred_actions dist.barrier() return step_metrics def _log_training_config(self): """record training config""" if self.accelerator.is_main_process: logger.info("***** Training Configuration *****") logger.info(f" Total optimization steps = {self.config.trainer.max_train_steps}") logger.info(f" Per device batch size = {self.config.datasets.vla_data.per_device_batch_size}") logger.info(f" Gradient accumulation steps = {self.config.trainer.gradient_accumulation_steps}") logger.info(f" Total batch size = {self.total_batch_size}") logger.info("***** LR Scheduler Debug Info *****") logger.info(f" lr_scheduler type = {type(self.lr_scheduler)}") base_scheduler = getattr(self.lr_scheduler, 'scheduler', self.lr_scheduler) logger.info(f" base_scheduler type = {type(base_scheduler)}") logger.info(f" initial last_epoch = {getattr(base_scheduler, 'last_epoch', 'N/A')}") logger.info(f" initial lr = {self.lr_scheduler.get_last_lr()}") logger.info(f" num_warmup_steps = {self.config.trainer.num_warmup_steps}") logger.info(f" num_stable_steps = {self.config.trainer.get('num_stable_steps', 0)}") logger.info(f" max_train_steps = {self.config.trainer.max_train_steps}") logger.info(f" accelerator.num_processes = {self.accelerator.num_processes}") logger.info(f" accelerator.gradient_accumulation_steps = {self.accelerator.gradient_accumulation_steps}") def _train_step(self, batch_vla, batch_vlm=None): """execute single training step""" with self.accelerator.accumulate(self.model): self.optimizer.zero_grad() # VLA task forward propagation with torch.autocast("cuda", dtype=torch.bfloat16): recon_loss = self.model.forward(batch_vla) # VLA backward propagation self.accelerator.backward(recon_loss) # gradient clipping grad_norm = None if self.config.trainer.gradient_clipping is not None: grad_norm = self.accelerator.clip_grad_norm_( self.model.parameters(), self.config.trainer.gradient_clipping ) # optimizer step self.optimizer.step() if self.accelerator.sync_gradients: self.lr_scheduler.step() step_metrics = { "recon_loss": recon_loss.item(), } if grad_norm is not None: step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) return step_metrics def _finalize_training(self): """training end processing""" # save final model if self.accelerator.is_main_process: final_checkpoint = os.path.join(self.config.output_dir, "final_model") os.makedirs(final_checkpoint, exist_ok=True) state_dict = self.accelerator.get_state_dict(self.model) torch.save(state_dict, os.path.join(final_checkpoint, "pytorch_model.pt")) logger.info(f"Training complete. Final model saved at {final_checkpoint}") # close W&B if self.accelerator.is_main_process: # wandb.finish() pass self.accelerator.wait_for_everyone() def main(cfg) -> None: logger.info("VLA Training :: Warming Up") # create output directory and save config output_dir = setup_directories(cfg=cfg) # build model vla = build_framework(cfg) # prepare data vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir) # create trainer # Run VLA Training trainer = VLATrainer( cfg=cfg, model=vla, vla_train_dataloader=vla_train_dataloader, optimizer=None, lr_scheduler=None, accelerator=accelerator, ) # execute training preparation trainer.prepare_training() # execute training trainer.train() # And... we're done! logger.info("... and that's all, folks!") dist.barrier() dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config_yaml", type=str, default="starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") args, clipargs = parser.parse_known_args() # Load YAML config & Convert CLI overrides to dotlist config cfg = OmegaConf.load(args.config_yaml) dotlist = normalize_dotlist_args(clipargs) # Normalize CLI args to dotlist format cli_cfg = OmegaConf.from_dotlist(dotlist) cfg = OmegaConf.merge(cfg, cli_cfg) # if cfg.is_debug: if cfg.is_debug and dist.is_initialized() and dist.get_rank() == 0: import debugpy debugpy.listen(("0.0.0.0", 10092)) print("🔍 Rank 0 waiting for debugger attach on port 10092...") debugpy.wait_for_client() main(cfg)