| |
| |
| |
| """ |
| PI0 Trainer |
| 参考 train_qwenlatent.py,用于训练 PI0 模型。 |
| 支持: |
| - 从 pi0 预训练 checkpoint 加载权重 |
| - 使用 unified 37D action 表示(框架内截断到 PI0 所需的 32D) |
| - 与 lerobot_datasets 兼容 |
| """ |
| import sys |
| sys.path.append("/mnt/data/fangyu/code/reward_new") |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import argparse |
| import json |
| import os |
| import glob |
| import re |
| import time |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import wandb |
| import yaml |
| from accelerate import Accelerator, DeepSpeedPlugin |
| from accelerate.logging import get_logger |
| from accelerate.utils import set_seed |
| from omegaconf import OmegaConf |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from transformers import get_scheduler |
|
|
| from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args |
| from starVLA.model.framework import build_framework |
| from starVLA.training.trainer_utils.trainer_tools import TrainerUtils |
| from starVLA.training.trainer_utils.trainer_tools import build_param_lr_groups |
| from starVLA.dataloader import build_dataloader |
|
|
| |
| os.environ.setdefault("WANDB_API_KEY", "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") |
|
|
| deepspeed_plugin = DeepSpeedPlugin() |
| accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) |
| accelerator.print(accelerator.state) |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| logger = get_logger(__name__) |
|
|
|
|
| def setup_directories(cfg) -> Path: |
| """Create output directory and save config.""" |
| cfg.output_dir = os.path.join(cfg.run_root_dir, cfg.run_id) |
| output_dir = Path(cfg.output_dir) |
|
|
| if not dist.is_initialized() or dist.get_rank() == 0: |
| os.makedirs(output_dir, exist_ok=True) |
| os.makedirs(output_dir / "checkpoints", exist_ok=True) |
| OmegaConf.save(cfg, output_dir / "config.yaml") |
| with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json: |
| yaml_cfg = yaml.safe_load(f_yaml) |
| json.dump(yaml_cfg, f_json, indent=2) |
|
|
| return output_dir |
|
|
|
|
| def prepare_data(cfg, accelerator, output_dir) -> DataLoader: |
| """Prepare training data.""" |
| logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.data_mix}`") |
| vla_train_dataloader = build_dataloader(cfg=cfg, dataset_py=cfg.datasets.vla_data.dataset_py) |
|
|
| accelerator.dataloader_config.dispatch_batches = False |
| accelerator.wait_for_everyone() |
|
|
| return vla_train_dataloader |
|
|
|
|
| def get_warmup_stable_cosine_scheduler(optimizer, num_warmup_steps, num_stable_steps, num_training_steps, min_lr_ratio=0.01): |
| """Warmup → Stable → Cosine Decay scheduler.""" |
| import math |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| stable_end = num_warmup_steps + num_stable_steps |
| if current_step < stable_end: |
| return 1.0 |
| decay_steps = num_training_steps - stable_end |
| if decay_steps <= 0: |
| return min_lr_ratio |
| progress = float(current_step - stable_end) / float(decay_steps) |
| return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
| num_param_groups = len(optimizer.param_groups) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda] * num_param_groups) |
|
|
|
|
| def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]: |
| """Set optimizer and scheduler.""" |
| param_groups = build_param_lr_groups(model=model, cfg=cfg) |
| optimizer = torch.optim.AdamW( |
| param_groups, |
| lr=cfg.trainer.learning_rate.base, |
| betas=tuple(cfg.trainer.optimizer.betas), |
| weight_decay=cfg.trainer.optimizer.weight_decay, |
| eps=cfg.trainer.optimizer.eps, |
| ) |
|
|
| if dist.is_initialized() and dist.get_rank() == 0: |
| for i, group in enumerate(optimizer.param_groups): |
| logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}") |
|
|
| if cfg.trainer.lr_scheduler_type == "warmup_stable_cosine": |
| min_lr_ratio = cfg.trainer.scheduler_specific_kwargs.get("min_lr_ratio", 0.01) |
| num_stable_steps = cfg.trainer.get("num_stable_steps", 0) |
| lr_scheduler = get_warmup_stable_cosine_scheduler( |
| optimizer=optimizer, |
| num_warmup_steps=cfg.trainer.num_warmup_steps, |
| num_stable_steps=num_stable_steps, |
| num_training_steps=cfg.trainer.max_train_steps, |
| min_lr_ratio=min_lr_ratio, |
| ) |
| else: |
| lr_scheduler = get_scheduler( |
| name=cfg.trainer.lr_scheduler_type, |
| optimizer=optimizer, |
| num_warmup_steps=cfg.trainer.num_warmup_steps, |
| num_training_steps=cfg.trainer.max_train_steps, |
| scheduler_specific_kwargs=cfg.trainer.get("scheduler_specific_kwargs"), |
| ) |
|
|
| return optimizer, lr_scheduler |
|
|
|
|
| class PI0Trainer(TrainerUtils): |
| """Trainer for PI0 Framework.""" |
|
|
| def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator): |
| self.config = cfg |
| self.model = model |
| self.vla_train_dataloader = vla_train_dataloader |
| self.optimizer = optimizer |
| self.lr_scheduler = lr_scheduler |
| self.accelerator = accelerator |
| self._printed_first_batch = False |
|
|
| self.completed_steps = 0 |
| self.total_batch_size = ( |
| self.config.datasets.vla_data.per_device_batch_size |
| * self.accelerator.num_processes |
| * self.accelerator.gradient_accumulation_steps |
| ) |
|
|
| def _debug_print_first_batch(self, batch) -> None: |
| """Print first batch structure for debugging (only once, on local main process).""" |
| if self._printed_first_batch or not self.accelerator.is_local_main_process: |
| return |
| self._printed_first_batch = True |
|
|
| sample = None |
| if isinstance(batch, list): |
| sample = batch[0] if len(batch) > 0 else None |
| elif isinstance(batch, dict): |
| sample = batch |
|
|
| if sample is None: |
| self.accelerator.print("First batch is empty.") |
| return |
|
|
| def _describe_value(value): |
| if hasattr(value, "shape"): |
| try: |
| return f"{type(value).__name__}(shape={tuple(value.shape)})" |
| except Exception: |
| return type(value).__name__ |
| if isinstance(value, list): |
| inner = type(value[0]).__name__ if value else "empty" |
| return f"list(len={len(value)}, inner={inner})" |
| return type(value).__name__ |
|
|
| self.accelerator.print(f"[PI0Trainer] First batch type: {type(batch).__name__}, " |
| f"size: {len(batch) if isinstance(batch, list) else 1}") |
| self.accelerator.print("[PI0Trainer] First sample keys:") |
| for key, value in sample.items(): |
| self.accelerator.print(f" - {key}: {_describe_value(value)}") |
|
|
| def prepare_training(self): |
| rank = dist.get_rank() if dist.is_initialized() else 0 |
| seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047 |
| set_seed(seed) |
|
|
| |
| |
| if hasattr(self.config.trainer, "pretrained_checkpoint") and self.config.trainer.pretrained_checkpoint: |
| pretrained_checkpoint = self.config.trainer.pretrained_checkpoint |
| self.model = self.load_pretrained_backbones( |
| self.model, pretrained_checkpoint, |
| reload_modules=getattr(self.config.trainer, "reload_modules", None) |
| ) |
|
|
| self.print_trainable_parameters(self.model) |
|
|
| self.optimizer, self.lr_scheduler = setup_optimizer_and_scheduler(model=self.model, cfg=self.config) |
|
|
| self.model, self.optimizer, self.vla_train_dataloader = self.setup_distributed_training( |
| self.accelerator, |
| self.model, |
| self.optimizer, |
| self.vla_train_dataloader, |
| ) |
|
|
| self._init_wandb() |
| self._init_checkpointing() |
|
|
| def _init_wandb(self): |
| if self.accelerator.is_main_process: |
| wandb.init( |
| name=self.config.run_id, |
| dir=os.path.join(self.config.output_dir, "wandb"), |
| project=self.config.wandb_project, |
| entity=self.config.wandb_entity, |
| group="pi0-train", |
| ) |
|
|
| def _init_checkpointing(self): |
| self.checkpoint_dir = os.path.join(self.config.output_dir, "checkpoints") |
| os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
|
| if getattr(self.config.trainer, "is_resume", False) and getattr(self.config.trainer, "resume_from_checkpoint", None): |
| self.accelerator.load_state(self.config.trainer.resume_from_checkpoint) |
| self.accelerator.print(f"Resumed from checkpoint: {self.config.trainer.resume_from_checkpoint}") |
|
|
| def _save_checkpoint(self): |
| if self.accelerator.is_main_process: |
| checkpoint_path = os.path.join(self.checkpoint_dir, f"steps_{self.completed_steps}") |
| state_dict = self.accelerator.get_state_dict(self.model) |
| torch.save(state_dict, checkpoint_path + "_pytorch_model.pt") |
|
|
| with open(os.path.join(self.config.output_dir, "summary.jsonl"), "a") as f: |
| f.write(json.dumps({"steps": self.completed_steps}) + "\n") |
| self.accelerator.print(f"✅ Checkpoint saved at {checkpoint_path}") |
|
|
| max_checkpoints = getattr(self.config.trainer, "max_checkpoints_to_keep", None) |
| if max_checkpoints and max_checkpoints > 0: |
| self._cleanup_old_checkpoints(max_checkpoints) |
|
|
| self.accelerator.wait_for_everyone() |
|
|
| def _cleanup_old_checkpoints(self, max_checkpoints: int): |
| if not self.accelerator.is_main_process: |
| return |
| checkpoint_pattern = os.path.join(self.checkpoint_dir, "steps_*_pytorch_model.pt") |
| checkpoint_files = glob.glob(checkpoint_pattern) |
| if len(checkpoint_files) <= max_checkpoints: |
| return |
|
|
| def extract_steps(filepath): |
| match = re.search(r'steps_(\d+)_pytorch_model\.pt', filepath) |
| return int(match.group(1)) if match else 0 |
|
|
| checkpoint_files.sort(key=extract_steps) |
| for filepath in checkpoint_files[:-max_checkpoints]: |
| try: |
| os.remove(filepath) |
| self.accelerator.print(f"🗑️ Deleted old checkpoint: {os.path.basename(filepath)}") |
| except Exception as e: |
| self.accelerator.print(f"⚠️ Failed to delete {filepath}: {e}") |
|
|
| def _log_metrics(self, metrics): |
| if self.completed_steps % self.config.trainer.logging_frequency == 0: |
| |
| is_main = not dist.is_initialized() or dist.get_rank() == 0 |
| if is_main: |
| metrics["learning_rate"] = self.lr_scheduler.get_last_lr()[0] |
| metrics["epoch"] = round(self.completed_steps / max(len(self.vla_train_dataloader), 1), 2) |
| wandb.log(metrics, step=self.completed_steps) |
| logger.info(f"\nStep {self.completed_steps}, Loss: {metrics}") |
|
|
| def _create_data_iterators(self): |
| self.vla_iter = iter(self.vla_train_dataloader) |
|
|
| def _get_next_batch(self): |
| try: |
| batch_vla = next(self.vla_iter) |
| except StopIteration: |
| if not hasattr(self, "vla_epoch_count"): |
| self.vla_epoch_count = 0 |
| self.vla_iter, self.vla_epoch_count = TrainerUtils._reset_dataloader( |
| self.vla_train_dataloader, self.vla_epoch_count |
| ) |
| batch_vla = next(self.vla_iter) |
| return batch_vla |
|
|
| def train(self): |
| self._log_training_config() |
| self._create_data_iterators() |
|
|
| progress_bar = tqdm( |
| range(self.config.trainer.max_train_steps), |
| disable=not self.accelerator.is_local_main_process |
| ) |
|
|
| while self.completed_steps < self.config.trainer.max_train_steps: |
| t_start_data = time.perf_counter() |
| batch_vla = self._get_next_batch() |
| self._debug_print_first_batch(batch_vla) |
| t_end_data = time.perf_counter() |
|
|
| t_start_model = time.perf_counter() |
| step_metrics = self._train_step(batch_vla) |
| t_end_model = time.perf_counter() |
|
|
| if self.accelerator.sync_gradients: |
| progress_bar.update(1) |
| self.completed_steps += 1 |
|
|
| if self.accelerator.is_local_main_process: |
| progress_bar.set_postfix({ |
| "data_t": f"{t_end_data - t_start_data:.3f}s", |
| "model_t": f"{t_end_model - t_start_model:.3f}s", |
| "loss": f"{step_metrics.get('action_loss', 0):.4f}", |
| }) |
|
|
| step_metrics["data_time"] = t_end_data - t_start_data |
| step_metrics["model_time"] = t_end_model - t_start_model |
| self._log_metrics(step_metrics) |
|
|
| if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0: |
| self._save_checkpoint() |
|
|
| if self.completed_steps >= self.config.trainer.max_train_steps: |
| break |
|
|
| self._finalize_training() |
|
|
| def _log_training_config(self): |
| if self.accelerator.is_main_process: |
| logger.info("***** PI0 Training Configuration *****") |
| logger.info(f" Total steps = {self.config.trainer.max_train_steps}") |
| logger.info(f" Per device batch size = {self.config.datasets.vla_data.per_device_batch_size}") |
| logger.info(f" Total batch size (global) = {self.total_batch_size}") |
| logger.info(f" Gradient accumulation steps = {self.accelerator.gradient_accumulation_steps}") |
| logger.info(f" Num processes = {self.accelerator.num_processes}") |
| pi0_cfg = getattr(self.config.framework, "pi0", None) |
| if pi0_cfg is not None: |
| logger.info(f" PI0 action_dim = {getattr(pi0_cfg, 'action_dim', 'N/A')} " |
| f"(dataset 37D unified actions will be truncated to this dim)") |
| logger.info(f" PI0 action_horizon = {getattr(pi0_cfg, 'action_horizon', 'N/A')}") |
| logger.info(f" PI0 pi05 = {getattr(pi0_cfg, 'pi05', 'N/A')}") |
|
|
| def _train_step(self, batch_vla): |
| with self.accelerator.accumulate(self.model): |
| self.optimizer.zero_grad() |
|
|
| |
| |
| output_dict = self.model.forward(batch_vla) |
| action_loss = output_dict["action_loss"] |
| total_loss = action_loss |
|
|
| self.accelerator.backward(total_loss) |
|
|
| grad_norm = None |
| if self.config.trainer.gradient_clipping is not None: |
| grad_norm = self.accelerator.clip_grad_norm_( |
| self.model.parameters(), self.config.trainer.gradient_clipping |
| ) |
|
|
| self.optimizer.step() |
|
|
| if self.accelerator.sync_gradients: |
| self.lr_scheduler.step() |
|
|
| step_metrics = {"action_loss": action_loss.item()} |
| if grad_norm is not None: |
| step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) |
| return step_metrics |
|
|
| def _finalize_training(self): |
| if self.accelerator.is_main_process: |
| final_checkpoint = os.path.join(self.config.output_dir, "final_model") |
| os.makedirs(final_checkpoint, exist_ok=True) |
| state_dict = self.accelerator.get_state_dict(self.model) |
| torch.save(state_dict, os.path.join(final_checkpoint, "pytorch_model.pt")) |
| logger.info(f"Training complete. Final model saved at {final_checkpoint}") |
|
|
| if self.accelerator.is_main_process: |
| wandb.finish() |
|
|
| self.accelerator.wait_for_everyone() |
|
|
|
|
| def main(cfg) -> None: |
| logger.info("PI0 Training :: Warming Up") |
|
|
| output_dir = setup_directories(cfg=cfg) |
| model = build_framework(cfg) |
| vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir) |
|
|
| trainer = PI0Trainer( |
| cfg=cfg, |
| model=model, |
| vla_train_dataloader=vla_train_dataloader, |
| optimizer=None, |
| lr_scheduler=None, |
| accelerator=accelerator, |
| ) |
|
|
| trainer.prepare_training() |
| trainer.train() |
|
|
| logger.info("... and that's all, folks!") |
| if dist.is_initialized(): |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config_yaml", |
| type=str, |
| default="starVLA/config/training/starvla_train_pi0.yaml", |
| help="Path to YAML config", |
| ) |
| args, clipargs = parser.parse_known_args() |
|
|
| cfg = OmegaConf.load(args.config_yaml) |
| dotlist = normalize_dotlist_args(clipargs) |
| cli_cfg = OmegaConf.from_dotlist(dotlist) |
| cfg = OmegaConf.merge(cfg, cli_cfg) |
|
|
| if getattr(cfg, "is_debug", False) and dist.is_initialized() and dist.get_rank() == 0: |
| import debugpy |
| debugpy.listen(("0.0.0.0", 10092)) |
| print("🔍 Rank 0 waiting for debugger attach on port 10092...") |
| debugpy.wait_for_client() |
|
|
| main(cfg) |
|
|