cross13tasks / code /training /train_actionmodel.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
import sys
sys.path.append("/mnt/data/fangyu/code/reward_new")
"""
StarVLA’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack.
Conventions:
1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc).
2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures.
3. Put each training strategy in its own `trainer_*.py` file (avoid large if‑else chains).
"""
import warnings
warnings.filterwarnings("ignore")
# Standard Library
import argparse
import json
import os
os.environ["WANDB_API_KEY"] = "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"
from pathlib import Path
from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
import glob
import re
# Third-Party Libraries
import torch
import torch.distributed as dist
# import wandb
import yaml
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import AutoProcessor, get_scheduler
# Local Modules
from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args
from starVLA.model.framework import build_framework
from starVLA.training.trainer_utils.trainer_tools import TrainerUtils
from starVLA.training.trainer_utils.trainer_tools import build_param_lr_groups
deepspeed_plugin = DeepSpeedPlugin()
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
accelerator.print(accelerator.state)
# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize Overwatch =>> Wraps `logging.Logger`
from accelerate.logging import get_logger
logger = get_logger(__name__)
def load_fast_tokenizer():
fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
return fast_tokenizer
def setup_directories(cfg) -> Path:
"""create output directory and save config"""
cfg.output_dir = os.path.join(cfg.run_root_dir, cfg.run_id)
output_dir = Path(cfg.output_dir)
if not dist.is_initialized() or dist.get_rank() == 0:
# create output directory and checkpoint directory
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir / "checkpoints", exist_ok=True)
# save config
OmegaConf.save(cfg, output_dir / "config.yaml")
with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json:
yaml_cfg = yaml.safe_load(f_yaml)
json.dump(yaml_cfg, f_json, indent=2)
return output_dir
def build_model(cfg) -> torch.nn.Module:
"""build model framework"""
logger.info(f"Loading Base VLM `{cfg.framework.qwenvl.base_vlm}` from ID/Path")
model = build_framework(cfg)
return model
# here changes need to 📦 encapsulate Dataloader
from starVLA.dataloader import build_dataloader
def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]:
"""prepare training data"""
# VLA data loader
logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.data_mix}`")
vla_train_dataloader = build_dataloader(cfg=cfg, dataset_py=cfg.datasets.vla_data.dataset_py)
accelerator.dataloader_config.dispatch_batches = False
dist.barrier()
return vla_train_dataloader
def get_warmup_stable_cosine_scheduler(optimizer, num_warmup_steps, num_stable_steps, num_training_steps, min_lr_ratio=0.01):
"""
Warmup → Stable → Cosine Decay scheduler
Args:
optimizer: PyTorch optimizer
num_warmup_steps: warmup 阶段步数
num_stable_steps: 保持 max_lr 的步数(在 warmup 之后)
num_training_steps: 总训练步数
min_lr_ratio: 最终 lr / max_lr 的比例
Returns:
LambdaLR scheduler
"""
import math
def lr_lambda(current_step):
# Warmup 阶段:线性增长
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
# Stable 阶段:保持 max_lr
stable_end = num_warmup_steps + num_stable_steps
if current_step < stable_end:
return 1.0
# Cosine decay 阶段
decay_steps = num_training_steps - stable_end
if decay_steps <= 0:
return min_lr_ratio
progress = float(current_step - stable_end) / float(decay_steps)
return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress))
# 为每个参数组提供相同的 lr_lambda(支持多参数组优化器)
num_param_groups = len(optimizer.param_groups)
return torch.optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda] * num_param_groups)
def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
"""set optimizer and scheduler"""
# initialize optimizer
param_groups = build_param_lr_groups(model=model, cfg=cfg)
optimizer = torch.optim.AdamW(
param_groups,
lr=cfg.trainer.learning_rate.base,
betas=tuple(cfg.trainer.optimizer.betas),
weight_decay=cfg.trainer.optimizer.weight_decay,
eps=cfg.trainer.optimizer.eps,
)
# print optimizer group info
if dist.is_initialized() and dist.get_rank() == 0:
for i, group in enumerate(optimizer.param_groups):
logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}")
# initialize learning rate scheduler
if cfg.trainer.lr_scheduler_type == "warmup_stable_cosine":
# 自定义 scheduler: Warmup → Stable → Cosine Decay
min_lr_ratio = cfg.trainer.scheduler_specific_kwargs.get("min_lr_ratio", 0.01)
num_stable_steps = cfg.trainer.get("num_stable_steps", 0)
lr_scheduler = get_warmup_stable_cosine_scheduler(
optimizer=optimizer,
num_warmup_steps=cfg.trainer.num_warmup_steps,
num_stable_steps=num_stable_steps,
num_training_steps=cfg.trainer.max_train_steps,
min_lr_ratio=min_lr_ratio,
)
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(f"Using warmup_stable_cosine scheduler: warmup={cfg.trainer.num_warmup_steps}, "
f"stable={num_stable_steps}, total={cfg.trainer.max_train_steps}, min_lr_ratio={min_lr_ratio}")
else:
# 使用 transformers 内置 scheduler
lr_scheduler = get_scheduler(
name=cfg.trainer.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=cfg.trainer.num_warmup_steps,
num_training_steps=cfg.trainer.max_train_steps,
scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs,
)
return optimizer, lr_scheduler
class VLATrainer(TrainerUtils):
def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator):
self.config = cfg
self.model = model
self.vla_train_dataloader = vla_train_dataloader
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.accelerator = accelerator
self._printed_first_batch = False
# training status tracking
self.completed_steps = 0
self.total_batch_size = self._calculate_total_batch_size()
def _debug_print_first_batch(self, batch) -> None:
if self._printed_first_batch or not self.accelerator.is_local_main_process:
return
self._printed_first_batch = True
sample = None
if isinstance(batch, list):
sample = batch[0] if len(batch) > 0 else None
elif isinstance(batch, dict):
sample = batch
if sample is None:
self.accelerator.print("First batch is empty.")
return
def _describe_value(value):
if hasattr(value, "shape"):
try:
return f"{type(value).__name__}(shape={tuple(value.shape)})"
except Exception:
return type(value).__name__
if isinstance(value, list):
inner = type(value[0]).__name__ if value else "empty"
return f"list(len={len(value)}, inner={inner})"
return type(value).__name__
self.accelerator.print(f"First batch type: {type(batch).__name__}")
if isinstance(batch, list):
self.accelerator.print(f"First batch size: {len(batch)}")
self.accelerator.print("First sample keys:")
for key, value in sample.items():
self.accelerator.print(f" - {key}: {_describe_value(value)}")
# Print full content for first 5 samples to inspect inputs.
if isinstance(batch, list):
import numpy as np
max_samples = min(5, len(batch))
for i in range(max_samples):
self.accelerator.print(f"Sample[{i}] content:")
for key, value in batch[i].items():
if hasattr(value, "shape"):
try:
value_str = np.array2string(
value, threshold=np.inf, max_line_width=200
)
except Exception:
value_str = repr(value)
else:
value_str = repr(value)
self.accelerator.print(f" - {key}: {value_str}")
def prepare_training(self):
rank = dist.get_rank() if dist.is_initialized() else 0
seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047
set_seed(seed)
# load pretrained weights
if hasattr(self.config.trainer, "pretrained_checkpoint") and self.config.trainer.pretrained_checkpoint:
pretrained_checkpoint = self.config.trainer.pretrained_checkpoint
reload_modules = (
self.config.trainer.reload_modules if hasattr(self.config.trainer, "reload_modules") else None
)
self.model = self.load_pretrained_backbones(self.model, pretrained_checkpoint,
reload_modules=reload_modules)
# freeze parameters
freeze_modules = (
self.config.trainer.freeze_modules
if (self.config and hasattr(self.config.trainer, "freeze_modules"))
else None
)
self.model = self.freeze_backbones(self.model, freeze_modules=freeze_modules)
# print model trainable parameters:
self.print_trainable_parameters(self.model)
# build optimizer and scheduler AFTER freezing (critical for DeepSpeed ZeRO)
self.optimizer, self.lr_scheduler = setup_optimizer_and_scheduler(model=self.model, cfg=self.config)
# initialize distributed training components
# 注意:不传入 lr_scheduler,避免被 AcceleratedScheduler 包装(会导致 step 被调用 num_processes 倍)
self.model, self.optimizer, self.vla_train_dataloader = self.setup_distributed_training(
self.accelerator, # must be the first param
self.model,
self.optimizer,
self.vla_train_dataloader,
)
# lr_scheduler 保持原始的 LambdaLR,不被 Accelerate 包装
# self._init_wandb()
self._init_checkpointing()
def _calculate_total_batch_size(self):
"""calculate global batch size"""
return (
self.config.datasets.vla_data.per_device_batch_size
* self.accelerator.num_processes
* self.accelerator.gradient_accumulation_steps
)
def _init_wandb(self):
"""initialize Weights & Biases"""
# if self.accelerator.is_main_process:
# wandb.init(
# name=self.config.run_id,
# dir=os.path.join(self.config.output_dir, "wandb"),
# project=self.config.wandb_project,
# entity=self.config.wandb_entity,
# group="vla-train",
# settings=wandb.Settings(
# _disable_stats=False, # 确保启用系统监控
# x_stats_sampling_interval=10.0, # 每10秒采样一次系统指标
# ),
# )
pass
def _init_checkpointing(self):
"""initialize checkpoint directory"""
self.checkpoint_dir = os.path.join(self.config.output_dir, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
pretrained_checkpoint = getattr(self.config.trainer, "pretrained_checkpoint", None)
is_resume = getattr(self.config.trainer, "is_resume", False)
# resume train ckpt
if pretrained_checkpoint and is_resume:
self._load_checkpoint(self.config.resume_from_checkpoint)
def _load_checkpoint(self, checkpoint_path):
"""load checkpoint"""
self.accelerator.load_state(checkpoint_path)
self.accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
def _save_checkpoint(self):
"""save current training state"""
if self.accelerator.is_main_process:
checkpoint_path = os.path.join(self.checkpoint_dir, f"steps_{self.completed_steps}")
# save model state
state_dict = self.accelerator.get_state_dict(self.model)
torch.save(state_dict, checkpoint_path + "_pytorch_model.pt")
# save training metadata
summary_data = {
"steps": self.completed_steps,
}
with open(os.path.join(self.config.output_dir, "summary.jsonl"), "a") as f:
f.write(json.dumps(summary_data) + "\n")
self.accelerator.print(f"✅ Checkpoint saved at {checkpoint_path}")
# 删除旧的checkpoint,只保留最近的N个
max_checkpoints = getattr(self.config.trainer, "max_checkpoints_to_keep", None)
if max_checkpoints is not None and max_checkpoints > 0:
self._cleanup_old_checkpoints(max_checkpoints)
self.accelerator.wait_for_everyone()
def _cleanup_old_checkpoints(self, max_checkpoints: int):
"""删除旧的checkpoint,只保留最近的N个"""
# 只在主进程中执行,避免多进程竞态条件
if not self.accelerator.is_main_process:
return
# 获取所有checkpoint文件
checkpoint_pattern = os.path.join(self.checkpoint_dir, "steps_*_pytorch_model.pt")
checkpoint_files = glob.glob(checkpoint_pattern)
if len(checkpoint_files) <= max_checkpoints:
return
# 从文件名中提取步数,并按步数排序
def extract_steps(filepath):
match = re.search(r'steps_(\d+)_pytorch_model\.pt', filepath)
return int(match.group(1)) if match else 0
checkpoint_files.sort(key=extract_steps)
# 删除最旧的checkpoint
files_to_delete = checkpoint_files[:-max_checkpoints]
for filepath in files_to_delete:
try:
os.remove(filepath)
self.accelerator.print(f"🗑️ Deleted old checkpoint: {os.path.basename(filepath)}")
except Exception as e:
self.accelerator.print(f"⚠️ Failed to delete checkpoint {filepath}: {e}")
def _log_metrics(self, metrics):
"""record training metrics"""
if self.completed_steps % self.config.trainer.logging_frequency == 0:
if dist.get_rank() == 0:
# add learning rate
metrics["learning_rate"] = self.lr_scheduler.get_last_lr()[
0] # see lr group in yaml.trainer.learning_rate
# add epoch info
metrics["epoch"] = round(self.completed_steps / len(self.vla_train_dataloader), 2)
# record to W&B
# wandb.log(metrics, step=self.completed_steps)
# debug output
logger.info(f"\nStep {self.completed_steps}, Loss: {metrics})")
def _create_data_iterators(self):
"""create data iterators"""
self.vla_iter = iter(self.vla_train_dataloader)
# self.vlm_iter = iter(self.vlm_train_dataloader)
def _get_next_batch(self):
"""get next batch (automatically handle data loop)"""
try:
batch_vla = next(self.vla_iter)
except StopIteration:
if not hasattr(self, "vla_epoch_count"):
self.vla_epoch_count = 0
self.vla_iter, self.vla_epoch_count = TrainerUtils._reset_dataloader(
self.vla_train_dataloader, self.vla_epoch_count
)
batch_vla = next(self.vla_iter)
return batch_vla
def train(self):
"""execute training loop"""
# print training config
self._log_training_config()
# prepare data iterators
self._create_data_iterators()
# create progress bar
progress_bar = tqdm(
range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process
)
# main training loop
while self.completed_steps < self.config.trainer.max_train_steps:
# get data batch
t_start_data = time.perf_counter()
batch_vla = self._get_next_batch()
self._debug_print_first_batch(batch_vla)
t_end_data = time.perf_counter()
# execute training step
t_start_model = time.perf_counter()
step_metrics = self._train_step(batch_vla)
t_end_model = time.perf_counter()
# update progress
if self.accelerator.sync_gradients:
progress_bar.update(1)
self.completed_steps += 1
if self.accelerator.is_local_main_process:
progress_bar.set_postfix(
{
"data_times": f"{t_end_data - t_start_data:.3f}",
"model_times": f"{t_end_model - t_start_model:.3f}",
}
)
# evaluate model (predict action once and compute MAE)
eval_interval = getattr(self.config.trainer, "eval_interval", 0)
if eval_interval > 0 and self.completed_steps > 0 and self.completed_steps % eval_interval == 0:
step_metrics = self.eval_action_model(step_metrics)
# record metrics
step_metrics["data_time"] = t_end_data - t_start_data
step_metrics["model_time"] = t_end_model - t_start_model
self._log_metrics(step_metrics)
# save checkpoint
if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0:
self._save_checkpoint()
# check termination condition
if self.completed_steps >= self.config.trainer.max_train_steps:
break
# training end processing
self._finalize_training()
# execute evaluation step
def eval_action_model(self, step_metrics: dict = None):
"""
Evaluate action model: encode -> decode one batch, then compute MAE (L1) between
predicted and ground-truth actions. Compatible with ActionModelFM (encode_actions + decode_actions).
"""
if step_metrics is None:
step_metrics = {}
examples = self._get_next_batch()
device = next(self.model.parameters()).device
batch_size = len(examples)
# Use same chunk length for all samples (min over batch, capped by config)
max_chunk = getattr(
self.model.config, "max_action_chunk_size", 50
)
chunk_len = min(max_chunk, min(len(ex["action"]) for ex in examples))
if chunk_len < 1:
dist.barrier()
return step_metrics
# (B, L, D)
param_dtype = next(self.model.parameters()).dtype
raw_actions = np.array([ex["action"][:chunk_len] for ex in examples])
actions_tensor = torch.tensor(raw_actions, device=device, dtype=param_dtype) # [B, L, D]
use_state = self.model.use_state
if use_state:
states_tensor = torch.tensor(
np.array([ex["state"][:chunk_len] for ex in examples]),
device=device,
dtype=param_dtype,
) # [B, L, state_dim]
else:
states_tensor = None
dataset_ids = [ex.get("dataset_id") for ex in examples]
with torch.no_grad():
action_embedding = self.model.encode_actions(actions_tensor, dataset_ids, states_tensor)
pred_actions = self.model.decode_actions(action_embedding, chunk_size=chunk_len)
pred_np = pred_actions.cpu().float().numpy()
gt_np = raw_actions
if self.accelerator.is_main_process:
score = TrainerUtils.l1_distance(pred_np, gt_np)
num_elements = pred_np.size
mae_score = score / max(num_elements, 1)
step_metrics["mae_score"] = float(mae_score)
del examples, actions_tensor, action_embedding, pred_actions
dist.barrier()
return step_metrics
def _log_training_config(self):
"""record training config"""
if self.accelerator.is_main_process:
logger.info("***** Training Configuration *****")
logger.info(f" Total optimization steps = {self.config.trainer.max_train_steps}")
logger.info(f" Per device batch size = {self.config.datasets.vla_data.per_device_batch_size}")
logger.info(f" Gradient accumulation steps = {self.config.trainer.gradient_accumulation_steps}")
logger.info(f" Total batch size = {self.total_batch_size}")
logger.info("***** LR Scheduler Debug Info *****")
logger.info(f" lr_scheduler type = {type(self.lr_scheduler)}")
base_scheduler = getattr(self.lr_scheduler, 'scheduler', self.lr_scheduler)
logger.info(f" base_scheduler type = {type(base_scheduler)}")
logger.info(f" initial last_epoch = {getattr(base_scheduler, 'last_epoch', 'N/A')}")
logger.info(f" initial lr = {self.lr_scheduler.get_last_lr()}")
logger.info(f" num_warmup_steps = {self.config.trainer.num_warmup_steps}")
logger.info(f" num_stable_steps = {self.config.trainer.get('num_stable_steps', 0)}")
logger.info(f" max_train_steps = {self.config.trainer.max_train_steps}")
logger.info(f" accelerator.num_processes = {self.accelerator.num_processes}")
logger.info(f" accelerator.gradient_accumulation_steps = {self.accelerator.gradient_accumulation_steps}")
def _train_step(self, batch_vla, batch_vlm=None):
"""execute single training step"""
with self.accelerator.accumulate(self.model):
self.optimizer.zero_grad()
# VLA task forward propagation
with torch.autocast("cuda", dtype=torch.bfloat16):
recon_loss = self.model.forward(batch_vla)
# VLA backward propagation
self.accelerator.backward(recon_loss)
# gradient clipping
grad_norm = None
if self.config.trainer.gradient_clipping is not None:
grad_norm = self.accelerator.clip_grad_norm_(
self.model.parameters(), self.config.trainer.gradient_clipping
)
# optimizer step
self.optimizer.step()
if self.accelerator.sync_gradients:
self.lr_scheduler.step()
step_metrics = {
"recon_loss": recon_loss.item(),
}
if grad_norm is not None:
step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
return step_metrics
def _finalize_training(self):
"""training end processing"""
# save final model
if self.accelerator.is_main_process:
final_checkpoint = os.path.join(self.config.output_dir, "final_model")
os.makedirs(final_checkpoint, exist_ok=True)
state_dict = self.accelerator.get_state_dict(self.model)
torch.save(state_dict, os.path.join(final_checkpoint, "pytorch_model.pt"))
logger.info(f"Training complete. Final model saved at {final_checkpoint}")
# close W&B
if self.accelerator.is_main_process:
# wandb.finish()
pass
self.accelerator.wait_for_everyone()
def main(cfg) -> None:
logger.info("VLA Training :: Warming Up")
# create output directory and save config
output_dir = setup_directories(cfg=cfg)
# build model
vla = build_framework(cfg)
# prepare data
vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)
# create trainer
# Run VLA Training
trainer = VLATrainer(
cfg=cfg,
model=vla,
vla_train_dataloader=vla_train_dataloader,
optimizer=None,
lr_scheduler=None,
accelerator=accelerator,
)
# execute training preparation
trainer.prepare_training()
# execute training
trainer.train()
# And... we're done!
logger.info("... and that's all, folks!")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="starVLA/config/training/starvla_cotrain_oxe.yaml",
help="Path to YAML config")
args, clipargs = parser.parse_known_args()
# Load YAML config & Convert CLI overrides to dotlist config
cfg = OmegaConf.load(args.config_yaml)
dotlist = normalize_dotlist_args(clipargs) # Normalize CLI args to dotlist format
cli_cfg = OmegaConf.from_dotlist(dotlist)
cfg = OmegaConf.merge(cfg, cli_cfg)
# if cfg.is_debug:
if cfg.is_debug and dist.is_initialized() and dist.get_rank() == 0:
import debugpy
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
main(cfg)