File size: 3,827 Bytes
e94400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import torch
import numpy as np
import torch.distributed as dist
from omegaconf import OmegaConf

from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args, TrainerUtils
from starVLA.model.framework import build_framework
from starVLA.training.train_qwenlatent import (
    accelerator,
    logger,
    setup_directories,
    prepare_data,
    VLATrainer,
)


class QwenPITrainer(VLATrainer):
    def _train_step(self, batch_vla, batch_vlm=None):
        """Execute one training step for QwenPI (single `action_loss`)."""
        with self.accelerator.accumulate(self.model):
            self.optimizer.zero_grad()

            # QwenPI.forward() manages autocast internally (bfloat16 for VLM, float32 for action model);
            # do NOT wrap again here to avoid interfering with internal precision management.
            output_dict = self.model.forward(batch_vla)
            action_loss = output_dict["action_loss"]
            total_loss = action_loss

            self.accelerator.backward(total_loss)

            grad_norm = None
            if self.config.trainer.gradient_clipping is not None:
                grad_norm = self.accelerator.clip_grad_norm_(
                    self.model.parameters(), self.config.trainer.gradient_clipping
                )

            self.optimizer.step()

        if self.accelerator.sync_gradients:
            self.lr_scheduler.step()

        step_metrics = {"action_loss": action_loss.item()}
        if grad_norm is not None:
            step_metrics["grad_norm"] = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
        return step_metrics

    def eval_action_model(self, step_metrics: dict = None, examples=None) -> float:
        """
        Evaluate MAE for QwenPI using predicted horizon length.
        """
        if examples is None:
            examples = self._get_next_batch()

        output_dict = self.model.predict_action(examples=examples)

        if self.accelerator.is_main_process:
            normalized_actions = output_dict["normalized_actions"]  # [B, T_pred, D]
            pred_horizon = normalized_actions.shape[1]

            # QwenPI forward trains on the last future window (`[-pred_horizon:]`)
            actions = [example["action"][-pred_horizon:] for example in examples]
            actions = np.array(actions)

            num_points = np.prod(actions.shape)
            score = TrainerUtils.l1_distance(normalized_actions, actions)
            average_score = score / num_points
            step_metrics["mae_score"] = average_score

        del examples
        if dist.is_initialized():
            dist.barrier()
        return step_metrics


def main(cfg) -> None:
    logger.info("QwenPI Training :: Warming Up")

    output_dir = setup_directories(cfg=cfg)
    vla = build_framework(cfg)
    vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)

    trainer = QwenPITrainer(
        cfg=cfg,
        model=vla,
        vla_train_dataloader=vla_train_dataloader,
        optimizer=None,
        lr_scheduler=None,
        accelerator=accelerator,
    )
    trainer.prepare_training()
    trainer.train()

    logger.info("QwenPI training finished.")
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_yaml",
        type=str,
        default="starVLA/config/training/starvla_train_qwenpi.yaml",
        help="Path to YAML config",
    )
    args, clipargs = parser.parse_known_args()

    cfg = OmegaConf.load(args.config_yaml)
    dotlist = normalize_dotlist_args(clipargs)
    cli_cfg = OmegaConf.from_dotlist(dotlist)
    cfg = OmegaConf.merge(cfg, cli_cfg)

    main(cfg)