File size: 4,780 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import argparse
import torch
import numpy as np
import torch.distributed as dist
from accelerate.utils import DistributedType
from omegaconf import OmegaConf
from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils
from starVLA.model.framework import build_framework
from starVLA.training.train_qwenlatent import (
accelerator,
logger,
setup_directories,
prepare_data,
VLATrainer,
)
class QwenGR00TTrainer(VLATrainer):
def _train_step(self, batch_vla, batch_vlm=None):
"""Execute one training step for QwenGR00T (single `action_loss`)."""
is_deepspeed = self.accelerator.distributed_type == DistributedType.DEEPSPEED
grad_norm_pre_clip = None
with self.accelerator.accumulate(self.model):
self.optimizer.zero_grad()
with torch.autocast("cuda", dtype=torch.bfloat16):
output_dict = self.model.forward(batch_vla, training_step=self.completed_steps)
action_loss = output_dict["action_loss"]
total_loss = action_loss
self.accelerator.backward(total_loss)
# For non-DeepSpeed: clip explicitly and capture pre-clip norm.
# For DeepSpeed: gradient clipping is handled internally during optimizer.step(),
# so we skip it here and retrieve the norm after step() below.
if not is_deepspeed:
gc = getattr(self.config.trainer, "gradient_clipping", None)
max_norm = float(gc) if gc is not None else float("inf")
grad_norm_pre_clip = self.accelerator.clip_grad_norm_(
self.model.parameters(), max_norm
)
if grad_norm_pre_clip is None:
grad_norm_pre_clip = self._total_grad_norm_l2_local(self.model.parameters())
self.optimizer.step()
if self.accelerator.sync_gradients:
self.lr_scheduler.step()
# For DeepSpeed: read the global grad norm populated by optimizer.step().
if is_deepspeed:
gn = getattr(self.model, "_global_grad_norm", None)
if gn is None:
gn = self.accelerator.clip_grad_norm_(self.model.parameters(), float("inf"))
grad_norm_pre_clip = gn
gn_scalar = self._grad_norm_scalar(grad_norm_pre_clip)
self._grad_norm_buffer.append(gn_scalar)
step_metrics = {
"action_loss": action_loss.item(),
"grad_norm_pre_clip": gn_scalar,
}
return step_metrics
def eval_action_model(self, step_metrics: dict = None, examples=None) -> float:
"""
Evaluate MAE for QwenGR00T using predicted horizon length.
"""
if examples is None:
examples = self._get_next_batch()
output_dict = self.model.predict_action(examples=examples)
if self.accelerator.is_main_process:
normalized_actions = output_dict["normalized_actions"] # [B, T_pred, D]
pred_horizon = normalized_actions.shape[1]
# QwenGR00T forward trains on the last future window (`[-pred_horizon:]`)
actions = [example["action"][-pred_horizon:] for example in examples]
actions = np.array(actions)
num_points = np.prod(actions.shape)
score = TrainerUtils.l1_distance(normalized_actions, actions)
average_score = score / num_points
step_metrics["mae_score"] = average_score
del examples
if dist.is_initialized():
dist.barrier()
return step_metrics
def main(cfg) -> None:
logger.info("QwenGR00T Training :: Warming Up")
output_dir = setup_directories(cfg=cfg)
vla = build_framework(cfg)
vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)
trainer = QwenGR00TTrainer(
cfg=cfg,
model=vla,
vla_train_dataloader=vla_train_dataloader,
optimizer=None,
lr_scheduler=None,
accelerator=accelerator,
)
trainer.prepare_training()
trainer.train()
logger.info("QwenGR00T training finished.")
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_yaml",
type=str,
default="starVLA/config/training/starvla_train_qwengr00t_oxe.yaml",
help="Path to YAML config",
)
args, clipargs = parser.parse_known_args()
cfg = OmegaConf.load(args.config_yaml)
dotlist = normalize_dotlist_args(clipargs)
cli_cfg = OmegaConf.from_dotlist(dotlist)
cfg = OmegaConf.merge(cfg, cli_cfg)
main(cfg)
|