File size: 7,733 Bytes
78e6df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccf25b1
 
 
 
 
 
 
78e6df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#!/usr/bin/env python3

import dataclasses

import jax
import numpy as np
import safetensors.torch
import torch
import tyro

import openpi.models.pi0_config
import openpi.training.config as _config
import openpi.training.data_loader as _data


@dataclasses.dataclass
class Args:
    baseline_config_name: str = "pi05_twin_handover_256_packed_baseline_pytorch_10k"
    parallel_config_name: str = "pi05_twin_handover_256_packed_parallel_pytorch_10k"
    baseline_ckpt: str = "/workspace/checkpoints/pi05_base_single_pytorch"
    parallel_ckpt: str = "/workspace/checkpoints/pi05_base_parallel_packed_from_single"
    repo_id: str = "lsnu/twin_handover_256_train"
    batch_size: int = 4
    num_workers: int = 0
    eval_seed: int = 777
    tolerance: float = 1e-6


def build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config:
    if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
        model_cfg = openpi.models.pi0_config.Pi0Config(
            dtype="float32",
            action_dim=config.model.action_dim,
            action_horizon=config.model.action_horizon,
            max_token_len=config.model.max_token_len,
            paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
            action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
            pi05=getattr(config.model, "pi05", False),
            arm_action_dims=getattr(config.model, "arm_action_dims", None),
        )
    else:
        model_cfg = dataclasses.replace(config.model)
        object.__setattr__(model_cfg, "dtype", "float32")
    return model_cfg


def compute_masked_action_loss(
    losses: torch.Tensor, action_loss_mask: tuple[float, ...] | None
) -> torch.Tensor:
    if action_loss_mask is None:
        return losses.mean()
    mask = torch.as_tensor(action_loss_mask, device=losses.device, dtype=losses.dtype)
    denom = mask.sum() * losses.shape[0] * losses.shape[1]
    return (losses * mask.view(1, 1, -1)).sum() / denom


def make_eval_noise(actions: torch.Tensor, seed: int) -> torch.Tensor:
    rng = np.random.default_rng(seed)
    noise = rng.standard_normal(size=tuple(actions.shape)).astype(np.float32)
    return torch.as_tensor(noise, device=actions.device, dtype=torch.float32)


def make_eval_time(batch_size: int, device: torch.device, seed: int) -> torch.Tensor:
    rng = np.random.default_rng(seed)
    time_beta = rng.beta(1.5, 1.0, size=batch_size).astype(np.float32)
    time_values = time_beta * 0.999 + 0.001
    return torch.as_tensor(time_values, device=device, dtype=torch.float32)


def main() -> None:
    args = tyro.cli(Args)
    print(
        f"starting_warmstart_equivalence baseline_config={args.baseline_config_name} "
        f"parallel_config={args.parallel_config_name} repo_id={args.repo_id}",
        flush=True,
    )
    baseline_config = _config.get_config(args.baseline_config_name)
    parallel_config = _config.get_config(args.parallel_config_name)
    parallel_model_cfg = build_model_config(parallel_config)
    if parallel_model_cfg.use_split_action_expert:
        raise ValueError(
            "Exact end-to-end warm-start equivalence is not expected for split action experts. "
            "Use init_parallel_pi05_from_single_pytorch.py for branch copy checks and "
            "check_split_expert_invariants.py for branch-local invariants."
        )

    data_config = baseline_config.data.create(baseline_config.assets_dirs, baseline_config.model)
    data_config = dataclasses.replace(data_config, repo_id=args.repo_id)
    loader = _data.create_torch_data_loader(
        data_config,
        model_config=baseline_config.model,
        action_horizon=baseline_config.model.action_horizon,
        batch_size=args.batch_size,
        shuffle=False,
        num_batches=1,
        num_workers=args.num_workers,
        seed=baseline_config.seed,
        framework="pytorch",
    )
    print("loaded_eval_dataloader", flush=True)

    import openpi.models_pytorch.pi0_pytorch as pi0_pytorch

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    observation, actions = next(iter(loader))
    print("loaded_reference_batch", flush=True)
    actions = actions.to(torch.float32)
    actions_device = actions.to(device)
    noise = make_eval_noise(actions_device, args.eval_seed)
    time_values = make_eval_time(actions_device.shape[0], device, args.eval_seed + 10_000)
    time_expanded = time_values[:, None, None]
    x_t = time_expanded * noise + (1 - time_expanded) * actions_device

    def run_model(config: _config.TrainConfig, checkpoint_dir: str):
        print(f"loading_model config={config.name} checkpoint={checkpoint_dir}", flush=True)
        model = pi0_pytorch.PI0Pytorch(build_model_config(config)).to(device)
        missing, unexpected = safetensors.torch.load_model(model, f"{checkpoint_dir}/model.safetensors", strict=False)
        model.eval()
        print(f"running_forward config={config.name}", flush=True)
        observation_device = jax.tree.map(lambda x: x.to(device), observation)
        actions_local = actions.to(device)
        with torch.inference_mode():
            projected_inputs = model._project_action_inputs(x_t)
            losses = model(observation_device, actions_local, noise=noise, time=time_values).to(torch.float32)
        print(f"finished_forward config={config.name}", flush=True)
        del model, observation_device, actions_local
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return projected_inputs.cpu(), losses.cpu(), list(missing), list(unexpected)

    baseline_in, baseline_losses, baseline_missing, baseline_unexpected = run_model(
        baseline_config, args.baseline_ckpt
    )
    parallel_in, parallel_losses, parallel_missing, parallel_unexpected = run_model(
        parallel_config, args.parallel_ckpt
    )

    projection_diff = (baseline_in - parallel_in).abs()
    loss_diff = (baseline_losses - parallel_losses).abs()
    baseline_masked_loss = compute_masked_action_loss(baseline_losses, baseline_config.action_loss_mask)
    parallel_masked_loss = compute_masked_action_loss(parallel_losses, parallel_config.action_loss_mask)
    masked_loss_abs_diff = abs(float(baseline_masked_loss.item()) - float(parallel_masked_loss.item()))

    print(f"baseline_config_name: {args.baseline_config_name}")
    print(f"parallel_config_name: {args.parallel_config_name}")
    print(f"repo_id_used: {args.repo_id}")
    print(f"baseline_ckpt: {args.baseline_ckpt}")
    print(f"parallel_ckpt: {args.parallel_ckpt}")
    print(f"batch_size: {args.batch_size}")
    print(f"eval_seed: {args.eval_seed}")
    print(f"tolerance: {args.tolerance}")
    print(f"baseline_missing_keys: {list(baseline_missing)}")
    print(f"baseline_unexpected_keys: {list(baseline_unexpected)}")
    print(f"parallel_missing_keys: {list(parallel_missing)}")
    print(f"parallel_unexpected_keys: {list(parallel_unexpected)}")
    print(f"input_projection_max_abs_diff: {float(projection_diff.max().item()):.8f}")
    print(f"input_projection_mean_abs_diff: {float(projection_diff.mean().item()):.8f}")
    print(f"loss_max_abs_diff: {float(loss_diff.max().item()):.8f}")
    print(f"loss_mean_abs_diff: {float(loss_diff.mean().item()):.8f}")
    print(f"baseline_masked_loss: {float(baseline_masked_loss.item()):.8f}")
    print(f"parallel_masked_loss: {float(parallel_masked_loss.item()):.8f}")
    print(f"masked_loss_abs_diff: {masked_loss_abs_diff:.8f}")
    print(
        "warmstart_equivalent: "
        f"{float(projection_diff.max().item()) <= args.tolerance and float(loss_diff.max().item()) <= args.tolerance}"
    )


if __name__ == "__main__":
    main()