Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
"""
PI0 Framework
ๅฐ† openpi ็š„ PI0Pytorch ๆจกๅž‹ๅฐ่ฃ…ไธบ starVLA framework ๆŽฅๅฃใ€‚
ๆ”ฏๆŒ๏ผš
- ไปŽ pi0 safetensors checkpoint ๅŠ ่ฝฝ้ข„่ฎญ็ปƒๅ‚ๆ•ฐ
- ไธŽๅ…ถไป– starVLA framework ็›ธๅŒ็š„ __init__ ๅ’Œ predict_action ๆŽฅๅฃ
- ไฝฟ็”จ PaliGemma SentencePiece tokenizer ๅค„็†่ฏญ่จ€ๆŒ‡ไปค
- ๅฐ† starVLA ๆ ทๆœฌๆ ผๅผ๏ผˆPIL ๅ›พๅƒๅˆ—่กจ + lang ๅญ—็ฌฆไธฒ๏ผ‰่ฝฌๆขไธบ PI0 Observation ๆ ผๅผ
ๅ‚่€ƒๆฅๆบ๏ผš
- openpi/src/openpi/models_pytorch/pi0_pytorch.py (PI0Pytorch)
- openpi/src/openpi/policies/policy.py (Policy.infer)
"""
import sys
import logging
from pathlib import Path
from typing import List, Optional, Dict, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from starVLA.model.framework.base_framework import baseframework
from starVLA.model.tools import FRAMEWORK_REGISTRY
from starVLA.training.trainer_utils import initialize_overwatch
logger = initialize_overwatch(__name__)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅธธ้‡
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# PI0 ้ป˜่ฎคๅ›พๅƒๅˆ†่พจ็އ
_IMAGE_RESOLUTION = (224, 224)
# PI0 ้ป˜่ฎคไฝฟ็”จ็š„ๅ›พๅƒ้”ฎๅ๏ผˆไธŽ openpi ไฟๆŒไธ€่‡ด๏ผ‰
_DEFAULT_IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่พ…ๅŠฉๅ‡ฝๆ•ฐ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _pil_to_tensor_normalized(img, resolution=(224, 224)) -> torch.Tensor:
"""
ๅฐ†ๅ›พๅƒ่ฝฌๆขไธบ PI0 ๆ‰€้œ€็š„ๅผ ้‡ๆ ผๅผใ€‚
่พ“ๅ‡บไธบ channels-first ๆ ผๅผ [C, H, W]๏ผŒๅฝ’ไธ€ๅŒ–่‡ณ [-1, 1]ใ€‚
PI0 ็š„ preprocess_observation_pytorch ๆฃ€ๆต‹ๅˆฐ channels-first๏ผˆshape[1]==3๏ผ‰ๆ—ถ
ไผšๅ…ˆ่ฝฌไธบ [B, H, W, C] ๅšๅขžๅนฟ๏ผŒๅ†่ฝฌๅ›ž [B, C, H, W]๏ผŒๆœ€็ปˆ้€ๅ…ฅ SigLIP conv2d
๏ผˆSigLIP ็š„ patch_embedding ๆœŸๆœ› [B, C, H, W] ๆ ผๅผ๏ผ‰ใ€‚
Args:
img: PIL.Image.Image ๆˆ– np.ndarray (H, W, 3) uint8ใ€‚
ๅŒๆ—ถๅ…ผๅฎนๆฅ่‡ช eval_libero.py ็š„ numpy uint8 ๆ•ฐ็ป„ๅ’Œ PIL Imageใ€‚
resolution: ็›ฎๆ ‡ๅˆ†่พจ็އ (H, W)๏ผŒ้ป˜่ฎค (224, 224)ใ€‚
Returns:
torch.Tensor: shape [C, H, W], dtype float32, ๅ€ผๅŸŸ [-1, 1]ใ€‚
"""
# โ”€โ”€ ็ปŸไธ€่ฝฌไธบ PIL Image โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if isinstance(img, np.ndarray):
# eval_libero.py ไผ ๅ…ฅ (H, W, 3) uint8 numpy ๆ•ฐ็ป„
img = Image.fromarray(img.astype(np.uint8))
# โ”€โ”€ PIL ้ข„ๅค„็† โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if img.mode != "RGB":
img = img.convert("RGB")
if img.size != (resolution[1], resolution[0]):
img = img.resize((resolution[1], resolution[0]), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, C], [0, 1]
arr = arr * 2.0 - 1.0 # [H, W, C], [-1, 1]
t = torch.from_numpy(arr) # [H, W, C]
return t.permute(2, 0, 1) # [C, H, W] channels-first
def _build_pi0_config_obj(framework_cfg):
"""
ไปŽ starVLA ็š„ framework ้…็ฝฎ่Š‚็‚นไธญๆž„้€  PI0Pytorch ๆ‰€้œ€็š„่ฝป้‡้…็ฝฎๅฏน่ฑกใ€‚
PI0Pytorch.__init__ ้œ€่ฆ็š„ๅญ—ๆฎต๏ผš
- paligemma_variant : str (e.g. "gemma_2b")
- action_expert_variant: str (e.g. "gemma_300m")
- pi05 : bool
- action_dim : int
- action_horizon : int
- dtype : str (e.g. "bfloat16")
่ฟ™ไบ›ๅญ—ๆฎตๅฏๅ†™ๅœจ yaml ็š„ framework.pi0 ่Š‚็‚นไธ‹๏ผŒ็ผบ็œๅ€ผๆฅ่‡ช openpi ็š„ Pi0Configใ€‚
"""
pi0_node = getattr(framework_cfg, "pi0", None) or {}
def _get(key, default):
if hasattr(pi0_node, key):
return getattr(pi0_node, key)
if isinstance(pi0_node, dict):
return pi0_node.get(key, default)
return default
class _PI0Config:
paligemma_variant = _get("paligemma_variant", "gemma_2b")
action_expert_variant= _get("action_expert_variant","gemma_300m")
pi05 = _get("pi05", False)
action_dim = _get("action_dim", 37)
state_dim = _get("state_dim", 74)
action_horizon = _get("action_horizon", 15)
dtype = _get("dtype", "bfloat16")
# max_token_len ๆŒ‰ pi05 ๆ ‡ๅฟ—่‡ชๅŠจๆŽจๆ–ญ๏ผˆไธŽ Pi0Config.__post_init__ ไฟๆŒไธ€่‡ด๏ผ‰
_max_token_len_raw = _get("max_token_len", None)
@property
def max_token_len(self):
if self._max_token_len_raw is not None:
return self._max_token_len_raw
return 200 if self.pi05 else 48
return _PI0Config()
class _SentencePieceTokenizer:
"""
่ฝป้‡ๅฐ่ฃ… SentencePiece tokenizer๏ผŒ็”จไบŽ PaliGemma ้ฃŽๆ ผ็š„ๆ็คบ่ฏ็ผ–็ ใ€‚
ไธŽ openpi.models.tokenizer.PaligemmaTokenizer ้€ป่พ‘็›ธๅŒ๏ผŒ
ไฝ†ไธไพ่ต– openpi ็š„ GCS ไธ‹่ฝฝ้€ป่พ‘๏ผŒๆ”นไธบๆŽฅๅ—ๆœฌๅœฐๆ–‡ไปถ่ทฏๅพ„ใ€‚
"""
def __init__(self, model_path: str, max_len: int = 48):
import sentencepiece
self._max_len = max_len
with open(model_path, "rb") as f:
self._sp = sentencepiece.SentencePieceProcessor(model_proto=f.read())
def tokenize(self, prompt: str, state: Optional[np.ndarray] = None):
"""
่ฟ”ๅ›ž (tokens: np.ndarray[int32], mask: np.ndarray[bool])๏ผŒ้•ฟๅบฆๅ‡ไธบ max_lenใ€‚
"""
cleaned = prompt.strip().replace("_", " ").replace("\n", " ")
if state is not None:
# pi05 ๆ ผๅผ๏ผšๅฐ†่ฟž็ปญ state ็ฆปๆ•ฃๅŒ–ๅŽๆ‹ผๅ…ฅๆ็คบ่ฏ
bins = np.linspace(-1, 1, 257)[:-1]
disc = np.digitize(state, bins=bins) - 1
state_str = " ".join(map(str, disc))
full_prompt = f"Task: {cleaned}, State: {state_str};\nAction: "
tokens = self._sp.encode(full_prompt, add_bos=True)
else:
tokens = self._sp.encode(cleaned, add_bos=True) + self._sp.encode("\n")
tokens_len = len(tokens)
if tokens_len < self._max_len:
pad = [False] * (self._max_len - tokens_len)
mask = [True] * tokens_len + pad
tokens = tokens + [0] * (self._max_len - tokens_len)
else:
if tokens_len > self._max_len:
logger.warning(
f"Token length ({tokens_len}) > max_len ({self._max_len}), truncating."
)
tokens = tokens[: self._max_len]
mask = [True] * self._max_len
return np.array(tokens, dtype=np.int32), np.array(mask, dtype=bool)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# PI0 Framework
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@FRAMEWORK_REGISTRY.register("PI0")
class PI0Framework(baseframework):
"""
starVLA framework ๅฐ่ฃ…๏ผŒๅฐ† openpi ็š„ PI0Pytorch ๆจกๅž‹้›†ๆˆๅˆฐ starVLA ็”Ÿๆ€ใ€‚
Config ่Š‚็‚น๏ผˆyaml ็คบไพ‹๏ผ‰๏ผš
```yaml
framework:
name: "PI0"
pi0:
paligemma_variant: "gemma_2b" # PaliGemma backbone ๅ˜ไฝ“
action_expert_variant: "gemma_300m" # Action Expert ๅ˜ไฝ“
pi05: false # ๆ˜ฏๅฆไฝฟ็”จ Pi0.5 ็‰ˆๆœฌ
action_dim: 32 # ๅŠจไฝœ็ปดๅบฆ
action_horizon: 50 # ้ข„ๆต‹ๅŠจไฝœๆญฅๆ•ฐ
dtype: "bfloat16" # ๆจกๅž‹ๆƒ้‡็ฒพๅบฆ
tokenizer_path: "/path/to/paligemma_tokenizer.model" # SentencePiece ๆจกๅž‹
pi0_checkpoint: "/path/to/model.safetensors" # ๅฏ้€‰๏ผŒpi0 ้ข„่ฎญ็ปƒๆƒ้‡
image_keys: # ๅ›พๅƒ้”ฎๅ้กบๅบ๏ผˆๅฏนๅบ” examples["image"] ไธญ็š„้กบๅบ๏ผ‰
- "base_0_rgb"
- "left_wrist_0_rgb"
- "right_wrist_0_rgb"
num_inference_steps: 10 # ๆตๅŒน้…ๆŽจ็†ๆญฅๆ•ฐ
```
ๆŽฅๅฃ๏ผš
__init__(config) : ๆž„ๅปบๆจกๅž‹ใ€ๅŠ ่ฝฝ tokenizerใ€ๅฏ้€‰ๅŠ ่ฝฝ้ข„่ฎญ็ปƒๆƒ้‡
predict_action(examples) : ๆŽจ็†๏ผŒ่ฟ”ๅ›ž {"normalized_actions": np.ndarray}
forward(examples) : ่ฎญ็ปƒๅ‰ๅ‘๏ผˆๆš‚ๆœชๅฎž็Žฐ๏ผ‰
"""
def __init__(
self,
config: Optional[Any] = None,
**kwargs,
) -> None:
"""
ๅˆๅง‹ๅŒ– PI0Frameworkใ€‚
Args:
config: starVLA ๅฑ‚็บง้…็ฝฎ๏ผˆOmegaConf DictConfig ๆˆ–ๅ…ผๅฎนๅฏน่ฑก๏ผ‰๏ผŒ
้กปๅŒ…ๅซ framework.pi0ใ€framework.tokenizer_path ็ญ‰ๅญ—ๆฎตใ€‚
**kwargs: ้ข„็•™ใ€‚
"""
super().__init__()
self.config = config
fw_cfg = getattr(config, "framework", config)
# โ”€โ”€ 1. ๆž„ๅปบ PI0Pytorch ้…็ฝฎ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
pi0_cfg = _build_pi0_config_obj(fw_cfg)
self._pi0_cfg = pi0_cfg
# โ”€โ”€ 2. ๅˆๅง‹ๅŒ– PI0Pytorch ๆจกๅž‹ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
try:
sys.path.append("/mnt/data/fangyu/code/MixtureOfHorizons/src")
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
except ImportError as e:
raise ImportError(
"PI0Framework ไพ่ต– openpi ๅŒ…ใ€‚่ฏท็กฎไฟ openpi ๅทฒๅฎ‰่ฃ…ๆˆ– "
"openpi/src ๅทฒๆทปๅŠ ๅˆฐ PYTHONPATHใ€‚ๅŽŸๅง‹้”™่ฏฏ: " + str(e)
) from e
self.pi0_model: nn.Module = PI0Pytorch(config=pi0_cfg)
logger.info(
f"PI0Pytorch ๅทฒๅˆๅง‹ๅŒ–๏ผšvariant={pi0_cfg.paligemma_variant}, "
f"action_dim={pi0_cfg.action_dim}, action_horizon={pi0_cfg.action_horizon}, "
f"pi05={pi0_cfg.pi05}"
)
# โ”€โ”€ 2b. ๆ›ฟๆข็กฌ็ผ–็ ็š„ 32D ๆŠ•ๅฝฑๅฑ‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# PI0Pytorch ๆบ็ ไธญ action_in_proj / action_out_proj / state_proj ๅ‡็กฌ็ผ–็ ไธบ Linear(32, ...)๏ผŒ
# ไธ่ฏปๅ– config.action_dim / config.state_dimใ€‚
# action_in_proj / action_out_proj ๆŒ‰ action_dim ๆ›ฟๆข๏ผ›
# state_proj ๆŒ‰ state_dim ๆ›ฟๆข๏ผˆstate_dim ๅฏไธŽ action_dim ไธๅŒ๏ผŒๅฆ‚ unified 74D state๏ผ‰ใ€‚
# ๆ›ฟๆขๅŽ็š„ๅฑ‚ไธบ้šๆœบๅˆๅง‹ๅŒ–๏ผŒๅŠ ่ฝฝ checkpoint ๆ—ถไผšๅ›  shape ไธๅŒน้…่€Œ่‡ชๅŠจ่ทณ่ฟ‡๏ผˆ็”ฑ _filter_ckpt_by_shape ไฟ่ฏ๏ผ‰ใ€‚
self._replace_pi0_projection_layers(pi0_cfg.action_dim, pi0_cfg.state_dim)
# โ”€โ”€ 3. ๅ›พๅƒ้”ฎๅๆ˜ ๅฐ„ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
_ik = getattr(fw_cfg, "image_keys", None)
if _ik is not None:
self.image_keys = list(_ik)
else:
self.image_keys = list(_DEFAULT_IMAGE_KEYS)
# โ”€โ”€ 4. ๆŽจ็†ๆญฅๆ•ฐ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
self.num_inference_steps = getattr(fw_cfg, "num_inference_steps", 10)
# โ”€โ”€ 4b. ๆœ‰ๆ•ˆๅŠจไฝœ็ปดๅบฆ๏ผˆ็”จไบŽๆˆชๆ–ญๆจกๅž‹่พ“ๅ‡บ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# PI0 action_dim=32๏ผŒไฝ† LIBERO ๅฎž้™…ๅช็”จๅ‰ 7 ็ปด๏ผˆ3 pos + 3 rot + 1 gripper๏ผ‰ใ€‚
# ่‹ฅ config.framework.effective_action_dim ๅทฒ่ฎพ็ฝฎ๏ผŒpredict_action ไผšๅฐ†่พ“ๅ‡บ
# ๆˆชๆ–ญ่‡ณๅ‰ N ็ปด๏ผŒไปฅๅŒน้… model2libero_interface.py ็š„ unnormalize ๆœŸๆœ›็ปดๅบฆใ€‚
self.effective_action_dim = getattr(fw_cfg, "effective_action_dim", None)
# โ”€โ”€ 4c. ๅ•่ง†่ง’ๅคๅˆถ๏ผˆ็”จไบŽ gr1 video.ego_view ็ญ‰ไป…ๅ•่ง†่ง’็š„ dataset๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่‹ฅ True๏ผŒๅฝ“ example["image"] ๅชๆœ‰ 1 ๅผ ๆ—ถ๏ผŒๅคๅˆถๅˆฐ image_keys ๆ•ฐ้‡ไปฅๅกซๅ……ๅคš่ง†่ง’
self._replicate_single_view = getattr(fw_cfg, "replicate_single_view", False)
# โ”€โ”€ 4d. ๆ˜ฏๅฆไฝฟ็”จ state ่พ“ๅ…ฅ๏ผˆ่ฎญ็ปƒ/ๆŽจ็†ๆ—ถ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่‹ฅ False๏ผŒไธ่ฏปๅ– example["state"]๏ผŒtokenizer ไธŽ Observation.state ๅ‡็”จ None/้›ถ
self._use_state = getattr(fw_cfg, "use_state", True)
# โ”€โ”€ 4e. ๅŠจๆ€่ง†่ง’ๆ•ฐ๏ผˆๅฏ้€‰๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่‹ฅ True๏ผŒๆ นๆฎ example["image"] ็š„ๅฎž้™…ๆ•ฐ้‡ไฝฟ็”จๅ‰ N ไธช image_keys๏ผŒไธ่กฅ้›ถ
# ่‹ฅ False๏ผŒๅ›บๅฎšไฝฟ็”จๅ…จ้ƒจ image_keys๏ผŒไธ่ถณ็š„่ง†่ง’็”จ้›ถ+mask=False ๅกซๅ……
self._dynamic_image_keys = getattr(fw_cfg, "dynamic_image_keys", False)
# โ”€โ”€ 5. ่ฎพ็ฝฎ Tokenizer โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
tokenizer_path = getattr(fw_cfg, "tokenizer_path", None)
self._tokenizer = self._load_tokenizer(tokenizer_path, pi0_cfg.max_token_len)
# โ”€โ”€ 6. ๅฏ้€‰๏ผšๅŠ ่ฝฝ pi0 ้ข„่ฎญ็ปƒๆƒ้‡ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
pi0_ckpt = getattr(fw_cfg, "pi0_checkpoint", None)
if pi0_ckpt:
self.load_pi0_weights(pi0_ckpt)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅ†…้ƒจๅทฅๅ…ท
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _load_tokenizer(self, tokenizer_path: Optional[str], max_len: int):
"""
ๅŠ ่ฝฝ PaliGemma SentencePiece tokenizerใ€‚
ไผ˜ๅ…ˆ็บง๏ผš
1. ไฝฟ็”จ tokenizer_path ๆŒ‡ๅฎš็š„ๆœฌๅœฐ .model ๆ–‡ไปถ
2. ๅฐ่ฏ•้€š่ฟ‡ openpi ็š„ไธ‹่ฝฝๅทฅๅ…ท่‡ชๅŠจ่Žทๅ– paligemma_tokenizer.model
3. ่‹ฅๅ‡ๅคฑ่ดฅ๏ผŒtokenizer ่ฎพไธบ None๏ผŒpredict_action ๆ—ถไผšๆŠฅ้”™ๆ็คบ็”จๆˆท
Args:
tokenizer_path: ๆœฌๅœฐ sentencepiece .model ๆ–‡ไปถ่ทฏๅพ„๏ผŒๅฏไธบ Noneใ€‚
max_len: ๆœ€ๅคง token ้•ฟๅบฆใ€‚
Returns:
_SentencePieceTokenizer ๅฎžไพ‹๏ผŒๆˆ– Noneใ€‚
"""
if tokenizer_path and Path(tokenizer_path).exists():
logger.info(f"ไปŽๆœฌๅœฐ่ทฏๅพ„ๅŠ ่ฝฝ tokenizer๏ผš{tokenizer_path}")
return _SentencePieceTokenizer(tokenizer_path, max_len=max_len)
# ๅฐ่ฏ•ไฝฟ็”จ openpi ็š„ไธ‹่ฝฝๅทฅๅ…ท๏ผˆไผš็ผ“ๅญ˜ๅˆฐๆœฌๅœฐ๏ผ‰
try:
from openpi.shared import download as _download
path = _download.maybe_download(
"gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}
)
logger.info(f"้€š่ฟ‡ openpi ไธ‹่ฝฝๅทฅๅ…ทๅŠ ่ฝฝ tokenizer๏ผš{path}")
return _SentencePieceTokenizer(str(path), max_len=max_len)
except Exception as e:
logger.warning(
f"ๆ— ๆณ•่‡ชๅŠจไธ‹่ฝฝ paligemma tokenizer๏ผš{e}ใ€‚"
"่ฏทๅœจ config ไธญ่ฎพ็ฝฎ framework.tokenizer_path ๆŒ‡ๅ‘ๆœฌๅœฐ .model ๆ–‡ไปถใ€‚"
)
return None
def _preprocess_examples(self, examples: List[dict], device: torch.device):
"""
ๅฐ† starVLA ๆ ทๆœฌๆ ผๅผ่ฝฌๆขไธบ PI0 Observation ๅฏน่ฑกใ€‚
starVLA ๆ ทๆœฌๆ ผๅผ๏ผš
examples[i]["image"] : List[PIL.Image] โ€”โ€” ๅ„่ง†่ง’ๅ›พๅƒ
examples[i]["lang"] : str โ€”โ€” ่ฏญ่จ€ๆŒ‡ไปค
examples[i]["state"] : np.ndarray (ๅฏ้€‰) โ€”โ€” ๆœบๅ™จไบบๆœฌไฝ“็Šถๆ€๏ผŒshape (1, state_dim)
PI0 Observation ๆ ผๅผ๏ผš
images : dict[key -> Tensor[B, H, W, C]], ๅ€ผๅŸŸ [-1, 1]
image_masks : dict[key -> Tensor[B]], bool
state : Tensor[B, action_dim]
tokenized_prompt : Tensor[B, max_token_len], int32
tokenized_prompt_mask : Tensor[B, max_token_len], bool
Args:
examples: List[dict]๏ผŒๆฏไธช dict ไธบไธ€ไธชๆ ทๆœฌใ€‚
device: ็›ฎๆ ‡ torch ่ฎพๅค‡ใ€‚
Returns:
Observation ๅฏน่ฑก๏ผˆopenpi.models.model.Observation๏ผ‰ใ€‚
"""
from openpi.models.model import Observation
batch_size = len(examples)
# โ”€โ”€ ๅ›พๅƒ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่ง†่ง’ๆ•ฐๅฏ้…็ฝฎ๏ผšimage_keys ้•ฟๅบฆๅ†ณๅฎšๆœ€ๅคง่ง†่ง’ๆ•ฐ๏ผ›dynamic_image_keys ๆ—ถๆŒ‰ๅฎž้™…ๅ›พๅƒๆ•ฐๆˆชๆ–ญ
replicate_single_view = getattr(self, "_replicate_single_view", False)
dynamic_image_keys = getattr(self, "_dynamic_image_keys", False)
# ็กฎๅฎšๆœฌ batch ไฝฟ็”จ็š„ keys๏ผšdynamic ๆ—ถไปฅ้ฆ–ๆ ทๆœฌๅ›พๅƒๆ•ฐไธบๅ‡†
num_views = len(examples[0].get("image", [])) if examples else len(self.image_keys)
if replicate_single_view and num_views == 1 and len(self.image_keys) > 1:
num_views = len(self.image_keys)
if dynamic_image_keys:
active_keys = list(self.image_keys)[: max(1, num_views)]
else:
active_keys = list(self.image_keys)
images_batch: Dict[str, List[torch.Tensor]] = {k: [] for k in active_keys}
masks_batch: Dict[str, List[bool]] = {k: [] for k in active_keys}
for example in examples:
imgs: List[Image.Image] = example.get("image", [])
if replicate_single_view and len(imgs) == 1 and len(active_keys) > 1:
imgs = imgs * len(active_keys)
for idx, key in enumerate(active_keys):
if idx < len(imgs):
t = _pil_to_tensor_normalized(imgs[idx], _IMAGE_RESOLUTION)
images_batch[key].append(t)
masks_batch[key].append(True)
else:
# ็ผบๅคฑ่ง†่ง’๏ผš็”จๅ…จ้›ถๅ ไฝ๏ผŒmask=False
t = torch.zeros(
(3, _IMAGE_RESOLUTION[0], _IMAGE_RESOLUTION[1]), dtype=torch.float32
)
images_batch[key].append(t)
masks_batch[key].append(False)
images_tensor: Dict[str, torch.Tensor] = {
k: torch.stack(v, dim=0).to(device, dtype=torch.float32) # [B, C, H, W]
for k, v in images_batch.items()
}
image_masks_tensor: Dict[str, torch.Tensor] = {
k: torch.tensor(v, dtype=torch.bool, device=device) # [B]
for k, v in masks_batch.items()
}
# โ”€โ”€ ่ฏญ่จ€ Tokenization โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer ๆœชๅˆๅง‹ๅŒ–ใ€‚่ฏทๅœจ config ไธญ่ฎพ็ฝฎ framework.tokenizer_path๏ผŒ"
"ๆˆ–็กฎไฟ openpi ็ฝ‘็ปœๅฏ่ฎฟ้—ฎไปฅ่‡ชๅŠจไธ‹่ฝฝใ€‚"
)
pi05 = self._pi0_cfg.pi05
use_state = getattr(self, "_use_state", True)
_pi05_missing_state_warned = False
all_tokens = []
all_masks = []
for example in examples:
lang = example.get("lang", example.get("language", ""))
state_for_tok = None
if pi05 and use_state:
# pi0.5๏ผšๅฐ† state ็ฆปๆ•ฃๅŒ–ๅŽๅนถๅ…ฅๆ็คบ่ฏ
# โš ๏ธ ๅฝ“ use_state=False ๆ—ถ๏ผŒไธๅ–‚ state๏ผŒtokenizer ไฝฟ็”จ non-pi05 ๆ ผๅผ
raw_state = example.get("state", None)
if raw_state is not None:
s = np.array(raw_state)
if s.ndim > 1:
s = s[0]
s = s.flatten()[: self._pi0_cfg.state_dim]
state_for_tok = s
elif not _pi05_missing_state_warned:
# logger.warning(
# "PI0Framework [pi05=True]: example ไธญๆฒกๆœ‰ 'state' ๅญ—ๆฎต๏ผ"
# "Tokenizer ๅฐ†้€€ๅ›ž non-pi05 prompt ๆ ผๅผ๏ผŒไธŽ่ฎญ็ปƒๆ ผๅผไธ็ฌฆ๏ผŒ"
# "ๅฏ่ƒฝๆ˜พ่‘—้™ไฝŽๆจกๅž‹ๆ€ง่ƒฝใ€‚"
# "่ฏทๅœจ example_dict ไธญๆทปๅŠ  'state': robot_state_arrayใ€‚"
# )
_pi05_missing_state_warned = True
toks, mask = self._tokenizer.tokenize(lang, state_for_tok)
all_tokens.append(toks)
all_masks.append(mask)
tokenized_prompt = torch.tensor(
np.stack(all_tokens, axis=0), dtype=torch.int32, device=device
) # [B, max_token_len]
tokenized_prompt_mask = torch.tensor(
np.stack(all_masks, axis=0), dtype=torch.bool, device=device
) # [B, max_token_len]
# โ”€โ”€ State โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅฝ“ use_state=False ๆ—ถไธ่ฏปๅ– example["state"]๏ผŒๅ…จ้ƒจๅกซ้›ถใ€‚
# state ๅฏน้ฝ็›ฎๆ ‡ไธบ state_dim๏ผˆ่€Œ้ž action_dim๏ผ‰๏ผŒไธค่€…ๅฏไปฅไธๅŒ๏ผŒ
# ไพ‹ๅฆ‚ unified 74D state + 37D action ๅœบๆ™ฏไธ‹ state_proj ไธบ Linear(74, width)๏ผŒ
# state_tensor shape ไธบ [B, 74]๏ผŒไธๆˆชๆ–ญใ€‚
state_dim = self._pi0_cfg.state_dim
state_list = []
for example in examples:
raw = example.get("state", None) if use_state else None
if raw is not None:
s = np.array(raw, dtype=np.float32)
if s.ndim > 1:
s = s[0] # ๅ–้ฆ–ๅธง (chunk ๆ—ถ state ไธบ [T, state_dim])
s = s.flatten()
else:
s = np.zeros(state_dim, dtype=np.float32)
# ๅฏน้ฝๅˆฐ state_dim๏ผˆๆˆชๆ–ญๆˆ– zero-pad๏ผ‰
if len(s) >= state_dim:
s = s[:state_dim]
else:
s = np.concatenate([s, np.zeros(state_dim - len(s), dtype=np.float32)])
state_list.append(s)
state_tensor = torch.tensor(
np.stack(state_list, axis=0), dtype=torch.float32, device=device
) # [B, state_dim]
return Observation(
images=images_tensor,
image_masks=image_masks_tensor,
state=state_tensor,
tokenized_prompt=tokenized_prompt,
tokenized_prompt_mask=tokenized_prompt_mask,
)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅ…ฌๅผ€ๆŽฅๅฃ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# load_state_dict ้‡ๅ†™๏ผˆ่งฃๅ†ณ key ๅ‰็ผ€ไธๅŒน้…้—ฎ้ข˜๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def load_state_dict(self, state_dict, strict=True, assign=False):
"""
้‡ๅ†™ load_state_dict๏ผŒ่งฃๅ†ณ PI0 checkpoint key ๅ‰็ผ€ไธๅŒน้…้—ฎ้ข˜ใ€‚
**้—ฎ้ข˜ๆ นๅ› ๏ผš**
`baseframework.from_pretrained` ่ฐƒ็”จ `FrameworkModel.load_state_dict(ckpt_dict, strict=True)`ใ€‚
PI0Framework ็š„ state_dict key ๅธฆ `pi0_model.` ๅ‰็ผ€๏ผˆ`self.pi0_model` ๆ˜ฏๅญๆจกๅ—๏ผ‰๏ผŒ
ไฝ† `convert_jax_model_to_pytorch.py` ่พ“ๅ‡บ็š„ safetensors ๆ˜ฏ่ฃธ key๏ผˆๆ— ๅ‰็ผ€๏ผ‰ใ€‚
็›ดๆŽฅ `load_state_dict` ไธฅๆ ผๆจกๅผๅฟ…็„ถๅคฑ่ดฅใ€‚
**ไฟฎๅค๏ผš**
ๆฃ€ๆต‹ๅˆฐ state_dict ไธญๅ‡ไธบ่ฃธ key๏ผˆไธๅซ `pi0_model.` ๅ‰็ผ€๏ผ‰ๆ—ถ๏ผŒ
่‡ชๅŠจๅฐ†ๆƒ้‡ๅŠ ่ฝฝๅˆฐ `self.pi0_model`๏ผˆ็ป•่ฟ‡ `PI0Framework` ๅฑ‚็š„ๅ‰็ผ€้—ฎ้ข˜๏ผ‰ใ€‚
"""
if state_dict and not any(k.startswith("pi0_model.") for k in state_dict.keys()):
logger.info(
"[PI0Framework.load_state_dict] ๆฃ€ๆต‹ๅˆฐ่ฃธ key๏ผˆpi05_libero_pytorch ๆ ผๅผ๏ผ‰๏ผŒ"
"็›ดๆŽฅๅŠ ่ฝฝๅˆฐ self.pi0_model๏ผˆ่ทณ่ฟ‡ pi0_model. ๅ‰็ผ€๏ผ‰"
)
# PaliGemma ไธญ embed_tokens.weight ไธŽ lm_head.weight ๆƒ้‡็ป‘ๅฎš๏ผˆๅŒไธ€ tensor ๅฏน่ฑก๏ผ‰๏ผŒ
# convert_jax_model_to_pytorch.py ๅชไฟๅญ˜ไบ† lm_head.weight๏ผŒไธ้‡ๅคไฟๅญ˜ embed_tokens.weightใ€‚
# ๅŠ ่ฝฝ lm_head.weight ๅŽ๏ผŒ็ป‘ๅฎšๆœบๅˆถ่‡ชๅŠจๅŒๆญฅ embed_tokens.weight๏ผŒ็”จ strict=False ่ทณ่ฟ‡ๆญคๆฃ€ๆŸฅใ€‚
missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False)
expected_missing = {"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"}
real_missing = set(missing) - expected_missing
if real_missing:
logger.warning(f"[PI0Framework.load_state_dict] ็œŸๆญฃ็ผบๅคฑ็š„ key๏ผˆ้žๆƒ้‡็ป‘ๅฎš๏ผ‰๏ผš{real_missing}")
if unexpected:
logger.warning(f"[PI0Framework.load_state_dict] ๅคšไฝ™็š„ key๏ผš{set(unexpected)}")
logger.info("[PI0Framework.load_state_dict] pi0 ๆƒ้‡ๅŠ ่ฝฝๅฎŒๆˆ๏ผˆembed_tokens.weight ็”ฑๆƒ้‡็ป‘ๅฎš่‡ชๅŠจๅŒๆญฅ๏ผ‰")
from torch.nn.modules.module import _IncompatibleKeys
return _IncompatibleKeys(missing, unexpected)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# from_pretrained ้‡ๅ†™๏ผˆ่กฅๅ……ๆ”ฏๆŒ็›ดๆŽฅ pi05_libero_pytorch ็›ฎๅฝ•ๆ ผๅผ๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@classmethod
def from_pretrained(cls, pretrained_checkpoint: str, **kwargs):
"""
ไปŽ checkpoint ๅŠ ่ฝฝ PI0Framework๏ผŒไพ› server_policy.py ่ฐƒ็”จใ€‚
**ไธบไป€ไนˆ้œ€่ฆ้‡ๅ†™๏ผŸ**
`baseframework.from_pretrained` ๆœ€็ปˆ่ฐƒ็”จ `FrameworkModel.load_state_dict(ckpt_dict)`๏ผŒ
่€Œ PI0Framework ็š„ state_dict key ๅธฆ `pi0_model.` ๅ‰็ผ€๏ผˆๅ› ไธบ self.pi0_model ๆ˜ฏๅญๆจกๅ—๏ผ‰๏ผŒ
ไฝ† convert_jax_model_to_pytorch.py ่พ“ๅ‡บ็š„ safetensors ๆ˜ฏ่ฃธ key๏ผˆๆ— ๅ‰็ผ€๏ผ‰ใ€‚
็›ดๆŽฅ load_state_dict ไผšๅ›  key ไธๅŒน้…่€Œๅคฑ่ดฅใ€‚
ๆœฌๆ–นๆณ•ๅฐ†ๆƒ้‡็›ดๆŽฅๅŠ ่ฝฝๅˆฐ self.pi0_model๏ผŒ็ป•่ฟ‡ๅ‰็ผ€้—ฎ้ข˜ใ€‚
ๆ”ฏๆŒไธค็ง checkpoint ๆ ผๅผ๏ผš
**ๆ ผๅผ A๏ผšstarVLA wrapper ็›ฎๅฝ•**๏ผˆๆŽจ่็”จไบŽ server_policy.py ้ƒจ็ฝฒ๏ผ‰
```
<run_dir>/
โ”œโ”€โ”€ config.yaml # framework.name=PI0, framework.pi0.*, framework.effective_action_dim=7
โ”œโ”€โ”€ dataset_statistics.json # franka โ†’ action โ†’ min/max/mask๏ผˆ7 ็ปด๏ผŒ็”ฑ pi05_libero q01/q99 ่ฝฌๆข๏ผ‰
โ””โ”€โ”€ checkpoints/
โ””โ”€โ”€ model.safetensors # ๅณ pi05_libero_pytorch/model.safetensors๏ผˆๅฏ่ฝฏ้“พๆŽฅ๏ผ‰
```
server_policy.py ๅฏๅŠจๅ‘ฝไปค๏ผš
`--ckpt_path <run_dir>/checkpoints/model.safetensors`
**ๆ ผๅผ B๏ผš็›ดๆŽฅ pi05_libero_pytorch ็›ฎๅฝ•**๏ผˆไธๅซ dataset_statistics๏ผŒไป…ไพ›ๅฟซ้€Ÿๆต‹่ฏ•๏ผ‰
```
<pi0_dir>/
โ”œโ”€โ”€ model.safetensors # convert_jax_model_to_pytorch.py ่พ“ๅ‡บ
โ””โ”€โ”€ config.json # {"action_dim": 32, "action_horizon": 10, ...}
```
ๆณจๆ„๏ผšๆญคๆ ผๅผไธ‹ norm_stats ไธบ็ฉบ๏ผŒModelClient.get_action_stats ไผšๅคฑ่ดฅ๏ผŒ
้œ€่ฆๆ‰‹ๅŠจๆไพ›็ปŸ่ฎก้‡ๆˆ–ๆ”น็”จๆ ผๅผ Aใ€‚
Args:
pretrained_checkpoint: checkpoint ๆ–‡ไปถ่ทฏๅพ„๏ผˆ.safetensors ๆˆ– .pt๏ผ‰ใ€‚
Returns:
PI0Framework ๅฎžไพ‹๏ผŒๅทฒๅŠ ่ฝฝๆƒ้‡ใ€‚
"""
from pathlib import Path as _Path
pretrained_checkpoint = _Path(pretrained_checkpoint)
# โ”€โ”€โ”€ ๆ ผๅผ A๏ผšstarVLA wrapper ็›ฎๅฝ• โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
wrapper_run_dir = pretrained_checkpoint.parents[1]
wrapper_cfg_yaml = wrapper_run_dir / "config.yaml"
wrapper_stats_json = wrapper_run_dir / "dataset_statistics.json"
if wrapper_cfg_yaml.exists() and wrapper_stats_json.exists():
from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace
from starVLA.model.framework import build_framework
model_config, norm_stats = read_mode_config(pretrained_checkpoint)
config = dict_to_namespace(model_config)
config.trainer.pretrained_checkpoint = None
# ้˜ฒๆญข __init__ ๅ†…ไบŒๆฌกๅŠ ่ฝฝ๏ผˆ่‹ฅ config.yaml ไนŸ่ฎพไบ† pi0_checkpoint๏ผ‰
fw_cfg_ref = getattr(config, "framework", config)
if hasattr(fw_cfg_ref, "pi0_checkpoint"):
fw_cfg_ref.pi0_checkpoint = None
model = build_framework(cfg=config)
model.norm_stats = norm_stats
# ็›ดๆŽฅๅฐ† safetensors/pt ๆƒ้‡ๅŠ ่ฝฝๅˆฐ pi0_model๏ผˆ่ทณ่ฟ‡ PI0Framework ๅ‰็ผ€้—ฎ้ข˜๏ผ‰
if pretrained_checkpoint.exists():
model.load_pi0_weights(str(pretrained_checkpoint))
logger.info(f"[from_pretrained ๆ ผๅผA] ไปŽ wrapper ็›ฎๅฝ•ๅŠ ่ฝฝ๏ผš{wrapper_run_dir}")
return model
# โ”€โ”€โ”€ ๆ ผๅผ B๏ผš็›ดๆŽฅ pi05_libero_pytorch ็›ฎๅฝ• โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
pi0_dir = pretrained_checkpoint.parent
pi0_config_json = pi0_dir / "config.json"
if pi0_config_json.exists():
import json
from omegaconf import OmegaConf
with open(pi0_config_json) as _f:
pi0_cfg_dict = json.load(_f)
# ๅฐ่ฏ•ไปŽ safetensors ็š„ key ่‡ชๅŠจๅˆคๆ–ญ pi05
pi05_detected = False
if pretrained_checkpoint.suffix == ".safetensors":
try:
import safetensors
with safetensors.safe_open(str(pretrained_checkpoint), framework="pt") as _sf:
_keys = list(_sf.keys())
pi05_detected = "time_mlp_in.weight" in _keys
except Exception:
pass
config = OmegaConf.create({
"framework": {
"name": "PI0",
"pi0": {
"paligemma_variant": pi0_cfg_dict.get("paligemma_variant", "gemma_2b"),
"action_expert_variant": pi0_cfg_dict.get("action_expert_variant", "gemma_300m"),
"pi05": pi05_detected,
"action_dim": pi0_cfg_dict.get("action_dim", 32),
"action_horizon": pi0_cfg_dict.get("action_horizon", 50),
"dtype": pi0_cfg_dict.get("precision", "bfloat16"),
},
"pi0_checkpoint": None,
"num_inference_steps": 10,
},
"trainer": {"pretrained_checkpoint": None},
})
model = cls(config=config)
model.load_pi0_weights(str(pretrained_checkpoint))
model.norm_stats = {}
logger.warning(
"[from_pretrained ๆ ผๅผB] ็›ดๆŽฅๅŠ ่ฝฝ pi05_libero_pytorch ็›ฎๅฝ•๏ผŒ"
"norm_stats ไธบ็ฉบใ€‚ModelClient.get_action_stats ไผšๅคฑ่ดฅใ€‚"
"ๅปบ่ฎฎๅˆ›ๅปบ starVLA wrapper ็›ฎๅฝ•๏ผˆๅซ config.yaml + dataset_statistics.json๏ผ‰ใ€‚"
)
logger.info(f"[from_pretrained ๆ ผๅผB] pi05_detected={pi05_detected}๏ผŒ่ทฏๅพ„๏ผš{pretrained_checkpoint}")
return model
# โ”€โ”€โ”€ Fallback โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
logger.warning(
f"PI0Framework.from_pretrained๏ผšๆ— ๆณ•่ฏ†ๅˆซ checkpoint ๆ ผๅผ๏ผˆ{pretrained_checkpoint}๏ผ‰ใ€‚"
"่ฏทๅ‚่€ƒ docstring ๅˆ›ๅปบ starVLA wrapper ็›ฎๅฝ•๏ผˆๆ ผๅผ A๏ผ‰ใ€‚"
"ๅฐ่ฏ•่ฐƒ็”จ็ˆถ็ฑป from_pretrained๏ผˆๅคงๆฆ‚็އๅ›  key ไธๅŒน้…่€Œๅคฑ่ดฅ๏ผ‰ใ€‚"
)
return super().from_pretrained(pretrained_checkpoint, **kwargs)
# โ”€โ”€ PI0Pytorch ็ป“ๆž„้€‚้…ๅทฅๅ…ท โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# PI0Pytorch ๆบ็ ไธญไปฅไธ‹ไธ‰ไธชๆŠ•ๅฝฑๅฑ‚็กฌ็ผ–็ ไธบ 32D๏ผŒไธ่ฏปๅ– config.action_dim๏ผš
# self.action_in_proj = nn.Linear(32, width)
# self.action_out_proj = nn.Linear(width, 32)
# self.state_proj = nn.Linear(32, width) # pi05=False ๆ—ถ
_PI0_HARDCODED_ACTION_DIM: int = 32
def _replace_pi0_projection_layers(self, action_dim: int, state_dim: int = None) -> None:
"""
ๅฐ† PI0Pytorch ็กฌ็ผ–็ ็š„ 32D ๆŠ•ๅฝฑๅฑ‚ๆ›ฟๆขไธบ็›ฎๆ ‡็ปดๅบฆใ€‚
- action_in_proj / action_out_proj ๆŒ‰ action_dim ๆ›ฟๆขใ€‚
- state_proj ๆŒ‰ state_dim ๆ›ฟๆข๏ผˆstate_dim ๅฏไธŽ action_dim ไธๅŒ๏ผŒ
ๅฆ‚ unified 74D state + 37D action ็š„ๅœบๆ™ฏ๏ผ‰ใ€‚
- ๅฝ“ๆŸ็ปดๅบฆไธŽ็กฌ็ผ–็ ็š„ 32D ็›ธๅŒๆ—ถ๏ผŒๅฏนๅบ”ๅฑ‚ไธๅšๆ›ฟๆขใ€‚
- ๆ›ฟๆขๅŽ็š„ๅฑ‚ไธบ้šๆœบๅˆๅง‹ๅŒ–๏ผ›ๅŠ ่ฝฝ checkpoint ๆ—ถ๏ผŒ
`_filter_ckpt_by_shape` ไผšๅ›  shape ไธๅŒน้…่€Œ่‡ชๅŠจ่ทณ่ฟ‡่ฟ™ไบ› keyใ€‚
Args:
action_dim: ๅŠจไฝœ็ปดๅบฆ๏ผˆ็”จไบŽ action_in_proj / action_out_proj๏ผ‰ใ€‚
state_dim: ็Šถๆ€็ปดๅบฆ๏ผˆ็”จไบŽ state_proj๏ผ‰ใ€‚็ผบ็œๆ—ถไธŽ action_dim ็›ธๅŒใ€‚
"""
if state_dim is None:
state_dim = action_dim
# action_in_proj / action_out_proj ๅง‹็ปˆๅญ˜ๅœจ
if action_dim != self._PI0_HARDCODED_ACTION_DIM:
proj_width = self.pi0_model.action_in_proj.out_features # e.g. 1024
self.pi0_model.action_in_proj = nn.Linear(action_dim, proj_width)
self.pi0_model.action_out_proj = nn.Linear(proj_width, action_dim)
logger.info(
f"[PI0Framework] action_in/out_proj ๅทฒๆ›ฟๆข๏ผš"
f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {proj_width}) "
f"โ†’ Linear({action_dim}, {proj_width}) "
f"[้šๆœบๅˆๅง‹ๅŒ–๏ผŒไธๅŠ ่ฝฝ checkpoint ๆƒ้‡]"
)
# state_proj ไป…ๅœจ pi05=False ๆ—ถๅญ˜ๅœจ
if not self._pi0_cfg.pi05 and hasattr(self.pi0_model, "state_proj"):
if state_dim != self._PI0_HARDCODED_ACTION_DIM:
state_proj_width = self.pi0_model.state_proj.out_features
self.pi0_model.state_proj = nn.Linear(state_dim, state_proj_width)
logger.info(
f"[PI0Framework] state_proj ๅทฒๆ›ฟๆข๏ผš"
f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {state_proj_width}) "
f"โ†’ Linear({state_dim}, {state_proj_width}) "
f"[้šๆœบๅˆๅง‹ๅŒ–๏ผŒไธๅŠ ่ฝฝ checkpoint ๆƒ้‡]"
)
def _filter_ckpt_by_shape(self, state_dict: dict) -> dict:
"""
่ฟ‡ๆปค checkpoint state_dict๏ผŒ็งป้™คไธŽๅฝ“ๅ‰ๆจกๅž‹ shape ไธไธ€่‡ด็š„ keyใ€‚
load_state_dict(strict=False) ้‡ๅˆฐ shape ไธๅŒน้…ๆ—ถไปไผšๆŠ›ๅ‡บ RuntimeError๏ผ›
ๆœฌๆ–นๆณ•ๆๅ‰่ฟ‡ๆปค๏ผŒ่ฎฉ่ฟ™ไบ› key ไฟๆŒ้šๆœบๅˆๅง‹ๅŒ–๏ผŒๅ…ถไฝ™ key ๆญฃๅธธๅŠ ่ฝฝใ€‚
Args:
state_dict: ไปŽๆ–‡ไปถ่ฏปๅ–็š„ๅŽŸๅง‹ๆƒ้‡ๅญ—ๅ…ธ๏ผˆkey ไธบ่ฃธๅ๏ผŒไธๅซๅ‰็ผ€๏ผ‰ใ€‚
Returns:
่ฟ‡ๆปคๅŽ็š„ๅญ—ๅ…ธ๏ผŒๅชไฟ็•™ shape ๅŒน้…๏ผˆๆˆ–ๆจกๅž‹ไธญไธๅญ˜ๅœจ๏ผ‰็š„ keyใ€‚
"""
model_sd = self.pi0_model.state_dict()
filtered: dict = {}
skipped: list = []
for k, v in state_dict.items():
if k in model_sd and model_sd[k].shape != v.shape:
skipped.append(
f" {k}: ckpt{tuple(v.shape)} โ‰  model{tuple(model_sd[k].shape)}"
)
else:
filtered[k] = v
if skipped:
logger.info(
f"[load_pi0_weights] ่ทณ่ฟ‡ {len(skipped)} ไธช shape ไธๅŒน้…็š„ key"
f"๏ผˆ่ฟ™ไบ›ๅฑ‚ไฟๆŒ้šๆœบๅˆๅง‹ๅŒ–๏ผŒๅฐ†ๅœจ่ฎญ็ปƒไธญไปŽๅคดๅญฆไน ๏ผ‰๏ผš"
)
for s in skipped:
logger.info(s)
return filtered
def load_pi0_weights(self, checkpoint_path: str) -> None:
"""
ไปŽ pi0 ้ข„่ฎญ็ปƒ checkpoint ๅŠ ่ฝฝๆƒ้‡ๅˆฐ self.pi0_modelใ€‚
ๆ”ฏๆŒๆ ผๅผ๏ผš
- .safetensors : ไฝฟ็”จ safetensors.torch.load_file + ่ฟ‡ๆปคๅŽ load_state_dict
- .pt / .pth : ไฝฟ็”จ torch.load๏ผˆmap_location="cpu"๏ผ‰
ๅฝ“ config.action_dim ไธŽ checkpoint ็š„็กฌ็ผ–็ ็ปดๅบฆ๏ผˆ32๏ผ‰ไธไธ€่‡ดๆ—ถ๏ผŒ
_filter_ckpt_by_shape ไผš่‡ชๅŠจ่ทณ่ฟ‡ action_in_proj / action_out_proj / state_proj๏ผŒ
่ฟ™ไบ›ๅฑ‚ไฟๆŒ _replace_pi0_projection_layers ็š„้šๆœบๅˆๅง‹ๅŒ–๏ผŒ็”ฑ่ฎญ็ปƒ่‡ช่กŒๅญฆไน ใ€‚
Args:
checkpoint_path: ๆƒ้‡ๆ–‡ไปถ่ทฏๅพ„ใ€‚
Raises:
FileNotFoundError: ๆ–‡ไปถไธๅญ˜ๅœจใ€‚
RuntimeError: ๆƒ้‡ๅŠ ่ฝฝๅคฑ่ดฅใ€‚
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"pi0 checkpoint ไธๅญ˜ๅœจ๏ผš{checkpoint_path}")
logger.info(f"ๅŠ ่ฝฝ pi0 ้ข„่ฎญ็ปƒๆƒ้‡๏ผš{checkpoint_path}")
if checkpoint_path.suffix == ".safetensors":
import safetensors.torch as sf_torch
state_dict = sf_torch.load_file(str(checkpoint_path))
state_dict = self._filter_ckpt_by_shape(state_dict)
missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False)
# embed_tokens.weight ไธŽ lm_head.weight ็ป‘ๅฎš๏ผŒๅชไฟๅญ˜ไธ€ไปฝๅฑžๆญฃๅธธ็Žฐ่ฑก
expected_missing = {
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
}
real_missing = set(missing) - expected_missing
if real_missing:
logger.warning(f"ๅŠ ่ฝฝๅŽ็œŸๆญฃ็ผบๅคฑ็š„ key๏ผˆ{len(real_missing)} ไธช๏ผ‰๏ผš{list(real_missing)[:10]} ...")
if unexpected:
logger.warning(f"ๅŠ ่ฝฝๅŽๅคšไฝ™็š„ key๏ผˆ{len(unexpected)} ไธช๏ผ‰๏ผš{list(unexpected)[:10]} ...")
else:
state_dict = torch.load(str(checkpoint_path), map_location="cpu")
# ๅ…ผๅฎนๅค–ๅฑ‚ๅŒ…่ฃ…๏ผˆๆœ‰ไบ› checkpoint ไผšๅคšๅฅ—ไธ€ๅฑ‚ key๏ผ‰
if "model" in state_dict and isinstance(state_dict["model"], dict):
state_dict = state_dict["model"]
state_dict = self._filter_ckpt_by_shape(state_dict)
model_keys = set(self.pi0_model.state_dict().keys())
checkpoint_keys = set(state_dict.keys())
missing = model_keys - checkpoint_keys
unexpected = checkpoint_keys - model_keys
if missing:
logger.warning(f"ๆƒ้‡ไธญ็ผบๅฐ‘็š„ key๏ผˆ{len(missing)} ไธช๏ผ‰๏ผš{list(missing)[:10]} ...")
if unexpected:
logger.warning(f"ๆƒ้‡ไธญๅคšไฝ™็š„ key๏ผˆ{len(unexpected)} ไธช๏ผ‰๏ผš{list(unexpected)[:10]} ...")
self.pi0_model.load_state_dict(state_dict, strict=False)
logger.info("pi0 ้ข„่ฎญ็ปƒๆƒ้‡ๅŠ ่ฝฝๅฎŒๆฏ•ใ€‚")
def forward(
self,
examples: List[dict] = None,
**kwargs,
) -> dict:
"""
่ฎญ็ปƒๅ‰ๅ‘ไผ ๆ’ญ๏ผŒ่ฎก็ฎ—ๆตๅŒน้…๏ผˆflow matching๏ผ‰ๆŸๅคฑใ€‚
ๅŸบไบŽ PI0Pytorch.forward ็š„่ฎญ็ปƒ้€ป่พ‘๏ผš
1. ไปŽ examples ไธญ่ฏปๅ– image / lang / action / state
2. ๅฐ† action pad ๅˆฐ model ็š„ action_dim๏ผˆๅฆ‚ 32๏ผ‰๏ผŒไธ่ถณ็š„็ปดๅบฆ่กฅ้›ถ
3. ้€š่ฟ‡ _preprocess_examples ๆž„ๅปบ PI0 Observation ๅฏน่ฑก
4. ่ฐƒ็”จ PI0Pytorch.forward(observation, actions_target)๏ผš
- ้‡‡ๆ ท noise ๅ’Œ time
- ็บฟๆ€งๆ’ๅ€ผๅŠ ๅ™ช๏ผšx_t = t * noise + (1-t) * actions
- ็›ฎๆ ‡้€Ÿๅบฆ๏ผšu_t = noise - actions
- ้ข„ๆต‹้€Ÿๅบฆ๏ผšv_t = model(x_t, t)
- ๆŸๅคฑ๏ผšMSE(u_t, v_t) per element โ†’ ๅ†ๅ–ๅ‡ๅ€ผ
5. ่ฟ”ๅ›ž {"action_loss": scalar}
Args:
examples: List[dict]๏ผŒๆฏไธช dict ๅŒ…ๅซ๏ผš
- "image" (List[PIL.Image | np.ndarray]): ๅ„่ง†่ง’ๅ›พๅƒ
- "lang" (str): ่ฏญ่จ€ๆŒ‡ไปค
- "action" (np.ndarray): ็›ฎๆ ‡ๅŠจไฝœ๏ผŒshape (action_horizon, D)๏ผŒ
D ้€šๅธธๆ˜ฏ effective_action_dim๏ผˆๅฆ‚ 7๏ผ‰๏ผŒไธ่ถณ action_dim ไผš่‡ชๅŠจ zero-pad
- "state" (np.ndarray, ๅฏ้€‰): ๆœบๅ™จไบบๆœฌไฝ“็Šถๆ€
**kwargs: ้ข„็•™๏ผŒๆš‚ๆœชไฝฟ็”จใ€‚
Returns:
dict:
"action_loss" (torch.Tensor): ๆ ‡้‡๏ผŒbatch ๅนณๅ‡ๆตๅŒน้… MSE ๆŸๅคฑใ€‚
"""
if not isinstance(examples, list):
examples = [examples]
device = next(self.pi0_model.parameters()).device
action_dim = self._pi0_cfg.action_dim # 32๏ผˆๆจกๅž‹ๅ…จ็ปดๅบฆ๏ผ‰
action_horizon = self._pi0_cfg.action_horizon # ๅฆ‚ 10
# โ”€โ”€ 1. ๆž„ๅปบ PI0 Observation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
observation = self._preprocess_examples(examples, device)
# โ”€โ”€ 2. ๆ•ด็† action target๏ผŒpad ๅˆฐ action_dim โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# examples[i]["action"]: np.ndarray (action_horizon, D)๏ผŒD ๅฏ่ƒฝ < action_dim
actions_list = []
for example in examples:
a = np.array(example["action"], dtype=np.float32)
# ่‹ฅ็ผบๅฐ‘ horizon ็ปด๏ผŒ่ง†ไธบๅ•ๆญฅ
if a.ndim == 1:
a = a[np.newaxis, :] # (1, D)
# action_horizon ๅฏน้ฝ๏ผˆๆˆชๆ–ญๆˆ– zero-pad๏ผ‰
H, D = a.shape
if H > action_horizon:
a = a[:action_horizon]
elif H < action_horizon:
a = np.concatenate(
[a, np.zeros((action_horizon - H, D), dtype=np.float32)], axis=0
)
# action_dim ๅฏน้ฝ๏ผˆๆˆชๆ–ญๆˆ– zero-pad๏ผ‰
D = a.shape[1]
if D > action_dim:
a = a[:, :action_dim]
elif D < action_dim:
a = np.concatenate(
[a, np.zeros((action_horizon, action_dim - D), dtype=np.float32)], axis=1
)
actions_list.append(a)
actions_target = torch.tensor(
np.stack(actions_list, axis=0), dtype=torch.float32, device=device
) # [B, action_horizon, action_dim]
# โ”€โ”€ 3. ่ฐƒ็”จ PI0Pytorch.forward ่ฎก็ฎ—ๆตๅŒน้…ๆŸๅคฑ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# PI0Pytorch.forward ่ฟ”ๅ›ž F.mse_loss(u_t, v_t, reduction="none")
# shape: [B, action_horizon, action_dim]
# ๅ†…้ƒจไผš่‡ชๅŠจๅค„็† bfloat16 cast๏ผˆๅฏน embedding ๅฑ‚๏ผ‰ๅนถไปฅ float32 ่พ“ๅ‡บ loss
with torch.autocast("cuda", dtype=torch.bfloat16):
loss_per_element = self.pi0_model.forward(observation, actions_target)
# ๅฏน batch / horizon / dim ๅ…จ้ƒจๅ–ๅ‡ๅ€ผ๏ผŒๅพ—ๅˆฐๆ ‡้‡ๆŸๅคฑ
action_loss = loss_per_element.mean()
return {"action_loss": action_loss}
@torch.inference_mode()
def predict_action(
self,
examples: List[dict],
**kwargs,
) -> dict:
"""
ๆŽจ็†๏ผšๆ นๆฎ่ง‚ๆต‹้ข„ๆต‹ๅŠจไฝœๅบๅˆ—ใ€‚
ๆต็จ‹๏ผš
1. ๅฐ† examples ไธญ็š„ๅ›พๅƒ่ฝฌๆขไธบๅฝ’ไธ€ๅŒ–ๅผ ้‡
2. Tokenize ่ฏญ่จ€ๆŒ‡ไปค
3. ๆๅ–/ๅฏน้ฝๆœฌไฝ“็Šถๆ€
4. ๆž„้€  PI0 Observation ๅฏน่ฑก
5. ่ฐƒ็”จ PI0Pytorch.sample_actions ่ฟ›่กŒๆตๅŒน้…ๅŽปๅ™ชๆŽจ็†
6. ่ฟ”ๅ›žๅฝ’ไธ€ๅŒ–ๅŠจไฝœ
Args:
examples: List[dict]๏ผŒๆฏไธช dict ๅŒ…ๅซ๏ผš
- "image" (List[PIL.Image | np.ndarray]): ๅ„่ง†่ง’ๅ›พๅƒ๏ผŒ้กบๅบไธŽ config.framework.image_keys ๅฏนๅบ”ใ€‚
ๆŽฅๅ— PIL Image ๆˆ– (H, W, 3) uint8 numpy ๆ•ฐ็ป„๏ผˆeval_libero.py ๆ ผๅผ๏ผ‰ใ€‚
- "lang" (str) : ไปปๅŠก่ฏญ่จ€ๆŒ‡ไปค
- "state" (np.ndarray, ๅฏ้€‰): ๆœบๅ™จไบบๆœฌไฝ“็Šถๆ€๏ผŒshape (state_dim,) ๆˆ– (1, state_dim)ใ€‚
โš ๏ธ pi05 ๆจกๅผไธ‹ state ๆ˜ฏ tokenized prompt ็š„ไธ€้ƒจๅˆ†๏ผŒ็ผบๅคฑไผšๅฏผ่‡ด prompt ๆ ผๅผ
้”™่ฏฏใ€ๆ€ง่ƒฝไธ‹้™ใ€‚eval_libero.py ้œ€่ฆๅœจ example_dict ไธญๆทปๅŠ ๆญคๅญ—ๆฎตใ€‚
**kwargs: ้ขๅค–ๅ‚ๆ•ฐ๏ผˆๆš‚ๆœชไฝฟ็”จ๏ผ‰ใ€‚
Returns:
dict:
"normalized_actions" (np.ndarray): shape [B, action_horizon, action_dim]๏ผŒ
ๅฝ’ไธ€ๅŒ–่‡ณ [-1, 1] ็š„้ข„ๆต‹ๅŠจไฝœๅบๅˆ—ใ€‚
"""
if not isinstance(examples, list):
examples = [examples]
# ็กฎๅฎš่ฟ่กŒ่ฎพๅค‡
device = next(self.pi0_model.parameters()).device
# ๅ‡†ๅค‡ PI0 Observation
observation = self._preprocess_examples(examples, device)
# ๆ‰ง่กŒๆตๅŒน้…ๅŽปๅ™ชๆŽจ็†
num_steps = kwargs.get("num_steps", self.num_inference_steps)
pred_actions = self.pi0_model.sample_actions(
device, observation, num_steps=num_steps
) # [B, action_horizon, action_dim]
normalized_actions = pred_actions.detach().cpu().numpy() # [B, action_horizon, action_dim]
# ๆˆชๆ–ญๅˆฐๆœ‰ๆ•ˆๅŠจไฝœ็ปดๅบฆ๏ผˆ็”จไบŽ LIBERO ็ญ‰ๅช้œ€่ฆๅ‰ N ็ปด็š„ๅœบๆ™ฏ๏ผ‰
# PI0 pi05_libero action_dim=32๏ผŒไฝ† LIBERO ๆœบๅ™จไบบๅชๆœ‰ 7 DOF
# ModelClient.unnormalize_actions ็š„ min/max stats ๅชๆœ‰ 7 ็ปด๏ผŒ
# ่‹ฅไธๆˆชๆ–ญไผšๅฏผ่‡ด numpy broadcast ๅฝข็ŠถไธๅŒน้…้”™่ฏฏใ€‚
# ๅœจ config.framework.effective_action_dim ไธญ่ฎพ็ฝฎๆœ‰ๆ•ˆ็ปดๅบฆ๏ผˆๅฆ‚ 7๏ผ‰ๆฅๅฏ็”จๆˆชๆ–ญใ€‚
if self.effective_action_dim is not None:
normalized_actions = normalized_actions[:, :, : self.effective_action_dim]
# print(normalized_actions.shape)
return {"normalized_actions": normalized_actions}
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅฟซ้€Ÿ้ชŒ่ฏ๏ผˆไป…ไพ›่ฐƒ่ฏ•๏ผŒไธไพ่ต–็œŸๅฎž checkpoint๏ผ‰
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
# โ”€โ”€ ้ป˜่ฎค่ทฏๅพ„๏ผˆๅฏนๅบ” pi05_libero ๆจกๅž‹๏ผ‰ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
_DEFAULT_CKPT = (
"/mnt/data/fangyu/model/openpi/openpi-assets/checkpoints"
"/pi05_libero_pytorch/model.safetensors"
)
_DEFAULT_TOKENIZER = (
"/root/.cache/openpi/big_vision/paligemma_tokenizer.model"
)
parser = argparse.ArgumentParser(description="PI0Framework ๅฟซ้€Ÿๅ†’็ƒŸๆต‹่ฏ•")
parser.add_argument("--config_yaml", type=str, default=None, help="YAML ้…็ฝฎๆ–‡ไปถ่ทฏๅพ„๏ผˆไผ˜ๅ…ˆไบŽๅ†…่”้…็ฝฎ๏ผ‰")
parser.add_argument("--pi0_checkpoint", type=str, default=_DEFAULT_CKPT, help="pi0 safetensors ๆƒ้‡่ทฏๅพ„")
parser.add_argument("--tokenizer_path", type=str, default=_DEFAULT_TOKENIZER, help="PaliGemma tokenizer .model ่ทฏๅพ„")
parser.add_argument("--device", type=str, default=None, help="่ฟ่กŒ่ฎพๅค‡๏ผŒๅฆ‚ cuda:0 / cpu๏ผˆ้ป˜่ฎค่‡ชๅŠจๆฃ€ๆต‹๏ผ‰")
parser.add_argument("--steps", type=int, default=10, help="ๆตๅŒน้…ๆŽจ็†ๆญฅๆ•ฐ")
parser.add_argument("--batch_size", type=int, default=1, help="ๆต‹่ฏ• batch size")
args, _ = parser.parse_known_args()
if args.config_yaml:
cfg = OmegaConf.load(args.config_yaml)
else:
# pi05_libero ๅ†…่”้…็ฝฎ
cfg = OmegaConf.create({
"framework": {
"name": "PI0",
"pi0": {
"paligemma_variant": "gemma_2b",
"action_expert_variant": "gemma_300m",
"pi05": True, # pi05_libero ไฝฟ็”จ pi0.5
"action_dim": 32,
"action_horizon": 10, # pi05_libero action_horizon=10
"dtype": "bfloat16",
},
"tokenizer_path": args.tokenizer_path,
"pi0_checkpoint": args.pi0_checkpoint,
"image_keys": ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"],
"num_inference_steps": args.steps,
}
})
print("=" * 60)
print("PI0Framework ๆต‹่ฏ•")
print("=" * 60)
print(f" checkpoint : {args.pi0_checkpoint}")
print(f" tokenizer : {args.tokenizer_path}")
print(f" infer steps : {args.steps}")
print(f" batch size : {args.batch_size}")
# โ”€โ”€ 1. ๆž„ๅปบๆจกๅž‹ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print("\n[1/3] ๅˆๅง‹ๅŒ– PI0Framework ๅนถๅŠ ่ฝฝๆƒ้‡...")
model = PI0Framework(cfg)
total_params = sum(p.numel() for p in model.pi0_model.parameters())
print(f" ๆจกๅž‹ๅ‚ๆ•ฐ้‡: {total_params / 1e9:.2f}B")
# โ”€โ”€ 2. ็งปๅˆฐ็›ฎๆ ‡่ฎพๅค‡ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if args.device:
device = torch.device(args.device)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n[2/3] ็งป่‡ณ่ฎพๅค‡: {device}")
model = model.to(device)
# pi05_libero ็”จ bfloat16
model.pi0_model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
# โ”€โ”€ 3. ๆž„้€ ๅ‡ๆ ทๆœฌๅนถๆŽจ็† โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print(f"\n[3/3] ๆž„้€  batch_size={args.batch_size} ็š„ๅ‡ๆ ทๆœฌๅนถ่ฐƒ็”จ predict_action...")
fake_img = Image.fromarray(
np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
)
sample = {
"image": [fake_img, fake_img, fake_img], # 3 ่ง†่ง’๏ผšbase + left_wrist + right_wrist
"lang": "put the red cup on the plate",
"state": np.random.uniform(-1, 1, size=(32,)).astype(np.float32),
}
batch = [sample] * args.batch_size
import time
t0 = time.time()
result = model.predict_action(batch)
elapsed = time.time() - t0
actions = result["normalized_actions"]
print(f"\n ่พ“ๅ‡บ normalized_actions shape : {actions.shape}")
print(f" ๆŽจ็†่€—ๆ—ถ : {elapsed*1000:.1f} ms")
print(f" ๅŠจไฝœๅ€ผๅŸŸ [min, max] : [{actions.min():.4f}, {actions.max():.4f}]")
print(f" ๅŠจไฝœๅ‡ๅ€ผ ยฑ ๆ ‡ๅ‡†ๅทฎ : {actions.mean():.4f} ยฑ {actions.std():.4f}")
print("\n[OK] PI0Framework ๆต‹่ฏ•ๅฎŒๆˆ๏ผ")