import argparse import torch import numpy as np import torch.distributed as dist from omegaconf import OmegaConf from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils from starVLA.model.framework import build_framework from starVLA.training.train_qwenlatent import ( accelerator, logger, setup_directories, prepare_data, VLATrainer, ) class QwenPITrainer(VLATrainer): def _train_step(self, batch_vla, batch_vlm=None): """Execute one training step for QwenPI (single `action_loss`).""" with self.accelerator.accumulate(self.model): self.optimizer.zero_grad() # QwenPI.forward() manages autocast internally (bfloat16 for VLM, float32 for action model); # do NOT wrap again here to avoid interfering with internal precision management. output_dict = self.model.forward(batch_vla) action_loss = output_dict["action_loss"] total_loss = action_loss self.accelerator.backward(total_loss) grad_norm = None if self.config.trainer.gradient_clipping is not None: grad_norm = self.accelerator.clip_grad_norm_( self.model.parameters(), self.config.trainer.gradient_clipping ) self.optimizer.step() if self.accelerator.sync_gradients: self.lr_scheduler.step() step_metrics = {"action_loss": action_loss.item()} if grad_norm is not None: step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) return step_metrics def eval_action_model(self, step_metrics: dict = None, examples=None) -> float: """ Evaluate MAE for QwenPI using predicted horizon length. """ if examples is None: examples = self._get_next_batch() output_dict = self.model.predict_action(examples=examples) if self.accelerator.is_main_process: normalized_actions = output_dict["normalized_actions"] # [B, T_pred, D] pred_horizon = normalized_actions.shape[1] # QwenPI forward trains on the last future window (`[-pred_horizon:]`) actions = [example["action"][-pred_horizon:] for example in examples] actions = np.array(actions) num_points = np.prod(actions.shape) score = TrainerUtils.l1_distance(normalized_actions, actions) average_score = score / num_points step_metrics["mae_score"] = average_score del examples if dist.is_initialized(): dist.barrier() return step_metrics def main(cfg) -> None: logger.info("QwenPI Training :: Warming Up") output_dir = setup_directories(cfg=cfg) vla = build_framework(cfg) vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir) trainer = QwenPITrainer( cfg=cfg, model=vla, vla_train_dataloader=vla_train_dataloader, optimizer=None, lr_scheduler=None, accelerator=accelerator, ) trainer.prepare_training() trainer.train() logger.info("QwenPI training finished.") if dist.is_initialized(): dist.barrier() dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config_yaml", type=str, default="starVLA/config/training/starvla_train_qwenpi.yaml", help="Path to YAML config", ) args, clipargs = parser.parse_known_args() cfg = OmegaConf.load(args.config_yaml) dotlist = normalize_dotlist_args(clipargs) cli_cfg = OmegaConf.from_dotlist(dotlist) cfg = OmegaConf.merge(cfg, cli_cfg) main(cfg)