File size: 3,827 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import argparse
import torch
import numpy as np
import torch.distributed as dist
from omegaconf import OmegaConf
from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils
from starVLA.model.framework import build_framework
from starVLA.training.train_qwenlatent import (
accelerator,
logger,
setup_directories,
prepare_data,
VLATrainer,
)
class QwenPITrainer(VLATrainer):
def _train_step(self, batch_vla, batch_vlm=None):
"""Execute one training step for QwenPI (single `action_loss`)."""
with self.accelerator.accumulate(self.model):
self.optimizer.zero_grad()
# QwenPI.forward() manages autocast internally (bfloat16 for VLM, float32 for action model);
# do NOT wrap again here to avoid interfering with internal precision management.
output_dict = self.model.forward(batch_vla)
action_loss = output_dict["action_loss"]
total_loss = action_loss
self.accelerator.backward(total_loss)
grad_norm = None
if self.config.trainer.gradient_clipping is not None:
grad_norm = self.accelerator.clip_grad_norm_(
self.model.parameters(), self.config.trainer.gradient_clipping
)
self.optimizer.step()
if self.accelerator.sync_gradients:
self.lr_scheduler.step()
step_metrics = {"action_loss": action_loss.item()}
if grad_norm is not None:
step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
return step_metrics
def eval_action_model(self, step_metrics: dict = None, examples=None) -> float:
"""
Evaluate MAE for QwenPI using predicted horizon length.
"""
if examples is None:
examples = self._get_next_batch()
output_dict = self.model.predict_action(examples=examples)
if self.accelerator.is_main_process:
normalized_actions = output_dict["normalized_actions"] # [B, T_pred, D]
pred_horizon = normalized_actions.shape[1]
# QwenPI forward trains on the last future window (`[-pred_horizon:]`)
actions = [example["action"][-pred_horizon:] for example in examples]
actions = np.array(actions)
num_points = np.prod(actions.shape)
score = TrainerUtils.l1_distance(normalized_actions, actions)
average_score = score / num_points
step_metrics["mae_score"] = average_score
del examples
if dist.is_initialized():
dist.barrier()
return step_metrics
def main(cfg) -> None:
logger.info("QwenPI Training :: Warming Up")
output_dir = setup_directories(cfg=cfg)
vla = build_framework(cfg)
vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)
trainer = QwenPITrainer(
cfg=cfg,
model=vla,
vla_train_dataloader=vla_train_dataloader,
optimizer=None,
lr_scheduler=None,
accelerator=accelerator,
)
trainer.prepare_training()
trainer.train()
logger.info("QwenPI training finished.")
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_yaml",
type=str,
default="starVLA/config/training/starvla_train_qwenpi.yaml",
help="Path to YAML config",
)
args, clipargs = parser.parse_known_args()
cfg = OmegaConf.load(args.config_yaml)
dotlist = normalize_dotlist_args(clipargs)
cli_cfg = OmegaConf.from_dotlist(dotlist)
cfg = OmegaConf.merge(cfg, cli_cfg)
main(cfg)
|