| |
| |
| |
| import sys |
| sys.path.append("/mnt/data/fangyu/code/reward_new") |
|
|
| """ |
| StarVLA’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack. |
| Conventions: |
| 1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc). |
| 2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures. |
| 3. Put each training strategy in its own `trainer_*.py` file (avoid large if‑else chains). |
| """ |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
| import argparse |
| import json |
| import os |
| os.environ["WANDB_API_KEY"] = "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" |
| from pathlib import Path |
| from typing import Tuple |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| import time |
| import glob |
| import re |
| |
| import torch |
| import torch.distributed as dist |
| import wandb |
| import yaml |
| from accelerate import Accelerator, DeepSpeedPlugin |
| from accelerate.logging import get_logger |
| from accelerate.utils import set_seed, DistributedType |
| from omegaconf import OmegaConf |
| from tqdm import tqdm |
| from transformers import AutoProcessor, get_scheduler |
|
|
| |
| from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args |
| from starVLA.model.framework import build_framework |
| from starVLA.training.trainer_utils.trainer_tools import TrainerUtils |
| from starVLA.training.trainer_utils.trainer_tools import build_param_lr_groups |
|
|
| deepspeed_plugin = DeepSpeedPlugin() |
| accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) |
| accelerator.print(accelerator.state) |
|
|
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
| |
| from accelerate.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def load_fast_tokenizer(): |
| fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) |
| return fast_tokenizer |
|
|
|
|
| def setup_directories(cfg) -> Path: |
| """create output directory and save config""" |
| cfg.output_dir = os.path.join(cfg.run_root_dir, cfg.run_id) |
| output_dir = Path(cfg.output_dir) |
|
|
| if not dist.is_initialized() or dist.get_rank() == 0: |
| |
| os.makedirs(output_dir, exist_ok=True) |
| os.makedirs(output_dir / "checkpoints", exist_ok=True) |
|
|
| |
| OmegaConf.save(cfg, output_dir / "config.yaml") |
| with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json: |
| yaml_cfg = yaml.safe_load(f_yaml) |
| json.dump(yaml_cfg, f_json, indent=2) |
|
|
| return output_dir |
|
|
|
|
| def build_model(cfg) -> torch.nn.Module: |
| """build model framework""" |
| logger.info(f"Loading Base VLM `{cfg.framework.qwenvl.base_vlm}` from ID/Path") |
| model = build_framework(cfg) |
|
|
| return model |
|
|
|
|
| |
| from starVLA.dataloader import build_dataloader |
|
|
|
|
| def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]: |
| """prepare training data""" |
| |
| logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.data_mix}`") |
| vla_train_dataloader = build_dataloader(cfg=cfg, dataset_py=cfg.datasets.vla_data.dataset_py) |
|
|
| accelerator.dataloader_config.dispatch_batches = False |
| dist.barrier() |
|
|
| return vla_train_dataloader |
|
|
|
|
| def get_warmup_stable_cosine_scheduler(optimizer, num_warmup_steps, num_stable_steps, num_training_steps, min_lr_ratio=0.01): |
| """ |
| Warmup → Stable → Cosine Decay scheduler |
| |
| Args: |
| optimizer: PyTorch optimizer |
| num_warmup_steps: warmup 阶段步数 |
| num_stable_steps: 保持 max_lr 的步数(在 warmup 之后) |
| num_training_steps: 总训练步数 |
| min_lr_ratio: 最终 lr / max_lr 的比例 |
| |
| Returns: |
| LambdaLR scheduler |
| """ |
| import math |
| |
| def lr_lambda(current_step): |
| |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| |
| |
| stable_end = num_warmup_steps + num_stable_steps |
| if current_step < stable_end: |
| return 1.0 |
| |
| |
| decay_steps = num_training_steps - stable_end |
| if decay_steps <= 0: |
| return min_lr_ratio |
| progress = float(current_step - stable_end) / float(decay_steps) |
| return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) |
| |
| |
| num_param_groups = len(optimizer.param_groups) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda] * num_param_groups) |
|
|
|
|
| def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]: |
| """set optimizer and scheduler""" |
| |
| param_groups = build_param_lr_groups(model=model, cfg=cfg) |
| optimizer = torch.optim.AdamW( |
| param_groups, |
| lr=cfg.trainer.learning_rate.base, |
| betas=tuple(cfg.trainer.optimizer.betas), |
| weight_decay=cfg.trainer.optimizer.weight_decay, |
| eps=cfg.trainer.optimizer.eps, |
| ) |
|
|
| |
| if dist.is_initialized() and dist.get_rank() == 0: |
| for i, group in enumerate(optimizer.param_groups): |
| logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}") |
|
|
| |
| if cfg.trainer.lr_scheduler_type == "warmup_stable_cosine": |
| |
| min_lr_ratio = cfg.trainer.scheduler_specific_kwargs.get("min_lr_ratio", 0.01) |
| num_stable_steps = cfg.trainer.get("num_stable_steps", 0) |
| lr_scheduler = get_warmup_stable_cosine_scheduler( |
| optimizer=optimizer, |
| num_warmup_steps=cfg.trainer.num_warmup_steps, |
| num_stable_steps=num_stable_steps, |
| num_training_steps=cfg.trainer.max_train_steps, |
| min_lr_ratio=min_lr_ratio, |
| ) |
| if dist.is_initialized() and dist.get_rank() == 0: |
| logger.info(f"Using warmup_stable_cosine scheduler: warmup={cfg.trainer.num_warmup_steps}, " |
| f"stable={num_stable_steps}, total={cfg.trainer.max_train_steps}, min_lr_ratio={min_lr_ratio}") |
| elif cfg.trainer.lr_scheduler_type == "onecycle": |
| |
| scheduler_kwargs = cfg.trainer.scheduler_specific_kwargs or {} |
| pct_start = scheduler_kwargs.get("pct_start", None) |
| if pct_start is None: |
| pct_start = float(cfg.trainer.num_warmup_steps) / float(max(1, cfg.trainer.max_train_steps)) |
| pct_start = max(0.0, min(1.0, float(pct_start))) |
|
|
| max_lrs = [group["lr"] for group in optimizer.param_groups] |
| lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer=optimizer, |
| max_lr=max_lrs, |
| total_steps=cfg.trainer.max_train_steps, |
| pct_start=pct_start, |
| anneal_strategy=scheduler_kwargs.get("anneal_strategy", "cos"), |
| cycle_momentum=scheduler_kwargs.get("cycle_momentum", False), |
| div_factor=scheduler_kwargs.get("div_factor", 25.0), |
| final_div_factor=scheduler_kwargs.get("final_div_factor", 10000.0), |
| three_phase=scheduler_kwargs.get("three_phase", False), |
| ) |
| if dist.is_initialized() and dist.get_rank() == 0: |
| logger.info( |
| "Using onecycle scheduler: total=%s, pct_start=%.6f, max_lrs=%s, anneal=%s, " |
| "div_factor=%s, final_div_factor=%s, cycle_momentum=%s, three_phase=%s", |
| cfg.trainer.max_train_steps, |
| pct_start, |
| max_lrs, |
| scheduler_kwargs.get("anneal_strategy", "cos"), |
| scheduler_kwargs.get("div_factor", 25.0), |
| scheduler_kwargs.get("final_div_factor", 10000.0), |
| scheduler_kwargs.get("cycle_momentum", False), |
| scheduler_kwargs.get("three_phase", False), |
| ) |
| else: |
| |
| lr_scheduler = get_scheduler( |
| name=cfg.trainer.lr_scheduler_type, |
| optimizer=optimizer, |
| num_warmup_steps=cfg.trainer.num_warmup_steps, |
| num_training_steps=cfg.trainer.max_train_steps, |
| scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs, |
| ) |
|
|
| return optimizer, lr_scheduler |
|
|
|
|
| class VLATrainer(TrainerUtils): |
| def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator): |
| self.config = cfg |
| self.model = model |
| self.vla_train_dataloader = vla_train_dataloader |
| |
| |
| |
| self.optimizer = optimizer |
| self.lr_scheduler = lr_scheduler |
| self.accelerator = accelerator |
| self._printed_first_batch = False |
|
|
| |
| self.completed_steps = 0 |
| self.total_batch_size = self._calculate_total_batch_size() |
| self._grad_norm_buffer: list[float] = [] |
| self.training_mode = getattr(self.config.trainer, "mode", "default") |
| self.loss_weights_decay_steps = int(getattr(self.config.trainer, "loss_weights_decay_steps", 5000)) |
| if self.loss_weights_decay_steps <= 0: |
| logger.warning( |
| f"Invalid loss_weights_decay_steps={self.loss_weights_decay_steps}, fallback to 1." |
| ) |
| self.loss_weights_decay_steps = 1 |
|
|
| def _debug_print_first_batch(self, batch) -> None: |
| if self._printed_first_batch or not self.accelerator.is_local_main_process: |
| return |
| self._printed_first_batch = True |
|
|
| sample = None |
| if isinstance(batch, list): |
| sample = batch[0] if len(batch) > 0 else None |
| elif isinstance(batch, dict): |
| sample = batch |
|
|
| if sample is None: |
| self.accelerator.print("First batch is empty.") |
| return |
|
|
| def _describe_value(value): |
| if hasattr(value, "shape"): |
| try: |
| return f"{type(value).__name__}(shape={tuple(value.shape)})" |
| except Exception: |
| return type(value).__name__ |
| if isinstance(value, list): |
| inner = type(value[0]).__name__ if value else "empty" |
| return f"list(len={len(value)}, inner={inner})" |
| return type(value).__name__ |
|
|
| self.accelerator.print(f"First batch type: {type(batch).__name__}") |
| if isinstance(batch, list): |
| self.accelerator.print(f"First batch size: {len(batch)}") |
| self.accelerator.print("First sample keys:") |
| for key, value in sample.items(): |
| self.accelerator.print(f" - {key}: {_describe_value(value)}") |
|
|
| |
| if isinstance(batch, list): |
| max_samples = min(5, len(batch)) |
| for i in range(max_samples): |
| self.accelerator.print(f"Sample[{i}] content:") |
| for key, value in batch[i].items(): |
| if hasattr(value, "shape"): |
| try: |
| value_str = np.array2string( |
| value, threshold=np.inf, max_line_width=200 |
| ) |
| except Exception: |
| value_str = repr(value) |
| else: |
| value_str = repr(value) |
| self.accelerator.print(f" - {key}: {value_str}") |
|
|
| def prepare_training(self): |
| rank = dist.get_rank() if dist.is_initialized() else 0 |
| seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047 |
| set_seed(seed) |
|
|
| |
| |
| action_model_ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None) |
| if action_model_ckpt_path: |
| |
| action_model_state_before = { |
| k: v.clone() for k, v in self.model.action_model.state_dict().items() |
| } |
| |
| if hasattr(self.config.trainer, "pretrained_checkpoint") and self.config.trainer.pretrained_checkpoint: |
| pretrained_checkpoint = self.config.trainer.pretrained_checkpoint |
| reload_modules = ( |
| self.config.trainer.reload_modules if hasattr(self.config.trainer, "reload_modules") else None |
| ) |
| |
| |
| if action_model_ckpt_path and not reload_modules: |
| |
| try: |
| checkpoint = torch.load(pretrained_checkpoint, map_location="cpu") |
| has_action_model_keys = any(k.startswith("action_model.") for k in checkpoint.keys()) |
| if has_action_model_keys: |
| logger.warning( |
| f"⚠️ pretrained_checkpoint contains action_model weights, but action_model " |
| f"was already loaded from {action_model_ckpt_path}. " |
| f"Will reload action_model from {action_model_ckpt_path} after loading checkpoint." |
| ) |
| except Exception: |
| pass |
| |
| self.model = self.load_pretrained_backbones(self.model, pretrained_checkpoint, reload_modules=reload_modules) |
| |
| |
| if action_model_ckpt_path: |
| logger.info(f"🔄 Reloading action_model from {action_model_ckpt_path} to ensure correct weights") |
| self.model.action_model.load_state_dict( |
| torch.load(action_model_ckpt_path, map_location="cpu"), strict=True |
| ) |
| |
| action_model_state_after = self.model.action_model.state_dict() |
| mismatched = [] |
| for k in action_model_state_before.keys(): |
| if not torch.equal(action_model_state_before[k], action_model_state_after[k]): |
| mismatched.append(k) |
| if mismatched: |
| logger.error(f"❌ action_model weights mismatch after reload: {mismatched}") |
| else: |
| logger.info("✅ action_model weights verified after checkpoint loading") |
|
|
| |
| self.print_trainable_parameters(self.model) |
|
|
| |
| self.optimizer, self.lr_scheduler = setup_optimizer_and_scheduler(model=self.model, cfg=self.config) |
|
|
| |
| |
| self.model, self.optimizer, self.vla_train_dataloader = self.setup_distributed_training( |
| self.accelerator, |
| self.model, |
| self.optimizer, |
| self.vla_train_dataloader, |
| ) |
|
|
| self._init_wandb() |
| self._init_checkpointing() |
|
|
| def _calculate_total_batch_size(self): |
| """calculate global batch size""" |
| return ( |
| self.config.datasets.vla_data.per_device_batch_size |
| * self.accelerator.num_processes |
| * self.accelerator.gradient_accumulation_steps |
| ) |
|
|
| def _init_wandb(self): |
| """initialize Weights & Biases""" |
| if self.accelerator.is_main_process: |
| wandb.init( |
| name=self.config.run_id, |
| dir=os.path.join(self.config.output_dir, "wandb"), |
| project=self.config.wandb_project, |
| entity=self.config.wandb_entity, |
| group="vla-train", |
| settings=wandb.Settings( |
| _disable_stats=False, |
| x_stats_sampling_interval=10.0, |
| ), |
| ) |
|
|
| def _init_checkpointing(self): |
| """initialize checkpoint directory""" |
| self.checkpoint_dir = os.path.join(self.config.output_dir, "checkpoints") |
| os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
| pretrained_checkpoint = getattr(self.config.trainer, "pretrained_checkpoint", None) |
| is_resume = getattr(self.config.trainer, "is_resume", False) |
|
|
| |
| if pretrained_checkpoint and is_resume: |
| self._load_checkpoint(self.config.resume_from_checkpoint) |
|
|
| def _load_checkpoint(self, checkpoint_path): |
| """load checkpoint""" |
| self.accelerator.load_state(checkpoint_path) |
| self.accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") |
|
|
| def _save_checkpoint(self): |
| """save current training state""" |
|
|
| if self.accelerator.is_main_process: |
|
|
| checkpoint_path = os.path.join(self.checkpoint_dir, f"steps_{self.completed_steps}") |
| |
| state_dict = self.accelerator.get_state_dict(self.model) |
| torch.save(state_dict, checkpoint_path + "_pytorch_model.pt") |
|
|
| |
| summary_data = { |
| "steps": self.completed_steps, |
| } |
| with open(os.path.join(self.config.output_dir, "summary.jsonl"), "a") as f: |
| f.write(json.dumps(summary_data) + "\n") |
| self.accelerator.print(f"✅ Checkpoint saved at {checkpoint_path}") |
| |
| |
| max_checkpoints = getattr(self.config.trainer, "max_checkpoints_to_keep", None) |
| if max_checkpoints is not None and max_checkpoints > 0: |
| self._cleanup_old_checkpoints(max_checkpoints) |
| |
| self.accelerator.wait_for_everyone() |
| |
| def _cleanup_old_checkpoints(self, max_checkpoints: int): |
| """删除旧的checkpoint,只保留最近的N个""" |
| |
| if not self.accelerator.is_main_process: |
| return |
| |
| |
| checkpoint_pattern = os.path.join(self.checkpoint_dir, "steps_*_pytorch_model.pt") |
| checkpoint_files = glob.glob(checkpoint_pattern) |
| |
| if len(checkpoint_files) <= max_checkpoints: |
| return |
| |
| |
| def extract_steps(filepath): |
| match = re.search(r'steps_(\d+)_pytorch_model\.pt', filepath) |
| return int(match.group(1)) if match else 0 |
| |
| checkpoint_files.sort(key=extract_steps) |
| |
| |
| files_to_delete = checkpoint_files[:-max_checkpoints] |
| for filepath in files_to_delete: |
| try: |
| os.remove(filepath) |
| self.accelerator.print(f"🗑️ Deleted old checkpoint: {os.path.basename(filepath)}") |
| except Exception as e: |
| self.accelerator.print(f"⚠️ Failed to delete checkpoint {filepath}: {e}") |
|
|
| def _log_metrics(self, metrics): |
| """record training metrics""" |
| if self.completed_steps % self.config.trainer.logging_frequency == 0: |
| |
| if self._grad_norm_buffer: |
| metrics["grad_norm_pre_clip_avg"] = float( |
| sum(self._grad_norm_buffer) / len(self._grad_norm_buffer) |
| ) |
| self._grad_norm_buffer.clear() |
| if dist.get_rank() == 0: |
| |
| metrics["learning_rate"] = self.lr_scheduler.get_last_lr()[0] |
|
|
| |
| metrics["epoch"] = round(self.completed_steps / len(self.vla_train_dataloader), 2) |
|
|
| |
| wandb.log(metrics, step=self.completed_steps) |
| |
| gn_str = f"{metrics['grad_norm_pre_clip']:.4f}" if "grad_norm_pre_clip" in metrics else "N/A" |
| gn_avg_str = f"{metrics['grad_norm_pre_clip_avg']:.4f}" if "grad_norm_pre_clip_avg" in metrics else "N/A" |
| logger.info( |
| f"\nStep {self.completed_steps} | " |
| f"grad_norm_pre_clip={gn_str} | grad_norm_pre_clip_avg={gn_avg_str} | " |
| f"Metrics: {metrics}" |
| ) |
|
|
| def _create_data_iterators(self): |
| """create data iterators""" |
| self.vla_iter = iter(self.vla_train_dataloader) |
| |
|
|
| def _get_next_batch(self): |
| """get next batch (automatically handle data loop)""" |
| try: |
| batch_vla = next(self.vla_iter) |
| except StopIteration: |
| if not hasattr(self, "vla_epoch_count"): |
| self.vla_epoch_count = 0 |
| self.vla_iter, self.vla_epoch_count = TrainerUtils._reset_dataloader( |
| self.vla_train_dataloader, self.vla_epoch_count |
| ) |
| batch_vla = next(self.vla_iter) |
|
|
| return batch_vla |
| |
| def train(self): |
| """execute training loop""" |
| |
| self._log_training_config() |
|
|
| |
| self._create_data_iterators() |
|
|
| |
| progress_bar = tqdm( |
| range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process |
| ) |
|
|
| |
| while self.completed_steps < self.config.trainer.max_train_steps: |
| |
| t_start_data = time.perf_counter() |
| batch_vla = self._get_next_batch() |
| self._debug_print_first_batch(batch_vla) |
| t_end_data = time.perf_counter() |
|
|
| |
| t_start_model = time.perf_counter() |
| step_metrics = self._train_step(batch_vla) |
| t_end_model = time.perf_counter() |
|
|
| |
| if self.accelerator.sync_gradients: |
| progress_bar.update(1) |
| self.completed_steps += 1 |
| |
| if self.accelerator.is_local_main_process: |
| progress_bar.set_postfix( |
| { |
| "data_times": f"{t_end_data - t_start_data:.3f}", |
| "model_times": f"{t_end_model - t_start_model:.3f}", |
| } |
| ) |
|
|
| |
| if self.completed_steps % self.config.trainer.eval_interval == 0: |
| step_metrics = self.eval_action_model(step_metrics) |
|
|
| |
| step_metrics["data_time"] = t_end_data - t_start_data |
| step_metrics["model_time"] = t_end_model - t_start_model |
| self._log_metrics(step_metrics) |
|
|
| |
| if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0: |
| self._save_checkpoint() |
|
|
| |
| if self.completed_steps >= self.config.trainer.max_train_steps: |
| break |
|
|
| |
| self._finalize_training() |
|
|
| |
|
|
| def eval_action_model(self, step_metrics: dict = None, examples=None) -> float: |
| """ |
| Evaluate the model on the given dataset using the specified metric function. |
| |
| :param eval_dataset: List of evaluation samples, each containing 'image', 'instruction', and 'action'. |
| :param metric_fn: Function to compute the distance between predicted and ground truth actions. |
| :return: Average metric score across the evaluation dataset. |
| """ |
|
|
| if examples is None: |
| examples = self._get_next_batch() |
| score = 0.0 |
| |
| |
| if self.model.num_history_steps > 0: |
| start = self.model.num_history_steps |
| end = start + self.model.chunk_size |
| actions = [example["action"][start:end] for example in examples] |
| else: |
| actions = [example["action"][: self.model.chunk_size] for example in examples] |
| |
| output_dict = self.model.predict_action(examples=examples) |
|
|
| if self.accelerator.is_main_process: |
| normalized_actions = output_dict["normalized_actions"] |
| actions = np.array(actions) |
| |
| num_pots = np.prod(actions.shape) |
| |
| score = TrainerUtils.l1_distance(normalized_actions, actions) |
| average_score = score / num_pots |
| step_metrics["mae_score"] = average_score |
|
|
| del examples |
| dist.barrier() |
| return step_metrics |
|
|
| def _log_training_config(self): |
| """record training config""" |
| if self.accelerator.is_main_process: |
| logger.info("***** Training Configuration *****") |
| logger.info(f" Total optimization steps = {self.config.trainer.max_train_steps}") |
| logger.info(f" Per device batch size = {self.config.datasets.vla_data.per_device_batch_size}") |
| logger.info(f" Gradient accumulation steps = {self.config.trainer.gradient_accumulation_steps}") |
| logger.info(f" Total batch size = {self.total_batch_size}") |
| |
| logger.info("***** LR Scheduler Debug Info *****") |
| logger.info(f" lr_scheduler type = {type(self.lr_scheduler)}") |
| base_scheduler = getattr(self.lr_scheduler, 'scheduler', self.lr_scheduler) |
| logger.info(f" base_scheduler type = {type(base_scheduler)}") |
| logger.info(f" initial last_epoch = {getattr(base_scheduler, 'last_epoch', 'N/A')}") |
| logger.info(f" initial lr = {self.lr_scheduler.get_last_lr()}") |
| logger.info(f" num_warmup_steps = {self.config.trainer.num_warmup_steps}") |
| logger.info(f" num_stable_steps = {self.config.trainer.get('num_stable_steps', 0)}") |
| logger.info(f" max_train_steps = {self.config.trainer.max_train_steps}") |
| logger.info(f" accelerator.num_processes = {self.accelerator.num_processes}") |
| logger.info(f" accelerator.gradient_accumulation_steps = {self.accelerator.gradient_accumulation_steps}") |
| logger.info(f" trainer.mode = {self.training_mode}") |
| logger.info(f" loss_weights_decay_steps = {self.loss_weights_decay_steps}") |
|
|
| def _get_aux_loss_decay_weight(self) -> float: |
| if self.training_mode != "decay_aux_loss": |
| return 1.0 |
| progress = min(float(self.completed_steps) / float(self.loss_weights_decay_steps), 1.0) |
| return 1.0 - progress |
|
|
| @staticmethod |
| def _total_grad_norm_l2_local(parameters) -> float: |
| """L2 norm over all grads (same recipe as torch.nn.utils.clip_grad_norm_). DeepSpeed-safe fallback when clip_grad_norm_ returns None.""" |
| total_sq = 0.0 |
| for p in parameters: |
| if p.grad is None: |
| continue |
| |
| param_norm = p.grad.detach().float().norm(2) |
| total_sq += float(param_norm) ** 2 |
| return total_sq ** 0.5 |
|
|
| @staticmethod |
| def _grad_norm_scalar(value) -> float: |
| if value is None: |
| return float("nan") |
| if isinstance(value, torch.Tensor): |
| return float(value.detach().item()) |
| return float(value) |
|
|
| def _train_step(self, batch_vla, batch_vlm=None): |
| """execute single training step""" |
| is_deepspeed = self.accelerator.distributed_type == DistributedType.DEEPSPEED |
| grad_norm_pre_clip = None |
|
|
| with self.accelerator.accumulate(self.model): |
| self.optimizer.zero_grad() |
|
|
| |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| output_dict = self.model.forward(batch_vla, training_step=self.completed_steps) |
|
|
| align_loss = output_dict["align_loss"] |
| recon_loss = output_dict["recon_loss"] |
| predict_loss = output_dict["predict_loss"] |
| aux_loss_decay_weight = self._get_aux_loss_decay_weight() |
|
|
| if align_loss is not None and recon_loss is not None: |
| total_loss = ( |
| self.config.trainer.loss_scale.align_loss * aux_loss_decay_weight * align_loss |
| + self.config.trainer.loss_scale.recon_loss * aux_loss_decay_weight * recon_loss |
| + predict_loss |
| ) |
| else: |
| total_loss = predict_loss |
|
|
| |
| self.accelerator.backward(total_loss) |
|
|
| |
| |
| |
| |
| |
| if not is_deepspeed: |
| gc = getattr(self.config.trainer, "gradient_clipping", None) |
| max_norm = float(gc) if gc is not None else float("inf") |
| grad_norm_pre_clip = self.accelerator.clip_grad_norm_( |
| self.model.parameters(), max_norm |
| ) |
| if grad_norm_pre_clip is None: |
| grad_norm_pre_clip = self._total_grad_norm_l2_local(self.model.parameters()) |
|
|
| self.optimizer.step() |
|
|
| if self.accelerator.sync_gradients: |
| self.lr_scheduler.step() |
|
|
| |
| |
| |
| if is_deepspeed: |
| gn = getattr(self.model, "_global_grad_norm", None) |
| if gn is None: |
| |
| gn = self.accelerator.clip_grad_norm_(self.model.parameters(), float("inf")) |
| grad_norm_pre_clip = gn |
|
|
| gn_scalar = self._grad_norm_scalar(grad_norm_pre_clip) |
| self._grad_norm_buffer.append(gn_scalar) |
| step_metrics = { |
| "align_loss": align_loss.item() if align_loss is not None else None, |
| "recon_loss": recon_loss.item() if recon_loss is not None else None, |
| "predict_loss": predict_loss.item(), |
| "aux_loss_decay_weight": aux_loss_decay_weight, |
| "grad_norm_pre_clip": gn_scalar, |
| } |
| return step_metrics |
|
|
| def _finalize_training(self): |
| """training end processing""" |
| |
| if self.accelerator.is_main_process: |
| final_checkpoint = os.path.join(self.config.output_dir, "final_model") |
| os.makedirs(final_checkpoint, exist_ok=True) |
| state_dict = self.accelerator.get_state_dict(self.model) |
| torch.save(state_dict, os.path.join(final_checkpoint, "pytorch_model.pt")) |
| logger.info(f"Training complete. Final model saved at {final_checkpoint}") |
|
|
| |
| if self.accelerator.is_main_process: |
| wandb.finish() |
|
|
| self.accelerator.wait_for_everyone() |
|
|
|
|
| def main(cfg) -> None: |
| logger.info("VLA Training :: Warming Up") |
|
|
| |
| output_dir = setup_directories(cfg=cfg) |
| |
| vla = build_framework(cfg) |
| |
| vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir) |
|
|
| |
| |
| trainer = VLATrainer( |
| cfg=cfg, |
| model=vla, |
| vla_train_dataloader=vla_train_dataloader, |
| optimizer=None, |
| lr_scheduler=None, |
| accelerator=accelerator, |
| ) |
|
|
| |
| trainer.prepare_training() |
| |
| trainer.train() |
|
|
| |
| logger.info("... and that's all, folks!") |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, default="starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| |
| cfg = OmegaConf.load(args.config_yaml) |
| dotlist = normalize_dotlist_args(clipargs) |
| cli_cfg = OmegaConf.from_dotlist(dotlist) |
| cfg = OmegaConf.merge(cfg, cli_cfg) |
|
|
| |
| if cfg.is_debug and dist.is_initialized() and dist.get_rank() == 0: |
| import debugpy |
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
|
|
| main(cfg) |
|
|