File size: 52,266 Bytes
e94400c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 | # Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
"""
PI0 Framework
将 openpi 的 PI0Pytorch 模型封装为 starVLA framework 接口。
支持:
- 从 pi0 safetensors checkpoint 加载预训练参数
- 与其他 starVLA framework 相同的 __init__ 和 predict_action 接口
- 使用 PaliGemma SentencePiece tokenizer 处理语言指令
- 将 starVLA 样本格式(PIL 图像列表 + lang 字符串)转换为 PI0 Observation 格式
参考来源:
- openpi/src/openpi/models_pytorch/pi0_pytorch.py (PI0Pytorch)
- openpi/src/openpi/policies/policy.py (Policy.infer)
"""
import sys
import logging
from pathlib import Path
from typing import List, Optional, Dict, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.tools import FRAMEWORK_REGISTRY
from starVLA.training.trainer_utils import initialize_overwatch
logger = initialize_overwatch(__name__)
# ────────────────────────────────────────────────────────────────
# 常量
# ────────────────────────────────────────────────────────────────
# PI0 默认图像分辨率
_IMAGE_RESOLUTION = (224, 224)
# PI0 默认使用的图像键名(与 openpi 保持一致)
_DEFAULT_IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
# ────────────────────────────────────────────────────────────────
# 辅助函数
# ────────────────────────────────────────────────────────────────
def _pil_to_tensor_normalized(img, resolution=(224, 224)) -> torch.Tensor:
"""
将图像转换为 PI0 所需的张量格式。
输出为 channels-first 格式 [C, H, W],归一化至 [-1, 1]。
PI0 的 preprocess_observation_pytorch 检测到 channels-first(shape[1]==3)时
会先转为 [B, H, W, C] 做增广,再转回 [B, C, H, W],最终送入 SigLIP conv2d
(SigLIP 的 patch_embedding 期望 [B, C, H, W] 格式)。
Args:
img: PIL.Image.Image 或 np.ndarray (H, W, 3) uint8。
同时兼容来自 eval_libero.py 的 numpy uint8 数组和 PIL Image。
resolution: 目标分辨率 (H, W),默认 (224, 224)。
Returns:
torch.Tensor: shape [C, H, W], dtype float32, 值域 [-1, 1]。
"""
# ── 统一转为 PIL Image ────────────────────────────────────────
if isinstance(img, np.ndarray):
# eval_libero.py 传入 (H, W, 3) uint8 numpy 数组
img = Image.fromarray(img.astype(np.uint8))
# ── PIL 预处理 ────────────────────────────────────────────────
if img.mode != "RGB":
img = img.convert("RGB")
if img.size != (resolution[1], resolution[0]):
img = img.resize((resolution[1], resolution[0]), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, C], [0, 1]
arr = arr * 2.0 - 1.0 # [H, W, C], [-1, 1]
t = torch.from_numpy(arr) # [H, W, C]
return t.permute(2, 0, 1) # [C, H, W] channels-first
def _build_pi0_config_obj(framework_cfg):
"""
从 starVLA 的 framework 配置节点中构造 PI0Pytorch 所需的轻量配置对象。
PI0Pytorch.__init__ 需要的字段:
- paligemma_variant : str (e.g. "gemma_2b")
- action_expert_variant: str (e.g. "gemma_300m")
- pi05 : bool
- action_dim : int
- action_horizon : int
- dtype : str (e.g. "bfloat16")
这些字段可写在 yaml 的 framework.pi0 节点下,缺省值来自 openpi 的 Pi0Config。
"""
pi0_node = getattr(framework_cfg, "pi0", None) or {}
def _get(key, default):
if hasattr(pi0_node, key):
return getattr(pi0_node, key)
if isinstance(pi0_node, dict):
return pi0_node.get(key, default)
return default
class _PI0Config:
paligemma_variant = _get("paligemma_variant", "gemma_2b")
action_expert_variant= _get("action_expert_variant","gemma_300m")
pi05 = _get("pi05", False)
action_dim = _get("action_dim", 37)
state_dim = _get("state_dim", 74)
action_horizon = _get("action_horizon", 15)
dtype = _get("dtype", "bfloat16")
# max_token_len 按 pi05 标志自动推断(与 Pi0Config.__post_init__ 保持一致)
_max_token_len_raw = _get("max_token_len", None)
@property
def max_token_len(self):
if self._max_token_len_raw is not None:
return self._max_token_len_raw
return 200 if self.pi05 else 48
return _PI0Config()
class _SentencePieceTokenizer:
"""
轻量封装 SentencePiece tokenizer,用于 PaliGemma 风格的提示词编码。
与 openpi.models.tokenizer.PaligemmaTokenizer 逻辑相同,
但不依赖 openpi 的 GCS 下载逻辑,改为接受本地文件路径。
"""
def __init__(self, model_path: str, max_len: int = 48):
import sentencepiece
self._max_len = max_len
with open(model_path, "rb") as f:
self._sp = sentencepiece.SentencePieceProcessor(model_proto=f.read())
def tokenize(self, prompt: str, state: Optional[np.ndarray] = None):
"""
返回 (tokens: np.ndarray[int32], mask: np.ndarray[bool]),长度均为 max_len。
"""
cleaned = prompt.strip().replace("_", " ").replace("\n", " ")
if state is not None:
# pi05 格式:将连续 state 离散化后拼入提示词
bins = np.linspace(-1, 1, 257)[:-1]
disc = np.digitize(state, bins=bins) - 1
state_str = " ".join(map(str, disc))
full_prompt = f"Task: {cleaned}, State: {state_str};\nAction: "
tokens = self._sp.encode(full_prompt, add_bos=True)
else:
tokens = self._sp.encode(cleaned, add_bos=True) + self._sp.encode("\n")
tokens_len = len(tokens)
if tokens_len < self._max_len:
pad = [False] * (self._max_len - tokens_len)
mask = [True] * tokens_len + pad
tokens = tokens + [0] * (self._max_len - tokens_len)
else:
if tokens_len > self._max_len:
logger.warning(
f"Token length ({tokens_len}) > max_len ({self._max_len}), truncating."
)
tokens = tokens[: self._max_len]
mask = [True] * self._max_len
return np.array(tokens, dtype=np.int32), np.array(mask, dtype=bool)
# ────────────────────────────────────────────────────────────────
# PI0 Framework
# ────────────────────────────────────────────────────────────────
@FRAMEWORK_REGISTRY.register("PI0")
class PI0Framework(baseframework):
"""
starVLA framework 封装,将 openpi 的 PI0Pytorch 模型集成到 starVLA 生态。
Config 节点(yaml 示例):
```yaml
framework:
name: "PI0"
pi0:
paligemma_variant: "gemma_2b" # PaliGemma backbone 变体
action_expert_variant: "gemma_300m" # Action Expert 变体
pi05: false # 是否使用 Pi0.5 版本
action_dim: 32 # 动作维度
action_horizon: 50 # 预测动作步数
dtype: "bfloat16" # 模型权重精度
tokenizer_path: "/path/to/paligemma_tokenizer.model" # SentencePiece 模型
pi0_checkpoint: "/path/to/model.safetensors" # 可选,pi0 预训练权重
image_keys: # 图像键名顺序(对应 examples["image"] 中的顺序)
- "base_0_rgb"
- "left_wrist_0_rgb"
- "right_wrist_0_rgb"
num_inference_steps: 10 # 流匹配推理步数
```
接口:
__init__(config) : 构建模型、加载 tokenizer、可选加载预训练权重
predict_action(examples) : 推理,返回 {"normalized_actions": np.ndarray}
forward(examples) : 训练前向(暂未实现)
"""
def __init__(
self,
config: Optional[Any] = None,
**kwargs,
) -> None:
"""
初始化 PI0Framework。
Args:
config: starVLA 层级配置(OmegaConf DictConfig 或兼容对象),
须包含 framework.pi0、framework.tokenizer_path 等字段。
**kwargs: 预留。
"""
super().__init__()
self.config = config
fw_cfg = getattr(config, "framework", config)
# ── 1. 构建 PI0Pytorch 配置 ──────────────────────────────
pi0_cfg = _build_pi0_config_obj(fw_cfg)
self._pi0_cfg = pi0_cfg
# ── 2. 初始化 PI0Pytorch 模型 ────────────────────────────
try:
sys.path.append("/mnt/data/fangyu/code/MixtureOfHorizons/src")
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
except ImportError as e:
raise ImportError(
"PI0Framework 依赖 openpi 包。请确保 openpi 已安装或 "
"openpi/src 已添加到 PYTHONPATH。原始错误: " + str(e)
) from e
self.pi0_model: nn.Module = PI0Pytorch(config=pi0_cfg)
logger.info(
f"PI0Pytorch 已初始化:variant={pi0_cfg.paligemma_variant}, "
f"action_dim={pi0_cfg.action_dim}, action_horizon={pi0_cfg.action_horizon}, "
f"pi05={pi0_cfg.pi05}"
)
# ── 2b. 替换硬编码的 32D 投影层 ────────────────────────────
# PI0Pytorch 源码中 action_in_proj / action_out_proj / state_proj 均硬编码为 Linear(32, ...),
# 不读取 config.action_dim / config.state_dim。
# action_in_proj / action_out_proj 按 action_dim 替换;
# state_proj 按 state_dim 替换(state_dim 可与 action_dim 不同,如 unified 74D state)。
# 替换后的层为随机初始化,加载 checkpoint 时会因 shape 不匹配而自动跳过(由 _filter_ckpt_by_shape 保证)。
self._replace_pi0_projection_layers(pi0_cfg.action_dim, pi0_cfg.state_dim)
# ── 3. 图像键名映射 ──────────────────────────────────────
_ik = getattr(fw_cfg, "image_keys", None)
if _ik is not None:
self.image_keys = list(_ik)
else:
self.image_keys = list(_DEFAULT_IMAGE_KEYS)
# ── 4. 推理步数 ──────────────────────────────────────────
self.num_inference_steps = getattr(fw_cfg, "num_inference_steps", 10)
# ── 4b. 有效动作维度(用于截断模型输出)────────────────────
# PI0 action_dim=32,但 LIBERO 实际只用前 7 维(3 pos + 3 rot + 1 gripper)。
# 若 config.framework.effective_action_dim 已设置,predict_action 会将输出
# 截断至前 N 维,以匹配 model2libero_interface.py 的 unnormalize 期望维度。
self.effective_action_dim = getattr(fw_cfg, "effective_action_dim", None)
# ── 4c. 单视角复制(用于 gr1 video.ego_view 等仅单视角的 dataset)────────
# 若 True,当 example["image"] 只有 1 张时,复制到 image_keys 数量以填充多视角
self._replicate_single_view = getattr(fw_cfg, "replicate_single_view", False)
# ── 4d. 是否使用 state 输入(训练/推理时)────────────────────────────
# 若 False,不读取 example["state"],tokenizer 与 Observation.state 均用 None/零
self._use_state = getattr(fw_cfg, "use_state", True)
# ── 4e. 动态视角数(可选)────────────────────────────────────────────
# 若 True,根据 example["image"] 的实际数量使用前 N 个 image_keys,不补零
# 若 False,固定使用全部 image_keys,不足的视角用零+mask=False 填充
self._dynamic_image_keys = getattr(fw_cfg, "dynamic_image_keys", False)
# ── 5. 设置 Tokenizer ────────────────────────────────────
tokenizer_path = getattr(fw_cfg, "tokenizer_path", None)
self._tokenizer = self._load_tokenizer(tokenizer_path, pi0_cfg.max_token_len)
# ── 6. 可选:加载 pi0 预训练权重 ─────────────────────────
pi0_ckpt = getattr(fw_cfg, "pi0_checkpoint", None)
if pi0_ckpt:
self.load_pi0_weights(pi0_ckpt)
# ──────────────────────────────────────────────────────────────
# 内部工具
# ──────────────────────────────────────────────────────────────
def _load_tokenizer(self, tokenizer_path: Optional[str], max_len: int):
"""
加载 PaliGemma SentencePiece tokenizer。
优先级:
1. 使用 tokenizer_path 指定的本地 .model 文件
2. 尝试通过 openpi 的下载工具自动获取 paligemma_tokenizer.model
3. 若均失败,tokenizer 设为 None,predict_action 时会报错提示用户
Args:
tokenizer_path: 本地 sentencepiece .model 文件路径,可为 None。
max_len: 最大 token 长度。
Returns:
_SentencePieceTokenizer 实例,或 None。
"""
if tokenizer_path and Path(tokenizer_path).exists():
logger.info(f"从本地路径加载 tokenizer:{tokenizer_path}")
return _SentencePieceTokenizer(tokenizer_path, max_len=max_len)
# 尝试使用 openpi 的下载工具(会缓存到本地)
try:
from openpi.shared import download as _download
path = _download.maybe_download(
"gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}
)
logger.info(f"通过 openpi 下载工具加载 tokenizer:{path}")
return _SentencePieceTokenizer(str(path), max_len=max_len)
except Exception as e:
logger.warning(
f"无法自动下载 paligemma tokenizer:{e}。"
"请在 config 中设置 framework.tokenizer_path 指向本地 .model 文件。"
)
return None
def _preprocess_examples(self, examples: List[dict], device: torch.device):
"""
将 starVLA 样本格式转换为 PI0 Observation 对象。
starVLA 样本格式:
examples[i]["image"] : List[PIL.Image] —— 各视角图像
examples[i]["lang"] : str —— 语言指令
examples[i]["state"] : np.ndarray (可选) —— 机器人本体状态,shape (1, state_dim)
PI0 Observation 格式:
images : dict[key -> Tensor[B, H, W, C]], 值域 [-1, 1]
image_masks : dict[key -> Tensor[B]], bool
state : Tensor[B, action_dim]
tokenized_prompt : Tensor[B, max_token_len], int32
tokenized_prompt_mask : Tensor[B, max_token_len], bool
Args:
examples: List[dict],每个 dict 为一个样本。
device: 目标 torch 设备。
Returns:
Observation 对象(openpi.models.model.Observation)。
"""
from openpi.models.model import Observation
batch_size = len(examples)
# ── 图像 ────────────────────────────────────────────────
# 视角数可配置:image_keys 长度决定最大视角数;dynamic_image_keys 时按实际图像数截断
replicate_single_view = getattr(self, "_replicate_single_view", False)
dynamic_image_keys = getattr(self, "_dynamic_image_keys", False)
# 确定本 batch 使用的 keys:dynamic 时以首样本图像数为准
num_views = len(examples[0].get("image", [])) if examples else len(self.image_keys)
if replicate_single_view and num_views == 1 and len(self.image_keys) > 1:
num_views = len(self.image_keys)
if dynamic_image_keys:
active_keys = list(self.image_keys)[: max(1, num_views)]
else:
active_keys = list(self.image_keys)
images_batch: Dict[str, List[torch.Tensor]] = {k: [] for k in active_keys}
masks_batch: Dict[str, List[bool]] = {k: [] for k in active_keys}
for example in examples:
imgs: List[Image.Image] = example.get("image", [])
if replicate_single_view and len(imgs) == 1 and len(active_keys) > 1:
imgs = imgs * len(active_keys)
for idx, key in enumerate(active_keys):
if idx < len(imgs):
t = _pil_to_tensor_normalized(imgs[idx], _IMAGE_RESOLUTION)
images_batch[key].append(t)
masks_batch[key].append(True)
else:
# 缺失视角:用全零占位,mask=False
t = torch.zeros(
(3, _IMAGE_RESOLUTION[0], _IMAGE_RESOLUTION[1]), dtype=torch.float32
)
images_batch[key].append(t)
masks_batch[key].append(False)
images_tensor: Dict[str, torch.Tensor] = {
k: torch.stack(v, dim=0).to(device, dtype=torch.float32) # [B, C, H, W]
for k, v in images_batch.items()
}
image_masks_tensor: Dict[str, torch.Tensor] = {
k: torch.tensor(v, dtype=torch.bool, device=device) # [B]
for k, v in masks_batch.items()
}
# ── 语言 Tokenization ───────────────────────────────────
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer 未初始化。请在 config 中设置 framework.tokenizer_path,"
"或确保 openpi 网络可访问以自动下载。"
)
pi05 = self._pi0_cfg.pi05
use_state = getattr(self, "_use_state", True)
_pi05_missing_state_warned = False
all_tokens = []
all_masks = []
for example in examples:
lang = example.get("lang", example.get("language", ""))
state_for_tok = None
if pi05 and use_state:
# pi0.5:将 state 离散化后并入提示词
# ⚠️ 当 use_state=False 时,不喂 state,tokenizer 使用 non-pi05 格式
raw_state = example.get("state", None)
if raw_state is not None:
s = np.array(raw_state)
if s.ndim > 1:
s = s[0]
s = s.flatten()[: self._pi0_cfg.state_dim]
state_for_tok = s
elif not _pi05_missing_state_warned:
# logger.warning(
# "PI0Framework [pi05=True]: example 中没有 'state' 字段!"
# "Tokenizer 将退回 non-pi05 prompt 格式,与训练格式不符,"
# "可能显著降低模型性能。"
# "请在 example_dict 中添加 'state': robot_state_array。"
# )
_pi05_missing_state_warned = True
toks, mask = self._tokenizer.tokenize(lang, state_for_tok)
all_tokens.append(toks)
all_masks.append(mask)
tokenized_prompt = torch.tensor(
np.stack(all_tokens, axis=0), dtype=torch.int32, device=device
) # [B, max_token_len]
tokenized_prompt_mask = torch.tensor(
np.stack(all_masks, axis=0), dtype=torch.bool, device=device
) # [B, max_token_len]
# ── State ───────────────────────────────────────────────
# 当 use_state=False 时不读取 example["state"],全部填零。
# state 对齐目标为 state_dim(而非 action_dim),两者可以不同,
# 例如 unified 74D state + 37D action 场景下 state_proj 为 Linear(74, width),
# state_tensor shape 为 [B, 74],不截断。
state_dim = self._pi0_cfg.state_dim
state_list = []
for example in examples:
raw = example.get("state", None) if use_state else None
if raw is not None:
s = np.array(raw, dtype=np.float32)
if s.ndim > 1:
s = s[0] # 取首帧 (chunk 时 state 为 [T, state_dim])
s = s.flatten()
else:
s = np.zeros(state_dim, dtype=np.float32)
# 对齐到 state_dim(截断或 zero-pad)
if len(s) >= state_dim:
s = s[:state_dim]
else:
s = np.concatenate([s, np.zeros(state_dim - len(s), dtype=np.float32)])
state_list.append(s)
state_tensor = torch.tensor(
np.stack(state_list, axis=0), dtype=torch.float32, device=device
) # [B, state_dim]
return Observation(
images=images_tensor,
image_masks=image_masks_tensor,
state=state_tensor,
tokenized_prompt=tokenized_prompt,
tokenized_prompt_mask=tokenized_prompt_mask,
)
# ──────────────────────────────────────────────────────────────
# 公开接口
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
# load_state_dict 重写(解决 key 前缀不匹配问题)
# ──────────────────────────────────────────────────────────────
def load_state_dict(self, state_dict, strict=True, assign=False):
"""
重写 load_state_dict,解决 PI0 checkpoint key 前缀不匹配问题。
**问题根因:**
`baseframework.from_pretrained` 调用 `FrameworkModel.load_state_dict(ckpt_dict, strict=True)`。
PI0Framework 的 state_dict key 带 `pi0_model.` 前缀(`self.pi0_model` 是子模块),
但 `convert_jax_model_to_pytorch.py` 输出的 safetensors 是裸 key(无前缀)。
直接 `load_state_dict` 严格模式必然失败。
**修复:**
检测到 state_dict 中均为裸 key(不含 `pi0_model.` 前缀)时,
自动将权重加载到 `self.pi0_model`(绕过 `PI0Framework` 层的前缀问题)。
"""
if state_dict and not any(k.startswith("pi0_model.") for k in state_dict.keys()):
logger.info(
"[PI0Framework.load_state_dict] 检测到裸 key(pi05_libero_pytorch 格式),"
"直接加载到 self.pi0_model(跳过 pi0_model. 前缀)"
)
# PaliGemma 中 embed_tokens.weight 与 lm_head.weight 权重绑定(同一 tensor 对象),
# convert_jax_model_to_pytorch.py 只保存了 lm_head.weight,不重复保存 embed_tokens.weight。
# 加载 lm_head.weight 后,绑定机制自动同步 embed_tokens.weight,用 strict=False 跳过此检查。
missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False)
expected_missing = {"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"}
real_missing = set(missing) - expected_missing
if real_missing:
logger.warning(f"[PI0Framework.load_state_dict] 真正缺失的 key(非权重绑定):{real_missing}")
if unexpected:
logger.warning(f"[PI0Framework.load_state_dict] 多余的 key:{set(unexpected)}")
logger.info("[PI0Framework.load_state_dict] pi0 权重加载完成(embed_tokens.weight 由权重绑定自动同步)")
from torch.nn.modules.module import _IncompatibleKeys
return _IncompatibleKeys(missing, unexpected)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
# ──────────────────────────────────────────────────────────────
# from_pretrained 重写(补充支持直接 pi05_libero_pytorch 目录格式)
# ──────────────────────────────────────────────────────────────
@classmethod
def from_pretrained(cls, pretrained_checkpoint: str, **kwargs):
"""
从 checkpoint 加载 PI0Framework,供 server_policy.py 调用。
**为什么需要重写?**
`baseframework.from_pretrained` 最终调用 `FrameworkModel.load_state_dict(ckpt_dict)`,
而 PI0Framework 的 state_dict key 带 `pi0_model.` 前缀(因为 self.pi0_model 是子模块),
但 convert_jax_model_to_pytorch.py 输出的 safetensors 是裸 key(无前缀)。
直接 load_state_dict 会因 key 不匹配而失败。
本方法将权重直接加载到 self.pi0_model,绕过前缀问题。
支持两种 checkpoint 格式:
**格式 A:starVLA wrapper 目录**(推荐用于 server_policy.py 部署)
```
<run_dir>/
├── config.yaml # framework.name=PI0, framework.pi0.*, framework.effective_action_dim=7
├── dataset_statistics.json # franka → action → min/max/mask(7 维,由 pi05_libero q01/q99 转换)
└── checkpoints/
└── model.safetensors # 即 pi05_libero_pytorch/model.safetensors(可软链接)
```
server_policy.py 启动命令:
`--ckpt_path <run_dir>/checkpoints/model.safetensors`
**格式 B:直接 pi05_libero_pytorch 目录**(不含 dataset_statistics,仅供快速测试)
```
<pi0_dir>/
├── model.safetensors # convert_jax_model_to_pytorch.py 输出
└── config.json # {"action_dim": 32, "action_horizon": 10, ...}
```
注意:此格式下 norm_stats 为空,ModelClient.get_action_stats 会失败,
需要手动提供统计量或改用格式 A。
Args:
pretrained_checkpoint: checkpoint 文件路径(.safetensors 或 .pt)。
Returns:
PI0Framework 实例,已加载权重。
"""
from pathlib import Path as _Path
pretrained_checkpoint = _Path(pretrained_checkpoint)
# ─── 格式 A:starVLA wrapper 目录 ─────────────────────────
wrapper_run_dir = pretrained_checkpoint.parents[1]
wrapper_cfg_yaml = wrapper_run_dir / "config.yaml"
wrapper_stats_json = wrapper_run_dir / "dataset_statistics.json"
if wrapper_cfg_yaml.exists() and wrapper_stats_json.exists():
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework import build_framework
model_config, norm_stats = read_mode_config(pretrained_checkpoint)
config = dict_to_namespace(model_config)
config.trainer.pretrained_checkpoint = None
# 防止 __init__ 内二次加载(若 config.yaml 也设了 pi0_checkpoint)
fw_cfg_ref = getattr(config, "framework", config)
if hasattr(fw_cfg_ref, "pi0_checkpoint"):
fw_cfg_ref.pi0_checkpoint = None
model = build_framework(cfg=config)
model.norm_stats = norm_stats
# 直接将 safetensors/pt 权重加载到 pi0_model(跳过 PI0Framework 前缀问题)
if pretrained_checkpoint.exists():
model.load_pi0_weights(str(pretrained_checkpoint))
logger.info(f"[from_pretrained 格式A] 从 wrapper 目录加载:{wrapper_run_dir}")
return model
# ─── 格式 B:直接 pi05_libero_pytorch 目录 ───────────────
pi0_dir = pretrained_checkpoint.parent
pi0_config_json = pi0_dir / "config.json"
if pi0_config_json.exists():
import json
from omegaconf import OmegaConf
with open(pi0_config_json) as _f:
pi0_cfg_dict = json.load(_f)
# 尝试从 safetensors 的 key 自动判断 pi05
pi05_detected = False
if pretrained_checkpoint.suffix == ".safetensors":
try:
import safetensors
with safetensors.safe_open(str(pretrained_checkpoint), framework="pt") as _sf:
_keys = list(_sf.keys())
pi05_detected = "time_mlp_in.weight" in _keys
except Exception:
pass
config = OmegaConf.create({
"framework": {
"name": "PI0",
"pi0": {
"paligemma_variant": pi0_cfg_dict.get("paligemma_variant", "gemma_2b"),
"action_expert_variant": pi0_cfg_dict.get("action_expert_variant", "gemma_300m"),
"pi05": pi05_detected,
"action_dim": pi0_cfg_dict.get("action_dim", 32),
"action_horizon": pi0_cfg_dict.get("action_horizon", 50),
"dtype": pi0_cfg_dict.get("precision", "bfloat16"),
},
"pi0_checkpoint": None,
"num_inference_steps": 10,
},
"trainer": {"pretrained_checkpoint": None},
})
model = cls(config=config)
model.load_pi0_weights(str(pretrained_checkpoint))
model.norm_stats = {}
logger.warning(
"[from_pretrained 格式B] 直接加载 pi05_libero_pytorch 目录,"
"norm_stats 为空。ModelClient.get_action_stats 会失败。"
"建议创建 starVLA wrapper 目录(含 config.yaml + dataset_statistics.json)。"
)
logger.info(f"[from_pretrained 格式B] pi05_detected={pi05_detected},路径:{pretrained_checkpoint}")
return model
# ─── Fallback ─────────────────────────────────────────────
logger.warning(
f"PI0Framework.from_pretrained:无法识别 checkpoint 格式({pretrained_checkpoint})。"
"请参考 docstring 创建 starVLA wrapper 目录(格式 A)。"
"尝试调用父类 from_pretrained(大概率因 key 不匹配而失败)。"
)
return super().from_pretrained(pretrained_checkpoint, **kwargs)
# ── PI0Pytorch 结构适配工具 ──────────────────────────────────────
# PI0Pytorch 源码中以下三个投影层硬编码为 32D,不读取 config.action_dim:
# self.action_in_proj = nn.Linear(32, width)
# self.action_out_proj = nn.Linear(width, 32)
# self.state_proj = nn.Linear(32, width) # pi05=False 时
_PI0_HARDCODED_ACTION_DIM: int = 32
def _replace_pi0_projection_layers(self, action_dim: int, state_dim: int = None) -> None:
"""
将 PI0Pytorch 硬编码的 32D 投影层替换为目标维度。
- action_in_proj / action_out_proj 按 action_dim 替换。
- state_proj 按 state_dim 替换(state_dim 可与 action_dim 不同,
如 unified 74D state + 37D action 的场景)。
- 当某维度与硬编码的 32D 相同时,对应层不做替换。
- 替换后的层为随机初始化;加载 checkpoint 时,
`_filter_ckpt_by_shape` 会因 shape 不匹配而自动跳过这些 key。
Args:
action_dim: 动作维度(用于 action_in_proj / action_out_proj)。
state_dim: 状态维度(用于 state_proj)。缺省时与 action_dim 相同。
"""
if state_dim is None:
state_dim = action_dim
# action_in_proj / action_out_proj 始终存在
if action_dim != self._PI0_HARDCODED_ACTION_DIM:
proj_width = self.pi0_model.action_in_proj.out_features # e.g. 1024
self.pi0_model.action_in_proj = nn.Linear(action_dim, proj_width)
self.pi0_model.action_out_proj = nn.Linear(proj_width, action_dim)
logger.info(
f"[PI0Framework] action_in/out_proj 已替换:"
f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {proj_width}) "
f"→ Linear({action_dim}, {proj_width}) "
f"[随机初始化,不加载 checkpoint 权重]"
)
# state_proj 仅在 pi05=False 时存在
if not self._pi0_cfg.pi05 and hasattr(self.pi0_model, "state_proj"):
if state_dim != self._PI0_HARDCODED_ACTION_DIM:
state_proj_width = self.pi0_model.state_proj.out_features
self.pi0_model.state_proj = nn.Linear(state_dim, state_proj_width)
logger.info(
f"[PI0Framework] state_proj 已替换:"
f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {state_proj_width}) "
f"→ Linear({state_dim}, {state_proj_width}) "
f"[随机初始化,不加载 checkpoint 权重]"
)
def _filter_ckpt_by_shape(self, state_dict: dict) -> dict:
"""
过滤 checkpoint state_dict,移除与当前模型 shape 不一致的 key。
load_state_dict(strict=False) 遇到 shape 不匹配时仍会抛出 RuntimeError;
本方法提前过滤,让这些 key 保持随机初始化,其余 key 正常加载。
Args:
state_dict: 从文件读取的原始权重字典(key 为裸名,不含前缀)。
Returns:
过滤后的字典,只保留 shape 匹配(或模型中不存在)的 key。
"""
model_sd = self.pi0_model.state_dict()
filtered: dict = {}
skipped: list = []
for k, v in state_dict.items():
if k in model_sd and model_sd[k].shape != v.shape:
skipped.append(
f" {k}: ckpt{tuple(v.shape)} ≠ model{tuple(model_sd[k].shape)}"
)
else:
filtered[k] = v
if skipped:
logger.info(
f"[load_pi0_weights] 跳过 {len(skipped)} 个 shape 不匹配的 key"
f"(这些层保持随机初始化,将在训练中从头学习):"
)
for s in skipped:
logger.info(s)
return filtered
def load_pi0_weights(self, checkpoint_path: str) -> None:
"""
从 pi0 预训练 checkpoint 加载权重到 self.pi0_model。
支持格式:
- .safetensors : 使用 safetensors.torch.load_file + 过滤后 load_state_dict
- .pt / .pth : 使用 torch.load(map_location="cpu")
当 config.action_dim 与 checkpoint 的硬编码维度(32)不一致时,
_filter_ckpt_by_shape 会自动跳过 action_in_proj / action_out_proj / state_proj,
这些层保持 _replace_pi0_projection_layers 的随机初始化,由训练自行学习。
Args:
checkpoint_path: 权重文件路径。
Raises:
FileNotFoundError: 文件不存在。
RuntimeError: 权重加载失败。
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"pi0 checkpoint 不存在:{checkpoint_path}")
logger.info(f"加载 pi0 预训练权重:{checkpoint_path}")
if checkpoint_path.suffix == ".safetensors":
import safetensors.torch as sf_torch
state_dict = sf_torch.load_file(str(checkpoint_path))
state_dict = self._filter_ckpt_by_shape(state_dict)
missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False)
# embed_tokens.weight 与 lm_head.weight 绑定,只保存一份属正常现象
expected_missing = {
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
}
real_missing = set(missing) - expected_missing
if real_missing:
logger.warning(f"加载后真正缺失的 key({len(real_missing)} 个):{list(real_missing)[:10]} ...")
if unexpected:
logger.warning(f"加载后多余的 key({len(unexpected)} 个):{list(unexpected)[:10]} ...")
else:
state_dict = torch.load(str(checkpoint_path), map_location="cpu")
# 兼容外层包装(有些 checkpoint 会多套一层 key)
if "model" in state_dict and isinstance(state_dict["model"], dict):
state_dict = state_dict["model"]
state_dict = self._filter_ckpt_by_shape(state_dict)
model_keys = set(self.pi0_model.state_dict().keys())
checkpoint_keys = set(state_dict.keys())
missing = model_keys - checkpoint_keys
unexpected = checkpoint_keys - model_keys
if missing:
logger.warning(f"权重中缺少的 key({len(missing)} 个):{list(missing)[:10]} ...")
if unexpected:
logger.warning(f"权重中多余的 key({len(unexpected)} 个):{list(unexpected)[:10]} ...")
self.pi0_model.load_state_dict(state_dict, strict=False)
logger.info("pi0 预训练权重加载完毕。")
def forward(
self,
examples: List[dict] = None,
**kwargs,
) -> dict:
"""
训练前向传播,计算流匹配(flow matching)损失。
基于 PI0Pytorch.forward 的训练逻辑:
1. 从 examples 中读取 image / lang / action / state
2. 将 action pad 到 model 的 action_dim(如 32),不足的维度补零
3. 通过 _preprocess_examples 构建 PI0 Observation 对象
4. 调用 PI0Pytorch.forward(observation, actions_target):
- 采样 noise 和 time
- 线性插值加噪:x_t = t * noise + (1-t) * actions
- 目标速度:u_t = noise - actions
- 预测速度:v_t = model(x_t, t)
- 损失:MSE(u_t, v_t) per element → 再取均值
5. 返回 {"action_loss": scalar}
Args:
examples: List[dict],每个 dict 包含:
- "image" (List[PIL.Image | np.ndarray]): 各视角图像
- "lang" (str): 语言指令
- "action" (np.ndarray): 目标动作,shape (action_horizon, D),
D 通常是 effective_action_dim(如 7),不足 action_dim 会自动 zero-pad
- "state" (np.ndarray, 可选): 机器人本体状态
**kwargs: 预留,暂未使用。
Returns:
dict:
"action_loss" (torch.Tensor): 标量,batch 平均流匹配 MSE 损失。
"""
if not isinstance(examples, list):
examples = [examples]
device = next(self.pi0_model.parameters()).device
action_dim = self._pi0_cfg.action_dim # 32(模型全维度)
action_horizon = self._pi0_cfg.action_horizon # 如 10
# ── 1. 构建 PI0 Observation ────────────────────────────────
observation = self._preprocess_examples(examples, device)
# ── 2. 整理 action target,pad 到 action_dim ───────────────
# examples[i]["action"]: np.ndarray (action_horizon, D),D 可能 < action_dim
actions_list = []
for example in examples:
a = np.array(example["action"], dtype=np.float32)
# 若缺少 horizon 维,视为单步
if a.ndim == 1:
a = a[np.newaxis, :] # (1, D)
# action_horizon 对齐(截断或 zero-pad)
H, D = a.shape
if H > action_horizon:
a = a[:action_horizon]
elif H < action_horizon:
a = np.concatenate(
[a, np.zeros((action_horizon - H, D), dtype=np.float32)], axis=0
)
# action_dim 对齐(截断或 zero-pad)
D = a.shape[1]
if D > action_dim:
a = a[:, :action_dim]
elif D < action_dim:
a = np.concatenate(
[a, np.zeros((action_horizon, action_dim - D), dtype=np.float32)], axis=1
)
actions_list.append(a)
actions_target = torch.tensor(
np.stack(actions_list, axis=0), dtype=torch.float32, device=device
) # [B, action_horizon, action_dim]
# ── 3. 调用 PI0Pytorch.forward 计算流匹配损失 ──────────────
# PI0Pytorch.forward 返回 F.mse_loss(u_t, v_t, reduction="none")
# shape: [B, action_horizon, action_dim]
# 内部会自动处理 bfloat16 cast(对 embedding 层)并以 float32 输出 loss
with torch.autocast("cuda", dtype=torch.bfloat16):
loss_per_element = self.pi0_model.forward(observation, actions_target)
# 对 batch / horizon / dim 全部取均值,得到标量损失
action_loss = loss_per_element.mean()
return {"action_loss": action_loss}
@torch.inference_mode()
def predict_action(
self,
examples: List[dict],
**kwargs,
) -> dict:
"""
推理:根据观测预测动作序列。
流程:
1. 将 examples 中的图像转换为归一化张量
2. Tokenize 语言指令
3. 提取/对齐本体状态
4. 构造 PI0 Observation 对象
5. 调用 PI0Pytorch.sample_actions 进行流匹配去噪推理
6. 返回归一化动作
Args:
examples: List[dict],每个 dict 包含:
- "image" (List[PIL.Image | np.ndarray]): 各视角图像,顺序与 config.framework.image_keys 对应。
接受 PIL Image 或 (H, W, 3) uint8 numpy 数组(eval_libero.py 格式)。
- "lang" (str) : 任务语言指令
- "state" (np.ndarray, 可选): 机器人本体状态,shape (state_dim,) 或 (1, state_dim)。
⚠️ pi05 模式下 state 是 tokenized prompt 的一部分,缺失会导致 prompt 格式
错误、性能下降。eval_libero.py 需要在 example_dict 中添加此字段。
**kwargs: 额外参数(暂未使用)。
Returns:
dict:
"normalized_actions" (np.ndarray): shape [B, action_horizon, action_dim],
归一化至 [-1, 1] 的预测动作序列。
"""
if not isinstance(examples, list):
examples = [examples]
# 确定运行设备
device = next(self.pi0_model.parameters()).device
# 准备 PI0 Observation
observation = self._preprocess_examples(examples, device)
# 执行流匹配去噪推理
num_steps = kwargs.get("num_steps", self.num_inference_steps)
pred_actions = self.pi0_model.sample_actions(
device, observation, num_steps=num_steps
) # [B, action_horizon, action_dim]
normalized_actions = pred_actions.detach().cpu().numpy() # [B, action_horizon, action_dim]
# 截断到有效动作维度(用于 LIBERO 等只需要前 N 维的场景)
# PI0 pi05_libero action_dim=32,但 LIBERO 机器人只有 7 DOF
# ModelClient.unnormalize_actions 的 min/max stats 只有 7 维,
# 若不截断会导致 numpy broadcast 形状不匹配错误。
# 在 config.framework.effective_action_dim 中设置有效维度(如 7)来启用截断。
if self.effective_action_dim is not None:
normalized_actions = normalized_actions[:, :, : self.effective_action_dim]
# print(normalized_actions.shape)
return {"normalized_actions": normalized_actions}
# ────────────────────────────────────────────────────────────────
# 快速验证(仅供调试,不依赖真实 checkpoint)
# ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
# ── 默认路径(对应 pi05_libero 模型) ────────────────────────
_DEFAULT_CKPT = (
"/mnt/data/fangyu/model/openpi/openpi-assets/checkpoints"
"/pi05_libero_pytorch/model.safetensors"
)
_DEFAULT_TOKENIZER = (
"/root/.cache/openpi/big_vision/paligemma_tokenizer.model"
)
parser = argparse.ArgumentParser(description="PI0Framework 快速冒烟测试")
parser.add_argument("--config_yaml", type=str, default=None, help="YAML 配置文件路径(优先于内联配置)")
parser.add_argument("--pi0_checkpoint", type=str, default=_DEFAULT_CKPT, help="pi0 safetensors 权重路径")
parser.add_argument("--tokenizer_path", type=str, default=_DEFAULT_TOKENIZER, help="PaliGemma tokenizer .model 路径")
parser.add_argument("--device", type=str, default=None, help="运行设备,如 cuda:0 / cpu(默认自动检测)")
parser.add_argument("--steps", type=int, default=10, help="流匹配推理步数")
parser.add_argument("--batch_size", type=int, default=1, help="测试 batch size")
args, _ = parser.parse_known_args()
if args.config_yaml:
cfg = OmegaConf.load(args.config_yaml)
else:
# pi05_libero 内联配置
cfg = OmegaConf.create({
"framework": {
"name": "PI0",
"pi0": {
"paligemma_variant": "gemma_2b",
"action_expert_variant": "gemma_300m",
"pi05": True, # pi05_libero 使用 pi0.5
"action_dim": 32,
"action_horizon": 10, # pi05_libero action_horizon=10
"dtype": "bfloat16",
},
"tokenizer_path": args.tokenizer_path,
"pi0_checkpoint": args.pi0_checkpoint,
"image_keys": ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"],
"num_inference_steps": args.steps,
}
})
print("=" * 60)
print("PI0Framework 测试")
print("=" * 60)
print(f" checkpoint : {args.pi0_checkpoint}")
print(f" tokenizer : {args.tokenizer_path}")
print(f" infer steps : {args.steps}")
print(f" batch size : {args.batch_size}")
# ── 1. 构建模型 ──────────────────────────────────────────────
print("\n[1/3] 初始化 PI0Framework 并加载权重...")
model = PI0Framework(cfg)
total_params = sum(p.numel() for p in model.pi0_model.parameters())
print(f" 模型参数量: {total_params / 1e9:.2f}B")
# ── 2. 移到目标设备 ──────────────────────────────────────────
if args.device:
device = torch.device(args.device)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n[2/3] 移至设备: {device}")
model = model.to(device)
# pi05_libero 用 bfloat16
model.pi0_model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
# ── 3. 构造假样本并推理 ──────────────────────────────────────
print(f"\n[3/3] 构造 batch_size={args.batch_size} 的假样本并调用 predict_action...")
fake_img = Image.fromarray(
np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
)
sample = {
"image": [fake_img, fake_img, fake_img], # 3 视角:base + left_wrist + right_wrist
"lang": "put the red cup on the plate",
"state": np.random.uniform(-1, 1, size=(32,)).astype(np.float32),
}
batch = [sample] * args.batch_size
import time
t0 = time.time()
result = model.predict_action(batch)
elapsed = time.time() - t0
actions = result["normalized_actions"]
print(f"\n 输出 normalized_actions shape : {actions.shape}")
print(f" 推理耗时 : {elapsed*1000:.1f} ms")
print(f" 动作值域 [min, max] : [{actions.min():.4f}, {actions.max():.4f}]")
print(f" 动作均值 ± 标准差 : {actions.mean():.4f} ± {actions.std():.4f}")
print("\n[OK] PI0Framework 测试完成!")
|