cross13tasks / code /training /train_qwengr00t.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
import argparse
import torch
import numpy as np
import torch.distributed as dist
from accelerate.utils import DistributedType
from omegaconf import OmegaConf
from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils
from starVLA.model.framework import build_framework
from starVLA.training.train_qwenlatent import (
accelerator,
logger,
setup_directories,
prepare_data,
VLATrainer,
)
class QwenGR00TTrainer(VLATrainer):
def _train_step(self, batch_vla, batch_vlm=None):
"""Execute one training step for QwenGR00T (single `action_loss`)."""
is_deepspeed = self.accelerator.distributed_type == DistributedType.DEEPSPEED
grad_norm_pre_clip = None
with self.accelerator.accumulate(self.model):
self.optimizer.zero_grad()
with torch.autocast("cuda", dtype=torch.bfloat16):
output_dict = self.model.forward(batch_vla, training_step=self.completed_steps)
action_loss = output_dict["action_loss"]
total_loss = action_loss
self.accelerator.backward(total_loss)
# For non-DeepSpeed: clip explicitly and capture pre-clip norm.
# For DeepSpeed: gradient clipping is handled internally during optimizer.step(),
# so we skip it here and retrieve the norm after step() below.
if not is_deepspeed:
gc = getattr(self.config.trainer, "gradient_clipping", None)
max_norm = float(gc) if gc is not None else float("inf")
grad_norm_pre_clip = self.accelerator.clip_grad_norm_(
self.model.parameters(), max_norm
)
if grad_norm_pre_clip is None:
grad_norm_pre_clip = self._total_grad_norm_l2_local(self.model.parameters())
self.optimizer.step()
if self.accelerator.sync_gradients:
self.lr_scheduler.step()
# For DeepSpeed: read the global grad norm populated by optimizer.step().
if is_deepspeed:
gn = getattr(self.model, "_global_grad_norm", None)
if gn is None:
gn = self.accelerator.clip_grad_norm_(self.model.parameters(), float("inf"))
grad_norm_pre_clip = gn
gn_scalar = self._grad_norm_scalar(grad_norm_pre_clip)
self._grad_norm_buffer.append(gn_scalar)
step_metrics = {
"action_loss": action_loss.item(),
"grad_norm_pre_clip": gn_scalar,
}
return step_metrics
def eval_action_model(self, step_metrics: dict = None, examples=None) -> float:
"""
Evaluate MAE for QwenGR00T using predicted horizon length.
"""
if examples is None:
examples = self._get_next_batch()
output_dict = self.model.predict_action(examples=examples)
if self.accelerator.is_main_process:
normalized_actions = output_dict["normalized_actions"] # [B, T_pred, D]
pred_horizon = normalized_actions.shape[1]
# QwenGR00T forward trains on the last future window (`[-pred_horizon:]`)
actions = [example["action"][-pred_horizon:] for example in examples]
actions = np.array(actions)
num_points = np.prod(actions.shape)
score = TrainerUtils.l1_distance(normalized_actions, actions)
average_score = score / num_points
step_metrics["mae_score"] = average_score
del examples
if dist.is_initialized():
dist.barrier()
return step_metrics
def main(cfg) -> None:
logger.info("QwenGR00T Training :: Warming Up")
output_dir = setup_directories(cfg=cfg)
vla = build_framework(cfg)
vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)
trainer = QwenGR00TTrainer(
cfg=cfg,
model=vla,
vla_train_dataloader=vla_train_dataloader,
optimizer=None,
lr_scheduler=None,
accelerator=accelerator,
)
trainer.prepare_training()
trainer.train()
logger.info("QwenGR00T training finished.")
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_yaml",
type=str,
default="starVLA/config/training/starvla_train_qwengr00t_oxe.yaml",
help="Path to YAML config",
)
args, clipargs = parser.parse_known_args()
cfg = OmegaConf.load(args.config_yaml)
dotlist = normalize_dotlist_args(clipargs)
cli_cfg = OmegaConf.from_dotlist(dotlist)
cfg = OmegaConf.merge(cfg, cli_cfg)
main(cfg)