| import argparse |
| import torch |
| import numpy as np |
| import torch.distributed as dist |
| from omegaconf import OmegaConf |
|
|
| from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils |
| from starVLA.model.framework import build_framework |
| from starVLA.training.train_qwenlatent import ( |
| accelerator, |
| logger, |
| setup_directories, |
| prepare_data, |
| VLATrainer, |
| ) |
|
|
|
|
| class QwenPITrainer(VLATrainer): |
| def _train_step(self, batch_vla, batch_vlm=None): |
| """Execute one training step for QwenPI (single `action_loss`).""" |
| with self.accelerator.accumulate(self.model): |
| self.optimizer.zero_grad() |
|
|
| |
| |
| output_dict = self.model.forward(batch_vla) |
| action_loss = output_dict["action_loss"] |
| total_loss = action_loss |
|
|
| self.accelerator.backward(total_loss) |
|
|
| grad_norm = None |
| if self.config.trainer.gradient_clipping is not None: |
| grad_norm = self.accelerator.clip_grad_norm_( |
| self.model.parameters(), self.config.trainer.gradient_clipping |
| ) |
|
|
| self.optimizer.step() |
|
|
| if self.accelerator.sync_gradients: |
| self.lr_scheduler.step() |
|
|
| step_metrics = {"action_loss": action_loss.item()} |
| if grad_norm is not None: |
| step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) |
| return step_metrics |
|
|
| def eval_action_model(self, step_metrics: dict = None, examples=None) -> float: |
| """ |
| Evaluate MAE for QwenPI using predicted horizon length. |
| """ |
| if examples is None: |
| examples = self._get_next_batch() |
|
|
| output_dict = self.model.predict_action(examples=examples) |
|
|
| if self.accelerator.is_main_process: |
| normalized_actions = output_dict["normalized_actions"] |
| pred_horizon = normalized_actions.shape[1] |
|
|
| |
| actions = [example["action"][-pred_horizon:] for example in examples] |
| actions = np.array(actions) |
|
|
| num_points = np.prod(actions.shape) |
| score = TrainerUtils.l1_distance(normalized_actions, actions) |
| average_score = score / num_points |
| step_metrics["mae_score"] = average_score |
|
|
| del examples |
| if dist.is_initialized(): |
| dist.barrier() |
| return step_metrics |
|
|
|
|
| def main(cfg) -> None: |
| logger.info("QwenPI Training :: Warming Up") |
|
|
| output_dir = setup_directories(cfg=cfg) |
| vla = build_framework(cfg) |
| vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir) |
|
|
| trainer = QwenPITrainer( |
| cfg=cfg, |
| model=vla, |
| vla_train_dataloader=vla_train_dataloader, |
| optimizer=None, |
| lr_scheduler=None, |
| accelerator=accelerator, |
| ) |
| trainer.prepare_training() |
| trainer.train() |
|
|
| logger.info("QwenPI training finished.") |
| if dist.is_initialized(): |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config_yaml", |
| type=str, |
| default="starVLA/config/training/starvla_train_qwenpi.yaml", |
| help="Path to YAML config", |
| ) |
| args, clipargs = parser.parse_known_args() |
|
|
| cfg = OmegaConf.load(args.config_yaml) |
| dotlist = normalize_dotlist_args(clipargs) |
| cli_cfg = OmegaConf.from_dotlist(dotlist) |
| cfg = OmegaConf.merge(cfg, cli_cfg) |
|
|
| main(cfg) |
|
|