# Copyright 2025 starVLA community. All rights reserved. # Licensed under the MIT License, Version 1.0 (the "License"); """ PI0 Framework 将 openpi 的 PI0Pytorch 模型封装为 starVLA framework 接口。 支持: - 从 pi0 safetensors checkpoint 加载预训练参数 - 与其他 starVLA framework 相同的 __init__ 和 predict_action 接口 - 使用 PaliGemma SentencePiece tokenizer 处理语言指令 - 将 starVLA 样本格式(PIL 图像列表 + lang 字符串)转换为 PI0 Observation 格式 参考来源: - openpi/src/openpi/models_pytorch/pi0_pytorch.py (PI0Pytorch) - openpi/src/openpi/policies/policy.py (Policy.infer) """ import sys import logging from pathlib import Path from typing import List, Optional, Dict, Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from starVLA.model.framework.base_framework import baseframework from starVLA.model.tools import FRAMEWORK_REGISTRY from starVLA.training.trainer_utils import initialize_overwatch logger = initialize_overwatch(__name__) # ──────────────────────────────────────────────────────────────── # 常量 # ──────────────────────────────────────────────────────────────── # PI0 默认图像分辨率 _IMAGE_RESOLUTION = (224, 224) # PI0 默认使用的图像键名(与 openpi 保持一致) _DEFAULT_IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") # ──────────────────────────────────────────────────────────────── # 辅助函数 # ──────────────────────────────────────────────────────────────── def _pil_to_tensor_normalized(img, resolution=(224, 224)) -> torch.Tensor: """ 将图像转换为 PI0 所需的张量格式。 输出为 channels-first 格式 [C, H, W],归一化至 [-1, 1]。 PI0 的 preprocess_observation_pytorch 检测到 channels-first(shape[1]==3)时 会先转为 [B, H, W, C] 做增广,再转回 [B, C, H, W],最终送入 SigLIP conv2d (SigLIP 的 patch_embedding 期望 [B, C, H, W] 格式)。 Args: img: PIL.Image.Image 或 np.ndarray (H, W, 3) uint8。 同时兼容来自 eval_libero.py 的 numpy uint8 数组和 PIL Image。 resolution: 目标分辨率 (H, W),默认 (224, 224)。 Returns: torch.Tensor: shape [C, H, W], dtype float32, 值域 [-1, 1]。 """ # ── 统一转为 PIL Image ──────────────────────────────────────── if isinstance(img, np.ndarray): # eval_libero.py 传入 (H, W, 3) uint8 numpy 数组 img = Image.fromarray(img.astype(np.uint8)) # ── PIL 预处理 ──────────────────────────────────────────────── if img.mode != "RGB": img = img.convert("RGB") if img.size != (resolution[1], resolution[0]): img = img.resize((resolution[1], resolution[0]), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, C], [0, 1] arr = arr * 2.0 - 1.0 # [H, W, C], [-1, 1] t = torch.from_numpy(arr) # [H, W, C] return t.permute(2, 0, 1) # [C, H, W] channels-first def _build_pi0_config_obj(framework_cfg): """ 从 starVLA 的 framework 配置节点中构造 PI0Pytorch 所需的轻量配置对象。 PI0Pytorch.__init__ 需要的字段: - paligemma_variant : str (e.g. "gemma_2b") - action_expert_variant: str (e.g. "gemma_300m") - pi05 : bool - action_dim : int - action_horizon : int - dtype : str (e.g. "bfloat16") 这些字段可写在 yaml 的 framework.pi0 节点下,缺省值来自 openpi 的 Pi0Config。 """ pi0_node = getattr(framework_cfg, "pi0", None) or {} def _get(key, default): if hasattr(pi0_node, key): return getattr(pi0_node, key) if isinstance(pi0_node, dict): return pi0_node.get(key, default) return default class _PI0Config: paligemma_variant = _get("paligemma_variant", "gemma_2b") action_expert_variant= _get("action_expert_variant","gemma_300m") pi05 = _get("pi05", False) action_dim = _get("action_dim", 37) state_dim = _get("state_dim", 74) action_horizon = _get("action_horizon", 15) dtype = _get("dtype", "bfloat16") # max_token_len 按 pi05 标志自动推断(与 Pi0Config.__post_init__ 保持一致) _max_token_len_raw = _get("max_token_len", None) @property def max_token_len(self): if self._max_token_len_raw is not None: return self._max_token_len_raw return 200 if self.pi05 else 48 return _PI0Config() class _SentencePieceTokenizer: """ 轻量封装 SentencePiece tokenizer,用于 PaliGemma 风格的提示词编码。 与 openpi.models.tokenizer.PaligemmaTokenizer 逻辑相同, 但不依赖 openpi 的 GCS 下载逻辑,改为接受本地文件路径。 """ def __init__(self, model_path: str, max_len: int = 48): import sentencepiece self._max_len = max_len with open(model_path, "rb") as f: self._sp = sentencepiece.SentencePieceProcessor(model_proto=f.read()) def tokenize(self, prompt: str, state: Optional[np.ndarray] = None): """ 返回 (tokens: np.ndarray[int32], mask: np.ndarray[bool]),长度均为 max_len。 """ cleaned = prompt.strip().replace("_", " ").replace("\n", " ") if state is not None: # pi05 格式:将连续 state 离散化后拼入提示词 bins = np.linspace(-1, 1, 257)[:-1] disc = np.digitize(state, bins=bins) - 1 state_str = " ".join(map(str, disc)) full_prompt = f"Task: {cleaned}, State: {state_str};\nAction: " tokens = self._sp.encode(full_prompt, add_bos=True) else: tokens = self._sp.encode(cleaned, add_bos=True) + self._sp.encode("\n") tokens_len = len(tokens) if tokens_len < self._max_len: pad = [False] * (self._max_len - tokens_len) mask = [True] * tokens_len + pad tokens = tokens + [0] * (self._max_len - tokens_len) else: if tokens_len > self._max_len: logger.warning( f"Token length ({tokens_len}) > max_len ({self._max_len}), truncating." ) tokens = tokens[: self._max_len] mask = [True] * self._max_len return np.array(tokens, dtype=np.int32), np.array(mask, dtype=bool) # ──────────────────────────────────────────────────────────────── # PI0 Framework # ──────────────────────────────────────────────────────────────── @FRAMEWORK_REGISTRY.register("PI0") class PI0Framework(baseframework): """ starVLA framework 封装,将 openpi 的 PI0Pytorch 模型集成到 starVLA 生态。 Config 节点(yaml 示例): ```yaml framework: name: "PI0" pi0: paligemma_variant: "gemma_2b" # PaliGemma backbone 变体 action_expert_variant: "gemma_300m" # Action Expert 变体 pi05: false # 是否使用 Pi0.5 版本 action_dim: 32 # 动作维度 action_horizon: 50 # 预测动作步数 dtype: "bfloat16" # 模型权重精度 tokenizer_path: "/path/to/paligemma_tokenizer.model" # SentencePiece 模型 pi0_checkpoint: "/path/to/model.safetensors" # 可选,pi0 预训练权重 image_keys: # 图像键名顺序(对应 examples["image"] 中的顺序) - "base_0_rgb" - "left_wrist_0_rgb" - "right_wrist_0_rgb" num_inference_steps: 10 # 流匹配推理步数 ``` 接口: __init__(config) : 构建模型、加载 tokenizer、可选加载预训练权重 predict_action(examples) : 推理,返回 {"normalized_actions": np.ndarray} forward(examples) : 训练前向(暂未实现) """ def __init__( self, config: Optional[Any] = None, **kwargs, ) -> None: """ 初始化 PI0Framework。 Args: config: starVLA 层级配置(OmegaConf DictConfig 或兼容对象), 须包含 framework.pi0、framework.tokenizer_path 等字段。 **kwargs: 预留。 """ super().__init__() self.config = config fw_cfg = getattr(config, "framework", config) # ── 1. 构建 PI0Pytorch 配置 ────────────────────────────── pi0_cfg = _build_pi0_config_obj(fw_cfg) self._pi0_cfg = pi0_cfg # ── 2. 初始化 PI0Pytorch 模型 ──────────────────────────── try: sys.path.append("/mnt/data/fangyu/code/MixtureOfHorizons/src") from openpi.models_pytorch.pi0_pytorch import PI0Pytorch except ImportError as e: raise ImportError( "PI0Framework 依赖 openpi 包。请确保 openpi 已安装或 " "openpi/src 已添加到 PYTHONPATH。原始错误: " + str(e) ) from e self.pi0_model: nn.Module = PI0Pytorch(config=pi0_cfg) logger.info( f"PI0Pytorch 已初始化:variant={pi0_cfg.paligemma_variant}, " f"action_dim={pi0_cfg.action_dim}, action_horizon={pi0_cfg.action_horizon}, " f"pi05={pi0_cfg.pi05}" ) # ── 2b. 替换硬编码的 32D 投影层 ──────────────────────────── # PI0Pytorch 源码中 action_in_proj / action_out_proj / state_proj 均硬编码为 Linear(32, ...), # 不读取 config.action_dim / config.state_dim。 # action_in_proj / action_out_proj 按 action_dim 替换; # state_proj 按 state_dim 替换(state_dim 可与 action_dim 不同,如 unified 74D state)。 # 替换后的层为随机初始化,加载 checkpoint 时会因 shape 不匹配而自动跳过(由 _filter_ckpt_by_shape 保证)。 self._replace_pi0_projection_layers(pi0_cfg.action_dim, pi0_cfg.state_dim) # ── 3. 图像键名映射 ────────────────────────────────────── _ik = getattr(fw_cfg, "image_keys", None) if _ik is not None: self.image_keys = list(_ik) else: self.image_keys = list(_DEFAULT_IMAGE_KEYS) # ── 4. 推理步数 ────────────────────────────────────────── self.num_inference_steps = getattr(fw_cfg, "num_inference_steps", 10) # ── 4b. 有效动作维度(用于截断模型输出)──────────────────── # PI0 action_dim=32,但 LIBERO 实际只用前 7 维(3 pos + 3 rot + 1 gripper)。 # 若 config.framework.effective_action_dim 已设置,predict_action 会将输出 # 截断至前 N 维,以匹配 model2libero_interface.py 的 unnormalize 期望维度。 self.effective_action_dim = getattr(fw_cfg, "effective_action_dim", None) # ── 4c. 单视角复制(用于 gr1 video.ego_view 等仅单视角的 dataset)──────── # 若 True,当 example["image"] 只有 1 张时,复制到 image_keys 数量以填充多视角 self._replicate_single_view = getattr(fw_cfg, "replicate_single_view", False) # ── 4d. 是否使用 state 输入(训练/推理时)──────────────────────────── # 若 False,不读取 example["state"],tokenizer 与 Observation.state 均用 None/零 self._use_state = getattr(fw_cfg, "use_state", True) # ── 4e. 动态视角数(可选)──────────────────────────────────────────── # 若 True,根据 example["image"] 的实际数量使用前 N 个 image_keys,不补零 # 若 False,固定使用全部 image_keys,不足的视角用零+mask=False 填充 self._dynamic_image_keys = getattr(fw_cfg, "dynamic_image_keys", False) # ── 5. 设置 Tokenizer ──────────────────────────────────── tokenizer_path = getattr(fw_cfg, "tokenizer_path", None) self._tokenizer = self._load_tokenizer(tokenizer_path, pi0_cfg.max_token_len) # ── 6. 可选:加载 pi0 预训练权重 ───────────────────────── pi0_ckpt = getattr(fw_cfg, "pi0_checkpoint", None) if pi0_ckpt: self.load_pi0_weights(pi0_ckpt) # ────────────────────────────────────────────────────────────── # 内部工具 # ────────────────────────────────────────────────────────────── def _load_tokenizer(self, tokenizer_path: Optional[str], max_len: int): """ 加载 PaliGemma SentencePiece tokenizer。 优先级: 1. 使用 tokenizer_path 指定的本地 .model 文件 2. 尝试通过 openpi 的下载工具自动获取 paligemma_tokenizer.model 3. 若均失败,tokenizer 设为 None,predict_action 时会报错提示用户 Args: tokenizer_path: 本地 sentencepiece .model 文件路径,可为 None。 max_len: 最大 token 长度。 Returns: _SentencePieceTokenizer 实例,或 None。 """ if tokenizer_path and Path(tokenizer_path).exists(): logger.info(f"从本地路径加载 tokenizer:{tokenizer_path}") return _SentencePieceTokenizer(tokenizer_path, max_len=max_len) # 尝试使用 openpi 的下载工具(会缓存到本地) try: from openpi.shared import download as _download path = _download.maybe_download( "gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"} ) logger.info(f"通过 openpi 下载工具加载 tokenizer:{path}") return _SentencePieceTokenizer(str(path), max_len=max_len) except Exception as e: logger.warning( f"无法自动下载 paligemma tokenizer:{e}。" "请在 config 中设置 framework.tokenizer_path 指向本地 .model 文件。" ) return None def _preprocess_examples(self, examples: List[dict], device: torch.device): """ 将 starVLA 样本格式转换为 PI0 Observation 对象。 starVLA 样本格式: examples[i]["image"] : List[PIL.Image] —— 各视角图像 examples[i]["lang"] : str —— 语言指令 examples[i]["state"] : np.ndarray (可选) —— 机器人本体状态,shape (1, state_dim) PI0 Observation 格式: images : dict[key -> Tensor[B, H, W, C]], 值域 [-1, 1] image_masks : dict[key -> Tensor[B]], bool state : Tensor[B, action_dim] tokenized_prompt : Tensor[B, max_token_len], int32 tokenized_prompt_mask : Tensor[B, max_token_len], bool Args: examples: List[dict],每个 dict 为一个样本。 device: 目标 torch 设备。 Returns: Observation 对象(openpi.models.model.Observation)。 """ from openpi.models.model import Observation batch_size = len(examples) # ── 图像 ──────────────────────────────────────────────── # 视角数可配置:image_keys 长度决定最大视角数;dynamic_image_keys 时按实际图像数截断 replicate_single_view = getattr(self, "_replicate_single_view", False) dynamic_image_keys = getattr(self, "_dynamic_image_keys", False) # 确定本 batch 使用的 keys:dynamic 时以首样本图像数为准 num_views = len(examples[0].get("image", [])) if examples else len(self.image_keys) if replicate_single_view and num_views == 1 and len(self.image_keys) > 1: num_views = len(self.image_keys) if dynamic_image_keys: active_keys = list(self.image_keys)[: max(1, num_views)] else: active_keys = list(self.image_keys) images_batch: Dict[str, List[torch.Tensor]] = {k: [] for k in active_keys} masks_batch: Dict[str, List[bool]] = {k: [] for k in active_keys} for example in examples: imgs: List[Image.Image] = example.get("image", []) if replicate_single_view and len(imgs) == 1 and len(active_keys) > 1: imgs = imgs * len(active_keys) for idx, key in enumerate(active_keys): if idx < len(imgs): t = _pil_to_tensor_normalized(imgs[idx], _IMAGE_RESOLUTION) images_batch[key].append(t) masks_batch[key].append(True) else: # 缺失视角:用全零占位,mask=False t = torch.zeros( (3, _IMAGE_RESOLUTION[0], _IMAGE_RESOLUTION[1]), dtype=torch.float32 ) images_batch[key].append(t) masks_batch[key].append(False) images_tensor: Dict[str, torch.Tensor] = { k: torch.stack(v, dim=0).to(device, dtype=torch.float32) # [B, C, H, W] for k, v in images_batch.items() } image_masks_tensor: Dict[str, torch.Tensor] = { k: torch.tensor(v, dtype=torch.bool, device=device) # [B] for k, v in masks_batch.items() } # ── 语言 Tokenization ─────────────────────────────────── if self._tokenizer is None: raise RuntimeError( "Tokenizer 未初始化。请在 config 中设置 framework.tokenizer_path," "或确保 openpi 网络可访问以自动下载。" ) pi05 = self._pi0_cfg.pi05 use_state = getattr(self, "_use_state", True) _pi05_missing_state_warned = False all_tokens = [] all_masks = [] for example in examples: lang = example.get("lang", example.get("language", "")) state_for_tok = None if pi05 and use_state: # pi0.5:将 state 离散化后并入提示词 # ⚠️ 当 use_state=False 时,不喂 state,tokenizer 使用 non-pi05 格式 raw_state = example.get("state", None) if raw_state is not None: s = np.array(raw_state) if s.ndim > 1: s = s[0] s = s.flatten()[: self._pi0_cfg.state_dim] state_for_tok = s elif not _pi05_missing_state_warned: # logger.warning( # "PI0Framework [pi05=True]: example 中没有 'state' 字段!" # "Tokenizer 将退回 non-pi05 prompt 格式,与训练格式不符," # "可能显著降低模型性能。" # "请在 example_dict 中添加 'state': robot_state_array。" # ) _pi05_missing_state_warned = True toks, mask = self._tokenizer.tokenize(lang, state_for_tok) all_tokens.append(toks) all_masks.append(mask) tokenized_prompt = torch.tensor( np.stack(all_tokens, axis=0), dtype=torch.int32, device=device ) # [B, max_token_len] tokenized_prompt_mask = torch.tensor( np.stack(all_masks, axis=0), dtype=torch.bool, device=device ) # [B, max_token_len] # ── State ─────────────────────────────────────────────── # 当 use_state=False 时不读取 example["state"],全部填零。 # state 对齐目标为 state_dim(而非 action_dim),两者可以不同, # 例如 unified 74D state + 37D action 场景下 state_proj 为 Linear(74, width), # state_tensor shape 为 [B, 74],不截断。 state_dim = self._pi0_cfg.state_dim state_list = [] for example in examples: raw = example.get("state", None) if use_state else None if raw is not None: s = np.array(raw, dtype=np.float32) if s.ndim > 1: s = s[0] # 取首帧 (chunk 时 state 为 [T, state_dim]) s = s.flatten() else: s = np.zeros(state_dim, dtype=np.float32) # 对齐到 state_dim(截断或 zero-pad) if len(s) >= state_dim: s = s[:state_dim] else: s = np.concatenate([s, np.zeros(state_dim - len(s), dtype=np.float32)]) state_list.append(s) state_tensor = torch.tensor( np.stack(state_list, axis=0), dtype=torch.float32, device=device ) # [B, state_dim] return Observation( images=images_tensor, image_masks=image_masks_tensor, state=state_tensor, tokenized_prompt=tokenized_prompt, tokenized_prompt_mask=tokenized_prompt_mask, ) # ────────────────────────────────────────────────────────────── # 公开接口 # ────────────────────────────────────────────────────────────── # ────────────────────────────────────────────────────────────── # load_state_dict 重写(解决 key 前缀不匹配问题) # ────────────────────────────────────────────────────────────── def load_state_dict(self, state_dict, strict=True, assign=False): """ 重写 load_state_dict,解决 PI0 checkpoint key 前缀不匹配问题。 **问题根因:** `baseframework.from_pretrained` 调用 `FrameworkModel.load_state_dict(ckpt_dict, strict=True)`。 PI0Framework 的 state_dict key 带 `pi0_model.` 前缀(`self.pi0_model` 是子模块), 但 `convert_jax_model_to_pytorch.py` 输出的 safetensors 是裸 key(无前缀)。 直接 `load_state_dict` 严格模式必然失败。 **修复:** 检测到 state_dict 中均为裸 key(不含 `pi0_model.` 前缀)时, 自动将权重加载到 `self.pi0_model`(绕过 `PI0Framework` 层的前缀问题)。 """ if state_dict and not any(k.startswith("pi0_model.") for k in state_dict.keys()): logger.info( "[PI0Framework.load_state_dict] 检测到裸 key(pi05_libero_pytorch 格式)," "直接加载到 self.pi0_model(跳过 pi0_model. 前缀)" ) # PaliGemma 中 embed_tokens.weight 与 lm_head.weight 权重绑定(同一 tensor 对象), # convert_jax_model_to_pytorch.py 只保存了 lm_head.weight,不重复保存 embed_tokens.weight。 # 加载 lm_head.weight 后,绑定机制自动同步 embed_tokens.weight,用 strict=False 跳过此检查。 missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False) expected_missing = {"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"} real_missing = set(missing) - expected_missing if real_missing: logger.warning(f"[PI0Framework.load_state_dict] 真正缺失的 key(非权重绑定):{real_missing}") if unexpected: logger.warning(f"[PI0Framework.load_state_dict] 多余的 key:{set(unexpected)}") logger.info("[PI0Framework.load_state_dict] pi0 权重加载完成(embed_tokens.weight 由权重绑定自动同步)") from torch.nn.modules.module import _IncompatibleKeys return _IncompatibleKeys(missing, unexpected) return super().load_state_dict(state_dict, strict=strict, assign=assign) # ────────────────────────────────────────────────────────────── # from_pretrained 重写(补充支持直接 pi05_libero_pytorch 目录格式) # ────────────────────────────────────────────────────────────── @classmethod def from_pretrained(cls, pretrained_checkpoint: str, **kwargs): """ 从 checkpoint 加载 PI0Framework,供 server_policy.py 调用。 **为什么需要重写?** `baseframework.from_pretrained` 最终调用 `FrameworkModel.load_state_dict(ckpt_dict)`, 而 PI0Framework 的 state_dict key 带 `pi0_model.` 前缀(因为 self.pi0_model 是子模块), 但 convert_jax_model_to_pytorch.py 输出的 safetensors 是裸 key(无前缀)。 直接 load_state_dict 会因 key 不匹配而失败。 本方法将权重直接加载到 self.pi0_model,绕过前缀问题。 支持两种 checkpoint 格式: **格式 A:starVLA wrapper 目录**(推荐用于 server_policy.py 部署) ``` / ├── config.yaml # framework.name=PI0, framework.pi0.*, framework.effective_action_dim=7 ├── dataset_statistics.json # franka → action → min/max/mask(7 维,由 pi05_libero q01/q99 转换) └── checkpoints/ └── model.safetensors # 即 pi05_libero_pytorch/model.safetensors(可软链接) ``` server_policy.py 启动命令: `--ckpt_path /checkpoints/model.safetensors` **格式 B:直接 pi05_libero_pytorch 目录**(不含 dataset_statistics,仅供快速测试) ``` / ├── model.safetensors # convert_jax_model_to_pytorch.py 输出 └── config.json # {"action_dim": 32, "action_horizon": 10, ...} ``` 注意:此格式下 norm_stats 为空,ModelClient.get_action_stats 会失败, 需要手动提供统计量或改用格式 A。 Args: pretrained_checkpoint: checkpoint 文件路径(.safetensors 或 .pt)。 Returns: PI0Framework 实例,已加载权重。 """ from pathlib import Path as _Path pretrained_checkpoint = _Path(pretrained_checkpoint) # ─── 格式 A:starVLA wrapper 目录 ───────────────────────── wrapper_run_dir = pretrained_checkpoint.parents[1] wrapper_cfg_yaml = wrapper_run_dir / "config.yaml" wrapper_stats_json = wrapper_run_dir / "dataset_statistics.json" if wrapper_cfg_yaml.exists() and wrapper_stats_json.exists(): from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace from starVLA.model.framework import build_framework model_config, norm_stats = read_mode_config(pretrained_checkpoint) config = dict_to_namespace(model_config) config.trainer.pretrained_checkpoint = None # 防止 __init__ 内二次加载(若 config.yaml 也设了 pi0_checkpoint) fw_cfg_ref = getattr(config, "framework", config) if hasattr(fw_cfg_ref, "pi0_checkpoint"): fw_cfg_ref.pi0_checkpoint = None model = build_framework(cfg=config) model.norm_stats = norm_stats # 直接将 safetensors/pt 权重加载到 pi0_model(跳过 PI0Framework 前缀问题) if pretrained_checkpoint.exists(): model.load_pi0_weights(str(pretrained_checkpoint)) logger.info(f"[from_pretrained 格式A] 从 wrapper 目录加载:{wrapper_run_dir}") return model # ─── 格式 B:直接 pi05_libero_pytorch 目录 ─────────────── pi0_dir = pretrained_checkpoint.parent pi0_config_json = pi0_dir / "config.json" if pi0_config_json.exists(): import json from omegaconf import OmegaConf with open(pi0_config_json) as _f: pi0_cfg_dict = json.load(_f) # 尝试从 safetensors 的 key 自动判断 pi05 pi05_detected = False if pretrained_checkpoint.suffix == ".safetensors": try: import safetensors with safetensors.safe_open(str(pretrained_checkpoint), framework="pt") as _sf: _keys = list(_sf.keys()) pi05_detected = "time_mlp_in.weight" in _keys except Exception: pass config = OmegaConf.create({ "framework": { "name": "PI0", "pi0": { "paligemma_variant": pi0_cfg_dict.get("paligemma_variant", "gemma_2b"), "action_expert_variant": pi0_cfg_dict.get("action_expert_variant", "gemma_300m"), "pi05": pi05_detected, "action_dim": pi0_cfg_dict.get("action_dim", 32), "action_horizon": pi0_cfg_dict.get("action_horizon", 50), "dtype": pi0_cfg_dict.get("precision", "bfloat16"), }, "pi0_checkpoint": None, "num_inference_steps": 10, }, "trainer": {"pretrained_checkpoint": None}, }) model = cls(config=config) model.load_pi0_weights(str(pretrained_checkpoint)) model.norm_stats = {} logger.warning( "[from_pretrained 格式B] 直接加载 pi05_libero_pytorch 目录," "norm_stats 为空。ModelClient.get_action_stats 会失败。" "建议创建 starVLA wrapper 目录(含 config.yaml + dataset_statistics.json)。" ) logger.info(f"[from_pretrained 格式B] pi05_detected={pi05_detected},路径:{pretrained_checkpoint}") return model # ─── Fallback ───────────────────────────────────────────── logger.warning( f"PI0Framework.from_pretrained:无法识别 checkpoint 格式({pretrained_checkpoint})。" "请参考 docstring 创建 starVLA wrapper 目录(格式 A)。" "尝试调用父类 from_pretrained(大概率因 key 不匹配而失败)。" ) return super().from_pretrained(pretrained_checkpoint, **kwargs) # ── PI0Pytorch 结构适配工具 ────────────────────────────────────── # PI0Pytorch 源码中以下三个投影层硬编码为 32D,不读取 config.action_dim: # self.action_in_proj = nn.Linear(32, width) # self.action_out_proj = nn.Linear(width, 32) # self.state_proj = nn.Linear(32, width) # pi05=False 时 _PI0_HARDCODED_ACTION_DIM: int = 32 def _replace_pi0_projection_layers(self, action_dim: int, state_dim: int = None) -> None: """ 将 PI0Pytorch 硬编码的 32D 投影层替换为目标维度。 - action_in_proj / action_out_proj 按 action_dim 替换。 - state_proj 按 state_dim 替换(state_dim 可与 action_dim 不同, 如 unified 74D state + 37D action 的场景)。 - 当某维度与硬编码的 32D 相同时,对应层不做替换。 - 替换后的层为随机初始化;加载 checkpoint 时, `_filter_ckpt_by_shape` 会因 shape 不匹配而自动跳过这些 key。 Args: action_dim: 动作维度(用于 action_in_proj / action_out_proj)。 state_dim: 状态维度(用于 state_proj)。缺省时与 action_dim 相同。 """ if state_dim is None: state_dim = action_dim # action_in_proj / action_out_proj 始终存在 if action_dim != self._PI0_HARDCODED_ACTION_DIM: proj_width = self.pi0_model.action_in_proj.out_features # e.g. 1024 self.pi0_model.action_in_proj = nn.Linear(action_dim, proj_width) self.pi0_model.action_out_proj = nn.Linear(proj_width, action_dim) logger.info( f"[PI0Framework] action_in/out_proj 已替换:" f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {proj_width}) " f"→ Linear({action_dim}, {proj_width}) " f"[随机初始化,不加载 checkpoint 权重]" ) # state_proj 仅在 pi05=False 时存在 if not self._pi0_cfg.pi05 and hasattr(self.pi0_model, "state_proj"): if state_dim != self._PI0_HARDCODED_ACTION_DIM: state_proj_width = self.pi0_model.state_proj.out_features self.pi0_model.state_proj = nn.Linear(state_dim, state_proj_width) logger.info( f"[PI0Framework] state_proj 已替换:" f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {state_proj_width}) " f"→ Linear({state_dim}, {state_proj_width}) " f"[随机初始化,不加载 checkpoint 权重]" ) def _filter_ckpt_by_shape(self, state_dict: dict) -> dict: """ 过滤 checkpoint state_dict,移除与当前模型 shape 不一致的 key。 load_state_dict(strict=False) 遇到 shape 不匹配时仍会抛出 RuntimeError; 本方法提前过滤,让这些 key 保持随机初始化,其余 key 正常加载。 Args: state_dict: 从文件读取的原始权重字典(key 为裸名,不含前缀)。 Returns: 过滤后的字典,只保留 shape 匹配(或模型中不存在)的 key。 """ model_sd = self.pi0_model.state_dict() filtered: dict = {} skipped: list = [] for k, v in state_dict.items(): if k in model_sd and model_sd[k].shape != v.shape: skipped.append( f" {k}: ckpt{tuple(v.shape)} ≠ model{tuple(model_sd[k].shape)}" ) else: filtered[k] = v if skipped: logger.info( f"[load_pi0_weights] 跳过 {len(skipped)} 个 shape 不匹配的 key" f"(这些层保持随机初始化,将在训练中从头学习):" ) for s in skipped: logger.info(s) return filtered def load_pi0_weights(self, checkpoint_path: str) -> None: """ 从 pi0 预训练 checkpoint 加载权重到 self.pi0_model。 支持格式: - .safetensors : 使用 safetensors.torch.load_file + 过滤后 load_state_dict - .pt / .pth : 使用 torch.load(map_location="cpu") 当 config.action_dim 与 checkpoint 的硬编码维度(32)不一致时, _filter_ckpt_by_shape 会自动跳过 action_in_proj / action_out_proj / state_proj, 这些层保持 _replace_pi0_projection_layers 的随机初始化,由训练自行学习。 Args: checkpoint_path: 权重文件路径。 Raises: FileNotFoundError: 文件不存在。 RuntimeError: 权重加载失败。 """ checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"pi0 checkpoint 不存在:{checkpoint_path}") logger.info(f"加载 pi0 预训练权重:{checkpoint_path}") if checkpoint_path.suffix == ".safetensors": import safetensors.torch as sf_torch state_dict = sf_torch.load_file(str(checkpoint_path)) state_dict = self._filter_ckpt_by_shape(state_dict) missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False) # embed_tokens.weight 与 lm_head.weight 绑定,只保存一份属正常现象 expected_missing = { "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" } real_missing = set(missing) - expected_missing if real_missing: logger.warning(f"加载后真正缺失的 key({len(real_missing)} 个):{list(real_missing)[:10]} ...") if unexpected: logger.warning(f"加载后多余的 key({len(unexpected)} 个):{list(unexpected)[:10]} ...") else: state_dict = torch.load(str(checkpoint_path), map_location="cpu") # 兼容外层包装(有些 checkpoint 会多套一层 key) if "model" in state_dict and isinstance(state_dict["model"], dict): state_dict = state_dict["model"] state_dict = self._filter_ckpt_by_shape(state_dict) model_keys = set(self.pi0_model.state_dict().keys()) checkpoint_keys = set(state_dict.keys()) missing = model_keys - checkpoint_keys unexpected = checkpoint_keys - model_keys if missing: logger.warning(f"权重中缺少的 key({len(missing)} 个):{list(missing)[:10]} ...") if unexpected: logger.warning(f"权重中多余的 key({len(unexpected)} 个):{list(unexpected)[:10]} ...") self.pi0_model.load_state_dict(state_dict, strict=False) logger.info("pi0 预训练权重加载完毕。") def forward( self, examples: List[dict] = None, **kwargs, ) -> dict: """ 训练前向传播,计算流匹配(flow matching)损失。 基于 PI0Pytorch.forward 的训练逻辑: 1. 从 examples 中读取 image / lang / action / state 2. 将 action pad 到 model 的 action_dim(如 32),不足的维度补零 3. 通过 _preprocess_examples 构建 PI0 Observation 对象 4. 调用 PI0Pytorch.forward(observation, actions_target): - 采样 noise 和 time - 线性插值加噪:x_t = t * noise + (1-t) * actions - 目标速度:u_t = noise - actions - 预测速度:v_t = model(x_t, t) - 损失:MSE(u_t, v_t) per element → 再取均值 5. 返回 {"action_loss": scalar} Args: examples: List[dict],每个 dict 包含: - "image" (List[PIL.Image | np.ndarray]): 各视角图像 - "lang" (str): 语言指令 - "action" (np.ndarray): 目标动作,shape (action_horizon, D), D 通常是 effective_action_dim(如 7),不足 action_dim 会自动 zero-pad - "state" (np.ndarray, 可选): 机器人本体状态 **kwargs: 预留,暂未使用。 Returns: dict: "action_loss" (torch.Tensor): 标量,batch 平均流匹配 MSE 损失。 """ if not isinstance(examples, list): examples = [examples] device = next(self.pi0_model.parameters()).device action_dim = self._pi0_cfg.action_dim # 32(模型全维度) action_horizon = self._pi0_cfg.action_horizon # 如 10 # ── 1. 构建 PI0 Observation ──────────────────────────────── observation = self._preprocess_examples(examples, device) # ── 2. 整理 action target,pad 到 action_dim ─────────────── # examples[i]["action"]: np.ndarray (action_horizon, D),D 可能 < action_dim actions_list = [] for example in examples: a = np.array(example["action"], dtype=np.float32) # 若缺少 horizon 维,视为单步 if a.ndim == 1: a = a[np.newaxis, :] # (1, D) # action_horizon 对齐(截断或 zero-pad) H, D = a.shape if H > action_horizon: a = a[:action_horizon] elif H < action_horizon: a = np.concatenate( [a, np.zeros((action_horizon - H, D), dtype=np.float32)], axis=0 ) # action_dim 对齐(截断或 zero-pad) D = a.shape[1] if D > action_dim: a = a[:, :action_dim] elif D < action_dim: a = np.concatenate( [a, np.zeros((action_horizon, action_dim - D), dtype=np.float32)], axis=1 ) actions_list.append(a) actions_target = torch.tensor( np.stack(actions_list, axis=0), dtype=torch.float32, device=device ) # [B, action_horizon, action_dim] # ── 3. 调用 PI0Pytorch.forward 计算流匹配损失 ────────────── # PI0Pytorch.forward 返回 F.mse_loss(u_t, v_t, reduction="none") # shape: [B, action_horizon, action_dim] # 内部会自动处理 bfloat16 cast(对 embedding 层)并以 float32 输出 loss with torch.autocast("cuda", dtype=torch.bfloat16): loss_per_element = self.pi0_model.forward(observation, actions_target) # 对 batch / horizon / dim 全部取均值,得到标量损失 action_loss = loss_per_element.mean() return {"action_loss": action_loss} @torch.inference_mode() def predict_action( self, examples: List[dict], **kwargs, ) -> dict: """ 推理:根据观测预测动作序列。 流程: 1. 将 examples 中的图像转换为归一化张量 2. Tokenize 语言指令 3. 提取/对齐本体状态 4. 构造 PI0 Observation 对象 5. 调用 PI0Pytorch.sample_actions 进行流匹配去噪推理 6. 返回归一化动作 Args: examples: List[dict],每个 dict 包含: - "image" (List[PIL.Image | np.ndarray]): 各视角图像,顺序与 config.framework.image_keys 对应。 接受 PIL Image 或 (H, W, 3) uint8 numpy 数组(eval_libero.py 格式)。 - "lang" (str) : 任务语言指令 - "state" (np.ndarray, 可选): 机器人本体状态,shape (state_dim,) 或 (1, state_dim)。 ⚠️ pi05 模式下 state 是 tokenized prompt 的一部分,缺失会导致 prompt 格式 错误、性能下降。eval_libero.py 需要在 example_dict 中添加此字段。 **kwargs: 额外参数(暂未使用)。 Returns: dict: "normalized_actions" (np.ndarray): shape [B, action_horizon, action_dim], 归一化至 [-1, 1] 的预测动作序列。 """ if not isinstance(examples, list): examples = [examples] # 确定运行设备 device = next(self.pi0_model.parameters()).device # 准备 PI0 Observation observation = self._preprocess_examples(examples, device) # 执行流匹配去噪推理 num_steps = kwargs.get("num_steps", self.num_inference_steps) pred_actions = self.pi0_model.sample_actions( device, observation, num_steps=num_steps ) # [B, action_horizon, action_dim] normalized_actions = pred_actions.detach().cpu().numpy() # [B, action_horizon, action_dim] # 截断到有效动作维度(用于 LIBERO 等只需要前 N 维的场景) # PI0 pi05_libero action_dim=32,但 LIBERO 机器人只有 7 DOF # ModelClient.unnormalize_actions 的 min/max stats 只有 7 维, # 若不截断会导致 numpy broadcast 形状不匹配错误。 # 在 config.framework.effective_action_dim 中设置有效维度(如 7)来启用截断。 if self.effective_action_dim is not None: normalized_actions = normalized_actions[:, :, : self.effective_action_dim] # print(normalized_actions.shape) return {"normalized_actions": normalized_actions} # ──────────────────────────────────────────────────────────────── # 快速验证(仅供调试,不依赖真实 checkpoint) # ──────────────────────────────────────────────────────────────── if __name__ == "__main__": import argparse from omegaconf import OmegaConf # ── 默认路径(对应 pi05_libero 模型) ──────────────────────── _DEFAULT_CKPT = ( "/mnt/data/fangyu/model/openpi/openpi-assets/checkpoints" "/pi05_libero_pytorch/model.safetensors" ) _DEFAULT_TOKENIZER = ( "/root/.cache/openpi/big_vision/paligemma_tokenizer.model" ) parser = argparse.ArgumentParser(description="PI0Framework 快速冒烟测试") parser.add_argument("--config_yaml", type=str, default=None, help="YAML 配置文件路径(优先于内联配置)") parser.add_argument("--pi0_checkpoint", type=str, default=_DEFAULT_CKPT, help="pi0 safetensors 权重路径") parser.add_argument("--tokenizer_path", type=str, default=_DEFAULT_TOKENIZER, help="PaliGemma tokenizer .model 路径") parser.add_argument("--device", type=str, default=None, help="运行设备,如 cuda:0 / cpu(默认自动检测)") parser.add_argument("--steps", type=int, default=10, help="流匹配推理步数") parser.add_argument("--batch_size", type=int, default=1, help="测试 batch size") args, _ = parser.parse_known_args() if args.config_yaml: cfg = OmegaConf.load(args.config_yaml) else: # pi05_libero 内联配置 cfg = OmegaConf.create({ "framework": { "name": "PI0", "pi0": { "paligemma_variant": "gemma_2b", "action_expert_variant": "gemma_300m", "pi05": True, # pi05_libero 使用 pi0.5 "action_dim": 32, "action_horizon": 10, # pi05_libero action_horizon=10 "dtype": "bfloat16", }, "tokenizer_path": args.tokenizer_path, "pi0_checkpoint": args.pi0_checkpoint, "image_keys": ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"], "num_inference_steps": args.steps, } }) print("=" * 60) print("PI0Framework 测试") print("=" * 60) print(f" checkpoint : {args.pi0_checkpoint}") print(f" tokenizer : {args.tokenizer_path}") print(f" infer steps : {args.steps}") print(f" batch size : {args.batch_size}") # ── 1. 构建模型 ────────────────────────────────────────────── print("\n[1/3] 初始化 PI0Framework 并加载权重...") model = PI0Framework(cfg) total_params = sum(p.numel() for p in model.pi0_model.parameters()) print(f" 模型参数量: {total_params / 1e9:.2f}B") # ── 2. 移到目标设备 ────────────────────────────────────────── if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n[2/3] 移至设备: {device}") model = model.to(device) # pi05_libero 用 bfloat16 model.pi0_model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16") # ── 3. 构造假样本并推理 ────────────────────────────────────── print(f"\n[3/3] 构造 batch_size={args.batch_size} 的假样本并调用 predict_action...") fake_img = Image.fromarray( np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) ) sample = { "image": [fake_img, fake_img, fake_img], # 3 视角:base + left_wrist + right_wrist "lang": "put the red cup on the plate", "state": np.random.uniform(-1, 1, size=(32,)).astype(np.float32), } batch = [sample] * args.batch_size import time t0 = time.time() result = model.predict_action(batch) elapsed = time.time() - t0 actions = result["normalized_actions"] print(f"\n 输出 normalized_actions shape : {actions.shape}") print(f" 推理耗时 : {elapsed*1000:.1f} ms") print(f" 动作值域 [min, max] : [{actions.min():.4f}, {actions.max():.4f}]") print(f" 动作均值 ± 标准差 : {actions.mean():.4f} ± {actions.std():.4f}") print("\n[OK] PI0Framework 测试完成!")