File size: 4,780 Bytes
e94400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import torch
import numpy as np
import torch.distributed as dist
from accelerate.utils import DistributedType
from omegaconf import OmegaConf

from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils
from starVLA.model.framework import build_framework
from starVLA.training.train_qwenlatent import (
    accelerator,
    logger,
    setup_directories,
    prepare_data,
    VLATrainer,
)


class QwenGR00TTrainer(VLATrainer):
    def _train_step(self, batch_vla, batch_vlm=None):
        """Execute one training step for QwenGR00T (single `action_loss`)."""
        is_deepspeed = self.accelerator.distributed_type == DistributedType.DEEPSPEED
        grad_norm_pre_clip = None

        with self.accelerator.accumulate(self.model):
            self.optimizer.zero_grad()

            with torch.autocast("cuda", dtype=torch.bfloat16):
                output_dict = self.model.forward(batch_vla, training_step=self.completed_steps)
                action_loss = output_dict["action_loss"]
                total_loss = action_loss

            self.accelerator.backward(total_loss)

            # For non-DeepSpeed: clip explicitly and capture pre-clip norm.
            # For DeepSpeed: gradient clipping is handled internally during optimizer.step(),
            # so we skip it here and retrieve the norm after step() below.
            if not is_deepspeed:
                gc = getattr(self.config.trainer, "gradient_clipping", None)
                max_norm = float(gc) if gc is not None else float("inf")
                grad_norm_pre_clip = self.accelerator.clip_grad_norm_(
                    self.model.parameters(), max_norm
                )
                if grad_norm_pre_clip is None:
                    grad_norm_pre_clip = self._total_grad_norm_l2_local(self.model.parameters())

            self.optimizer.step()

        if self.accelerator.sync_gradients:
            self.lr_scheduler.step()

            # For DeepSpeed: read the global grad norm populated by optimizer.step().
            if is_deepspeed:
                gn = getattr(self.model, "_global_grad_norm", None)
                if gn is None:
                    gn = self.accelerator.clip_grad_norm_(self.model.parameters(), float("inf"))
                grad_norm_pre_clip = gn

        gn_scalar = self._grad_norm_scalar(grad_norm_pre_clip)
        self._grad_norm_buffer.append(gn_scalar)

        step_metrics = {
            "action_loss": action_loss.item(),
            "grad_norm_pre_clip": gn_scalar,
        }
        return step_metrics

    def eval_action_model(self, step_metrics: dict = None, examples=None) -> float:
        """
        Evaluate MAE for QwenGR00T using predicted horizon length.
        """
        if examples is None:
            examples = self._get_next_batch()

        output_dict = self.model.predict_action(examples=examples)

        if self.accelerator.is_main_process:
            normalized_actions = output_dict["normalized_actions"]  # [B, T_pred, D]
            pred_horizon = normalized_actions.shape[1]

            # QwenGR00T forward trains on the last future window (`[-pred_horizon:]`)
            actions = [example["action"][-pred_horizon:] for example in examples]
            actions = np.array(actions)

            num_points = np.prod(actions.shape)
            score = TrainerUtils.l1_distance(normalized_actions, actions)
            average_score = score / num_points
            step_metrics["mae_score"] = average_score

        del examples
        if dist.is_initialized():
            dist.barrier()
        return step_metrics


def main(cfg) -> None:
    logger.info("QwenGR00T Training :: Warming Up")

    output_dir = setup_directories(cfg=cfg)
    vla = build_framework(cfg)
    vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)

    trainer = QwenGR00TTrainer(
        cfg=cfg,
        model=vla,
        vla_train_dataloader=vla_train_dataloader,
        optimizer=None,
        lr_scheduler=None,
        accelerator=accelerator,
    )
    trainer.prepare_training()
    trainer.train()

    logger.info("QwenGR00T training finished.")
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_yaml",
        type=str,
        default="starVLA/config/training/starvla_train_qwengr00t_oxe.yaml",
        help="Path to YAML config",
    )
    args, clipargs = parser.parse_known_args()

    cfg = OmegaConf.load(args.config_yaml)
    dotlist = normalize_dotlist_args(clipargs)
    cli_cfg = OmegaConf.from_dotlist(dotlist)
    cfg = OmegaConf.merge(cfg, cli_cfg)

    main(cfg)