| |
| |
| """ |
| PI0 Framework |
| ๅฐ openpi ็ PI0Pytorch ๆจกๅๅฐ่ฃ
ไธบ starVLA framework ๆฅๅฃใ |
| |
| ๆฏๆ๏ผ |
| - ไป pi0 safetensors checkpoint ๅ ่ฝฝ้ข่ฎญ็ปๅๆฐ |
| - ไธๅ
ถไป starVLA framework ็ธๅ็ __init__ ๅ predict_action ๆฅๅฃ |
| - ไฝฟ็จ PaliGemma SentencePiece tokenizer ๅค็่ฏญ่จๆไปค |
| - ๅฐ starVLA ๆ ทๆฌๆ ผๅผ๏ผPIL ๅพๅๅ่กจ + lang ๅญ็ฌฆไธฒ๏ผ่ฝฌๆขไธบ PI0 Observation ๆ ผๅผ |
| |
| ๅ่ๆฅๆบ๏ผ |
| - openpi/src/openpi/models_pytorch/pi0_pytorch.py (PI0Pytorch) |
| - openpi/src/openpi/policies/policy.py (Policy.infer) |
| """ |
|
|
| import sys |
| import logging |
| from pathlib import Path |
| from typing import List, Optional, Dict, Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
|
|
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
| from starVLA.training.trainer_utils import initialize_overwatch |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| |
| |
| |
| |
| _IMAGE_RESOLUTION = (224, 224) |
|
|
| |
| _DEFAULT_IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") |
|
|
|
|
| |
| |
| |
|
|
| def _pil_to_tensor_normalized(img, resolution=(224, 224)) -> torch.Tensor: |
| """ |
| ๅฐๅพๅ่ฝฌๆขไธบ PI0 ๆ้็ๅผ ้ๆ ผๅผใ |
| |
| ่พๅบไธบ channels-first ๆ ผๅผ [C, H, W]๏ผๅฝไธๅ่ณ [-1, 1]ใ |
| PI0 ็ preprocess_observation_pytorch ๆฃๆตๅฐ channels-first๏ผshape[1]==3๏ผๆถ |
| ไผๅ
่ฝฌไธบ [B, H, W, C] ๅๅขๅนฟ๏ผๅ่ฝฌๅ [B, C, H, W]๏ผๆ็ป้ๅ
ฅ SigLIP conv2d |
| ๏ผSigLIP ็ patch_embedding ๆๆ [B, C, H, W] ๆ ผๅผ๏ผใ |
| |
| Args: |
| img: PIL.Image.Image ๆ np.ndarray (H, W, 3) uint8ใ |
| ๅๆถๅ
ผๅฎนๆฅ่ช eval_libero.py ็ numpy uint8 ๆฐ็ปๅ PIL Imageใ |
| resolution: ็ฎๆ ๅ่พจ็ (H, W)๏ผ้ป่ฎค (224, 224)ใ |
| |
| Returns: |
| torch.Tensor: shape [C, H, W], dtype float32, ๅผๅ [-1, 1]ใ |
| """ |
| |
| if isinstance(img, np.ndarray): |
| |
| img = Image.fromarray(img.astype(np.uint8)) |
| |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| if img.size != (resolution[1], resolution[0]): |
| img = img.resize((resolution[1], resolution[0]), Image.BILINEAR) |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| arr = arr * 2.0 - 1.0 |
| t = torch.from_numpy(arr) |
| return t.permute(2, 0, 1) |
|
|
|
|
| def _build_pi0_config_obj(framework_cfg): |
| """ |
| ไป starVLA ็ framework ้
็ฝฎ่็นไธญๆ้ PI0Pytorch ๆ้็่ฝป้้
็ฝฎๅฏน่ฑกใ |
| |
| PI0Pytorch.__init__ ้่ฆ็ๅญๆฎต๏ผ |
| - paligemma_variant : str (e.g. "gemma_2b") |
| - action_expert_variant: str (e.g. "gemma_300m") |
| - pi05 : bool |
| - action_dim : int |
| - action_horizon : int |
| - dtype : str (e.g. "bfloat16") |
| |
| ่ฟไบๅญๆฎตๅฏๅๅจ yaml ็ framework.pi0 ่็นไธ๏ผ็ผบ็ๅผๆฅ่ช openpi ็ Pi0Configใ |
| """ |
| pi0_node = getattr(framework_cfg, "pi0", None) or {} |
|
|
| def _get(key, default): |
| if hasattr(pi0_node, key): |
| return getattr(pi0_node, key) |
| if isinstance(pi0_node, dict): |
| return pi0_node.get(key, default) |
| return default |
|
|
| class _PI0Config: |
| paligemma_variant = _get("paligemma_variant", "gemma_2b") |
| action_expert_variant= _get("action_expert_variant","gemma_300m") |
| pi05 = _get("pi05", False) |
| action_dim = _get("action_dim", 37) |
| state_dim = _get("state_dim", 74) |
| action_horizon = _get("action_horizon", 15) |
| dtype = _get("dtype", "bfloat16") |
| |
| _max_token_len_raw = _get("max_token_len", None) |
|
|
| @property |
| def max_token_len(self): |
| if self._max_token_len_raw is not None: |
| return self._max_token_len_raw |
| return 200 if self.pi05 else 48 |
|
|
| return _PI0Config() |
|
|
|
|
| class _SentencePieceTokenizer: |
| """ |
| ่ฝป้ๅฐ่ฃ
SentencePiece tokenizer๏ผ็จไบ PaliGemma ้ฃๆ ผ็ๆ็คบ่ฏ็ผ็ ใ |
| |
| ไธ openpi.models.tokenizer.PaligemmaTokenizer ้ป่พ็ธๅ๏ผ |
| ไฝไธไพ่ต openpi ็ GCS ไธ่ฝฝ้ป่พ๏ผๆนไธบๆฅๅๆฌๅฐๆไปถ่ทฏๅพใ |
| """ |
|
|
| def __init__(self, model_path: str, max_len: int = 48): |
| import sentencepiece |
| self._max_len = max_len |
| with open(model_path, "rb") as f: |
| self._sp = sentencepiece.SentencePieceProcessor(model_proto=f.read()) |
|
|
| def tokenize(self, prompt: str, state: Optional[np.ndarray] = None): |
| """ |
| ่ฟๅ (tokens: np.ndarray[int32], mask: np.ndarray[bool])๏ผ้ฟๅบฆๅไธบ max_lenใ |
| """ |
| cleaned = prompt.strip().replace("_", " ").replace("\n", " ") |
| if state is not None: |
| |
| bins = np.linspace(-1, 1, 257)[:-1] |
| disc = np.digitize(state, bins=bins) - 1 |
| state_str = " ".join(map(str, disc)) |
| full_prompt = f"Task: {cleaned}, State: {state_str};\nAction: " |
| tokens = self._sp.encode(full_prompt, add_bos=True) |
| else: |
| tokens = self._sp.encode(cleaned, add_bos=True) + self._sp.encode("\n") |
|
|
| tokens_len = len(tokens) |
| if tokens_len < self._max_len: |
| pad = [False] * (self._max_len - tokens_len) |
| mask = [True] * tokens_len + pad |
| tokens = tokens + [0] * (self._max_len - tokens_len) |
| else: |
| if tokens_len > self._max_len: |
| logger.warning( |
| f"Token length ({tokens_len}) > max_len ({self._max_len}), truncating." |
| ) |
| tokens = tokens[: self._max_len] |
| mask = [True] * self._max_len |
|
|
| return np.array(tokens, dtype=np.int32), np.array(mask, dtype=bool) |
|
|
|
|
| |
| |
| |
|
|
| @FRAMEWORK_REGISTRY.register("PI0") |
| class PI0Framework(baseframework): |
| """ |
| starVLA framework ๅฐ่ฃ
๏ผๅฐ openpi ็ PI0Pytorch ๆจกๅ้ๆๅฐ starVLA ็ๆใ |
| |
| Config ่็น๏ผyaml ็คบไพ๏ผ๏ผ |
| ```yaml |
| framework: |
| name: "PI0" |
| pi0: |
| paligemma_variant: "gemma_2b" # PaliGemma backbone ๅไฝ |
| action_expert_variant: "gemma_300m" # Action Expert ๅไฝ |
| pi05: false # ๆฏๅฆไฝฟ็จ Pi0.5 ็ๆฌ |
| action_dim: 32 # ๅจไฝ็ปดๅบฆ |
| action_horizon: 50 # ้ขๆตๅจไฝๆญฅๆฐ |
| dtype: "bfloat16" # ๆจกๅๆ้็ฒพๅบฆ |
| tokenizer_path: "/path/to/paligemma_tokenizer.model" # SentencePiece ๆจกๅ |
| pi0_checkpoint: "/path/to/model.safetensors" # ๅฏ้๏ผpi0 ้ข่ฎญ็ปๆ้ |
| image_keys: # ๅพๅ้ฎๅ้กบๅบ๏ผๅฏนๅบ examples["image"] ไธญ็้กบๅบ๏ผ |
| - "base_0_rgb" |
| - "left_wrist_0_rgb" |
| - "right_wrist_0_rgb" |
| num_inference_steps: 10 # ๆตๅน้
ๆจ็ๆญฅๆฐ |
| ``` |
| |
| ๆฅๅฃ๏ผ |
| __init__(config) : ๆๅปบๆจกๅใๅ ่ฝฝ tokenizerใๅฏ้ๅ ่ฝฝ้ข่ฎญ็ปๆ้ |
| predict_action(examples) : ๆจ็๏ผ่ฟๅ {"normalized_actions": np.ndarray} |
| forward(examples) : ่ฎญ็ปๅๅ๏ผๆๆชๅฎ็ฐ๏ผ |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[Any] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| ๅๅงๅ PI0Frameworkใ |
| |
| Args: |
| config: starVLA ๅฑ็บง้
็ฝฎ๏ผOmegaConf DictConfig ๆๅ
ผๅฎนๅฏน่ฑก๏ผ๏ผ |
| ้กปๅ
ๅซ framework.pi0ใframework.tokenizer_path ็ญๅญๆฎตใ |
| **kwargs: ้ข็ใ |
| """ |
| super().__init__() |
| self.config = config |
|
|
| fw_cfg = getattr(config, "framework", config) |
|
|
| |
| pi0_cfg = _build_pi0_config_obj(fw_cfg) |
| self._pi0_cfg = pi0_cfg |
|
|
| |
| try: |
| sys.path.append("/mnt/data/fangyu/code/MixtureOfHorizons/src") |
| from openpi.models_pytorch.pi0_pytorch import PI0Pytorch |
| except ImportError as e: |
| raise ImportError( |
| "PI0Framework ไพ่ต openpi ๅ
ใ่ฏท็กฎไฟ openpi ๅทฒๅฎ่ฃ
ๆ " |
| "openpi/src ๅทฒๆทปๅ ๅฐ PYTHONPATHใๅๅง้่ฏฏ: " + str(e) |
| ) from e |
|
|
| self.pi0_model: nn.Module = PI0Pytorch(config=pi0_cfg) |
| logger.info( |
| f"PI0Pytorch ๅทฒๅๅงๅ๏ผvariant={pi0_cfg.paligemma_variant}, " |
| f"action_dim={pi0_cfg.action_dim}, action_horizon={pi0_cfg.action_horizon}, " |
| f"pi05={pi0_cfg.pi05}" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| self._replace_pi0_projection_layers(pi0_cfg.action_dim, pi0_cfg.state_dim) |
|
|
| |
| _ik = getattr(fw_cfg, "image_keys", None) |
| if _ik is not None: |
| self.image_keys = list(_ik) |
| else: |
| self.image_keys = list(_DEFAULT_IMAGE_KEYS) |
|
|
| |
| self.num_inference_steps = getattr(fw_cfg, "num_inference_steps", 10) |
|
|
| |
| |
| |
| |
| self.effective_action_dim = getattr(fw_cfg, "effective_action_dim", None) |
|
|
| |
| |
| self._replicate_single_view = getattr(fw_cfg, "replicate_single_view", False) |
|
|
| |
| |
| self._use_state = getattr(fw_cfg, "use_state", True) |
|
|
| |
| |
| |
| self._dynamic_image_keys = getattr(fw_cfg, "dynamic_image_keys", False) |
|
|
| |
| tokenizer_path = getattr(fw_cfg, "tokenizer_path", None) |
| self._tokenizer = self._load_tokenizer(tokenizer_path, pi0_cfg.max_token_len) |
|
|
| |
| pi0_ckpt = getattr(fw_cfg, "pi0_checkpoint", None) |
| if pi0_ckpt: |
| self.load_pi0_weights(pi0_ckpt) |
|
|
| |
| |
| |
|
|
| def _load_tokenizer(self, tokenizer_path: Optional[str], max_len: int): |
| """ |
| ๅ ่ฝฝ PaliGemma SentencePiece tokenizerใ |
| |
| ไผๅ
็บง๏ผ |
| 1. ไฝฟ็จ tokenizer_path ๆๅฎ็ๆฌๅฐ .model ๆไปถ |
| 2. ๅฐ่ฏ้่ฟ openpi ็ไธ่ฝฝๅทฅๅ
ท่ชๅจ่ทๅ paligemma_tokenizer.model |
| 3. ่ฅๅๅคฑ่ดฅ๏ผtokenizer ่ฎพไธบ None๏ผpredict_action ๆถไผๆฅ้ๆ็คบ็จๆท |
| |
| Args: |
| tokenizer_path: ๆฌๅฐ sentencepiece .model ๆไปถ่ทฏๅพ๏ผๅฏไธบ Noneใ |
| max_len: ๆๅคง token ้ฟๅบฆใ |
| |
| Returns: |
| _SentencePieceTokenizer ๅฎไพ๏ผๆ Noneใ |
| """ |
| if tokenizer_path and Path(tokenizer_path).exists(): |
| logger.info(f"ไปๆฌๅฐ่ทฏๅพๅ ่ฝฝ tokenizer๏ผ{tokenizer_path}") |
| return _SentencePieceTokenizer(tokenizer_path, max_len=max_len) |
|
|
| |
| try: |
| from openpi.shared import download as _download |
| path = _download.maybe_download( |
| "gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"} |
| ) |
| logger.info(f"้่ฟ openpi ไธ่ฝฝๅทฅๅ
ทๅ ่ฝฝ tokenizer๏ผ{path}") |
| return _SentencePieceTokenizer(str(path), max_len=max_len) |
| except Exception as e: |
| logger.warning( |
| f"ๆ ๆณ่ชๅจไธ่ฝฝ paligemma tokenizer๏ผ{e}ใ" |
| "่ฏทๅจ config ไธญ่ฎพ็ฝฎ framework.tokenizer_path ๆๅๆฌๅฐ .model ๆไปถใ" |
| ) |
| return None |
|
|
| def _preprocess_examples(self, examples: List[dict], device: torch.device): |
| """ |
| ๅฐ starVLA ๆ ทๆฌๆ ผๅผ่ฝฌๆขไธบ PI0 Observation ๅฏน่ฑกใ |
| |
| starVLA ๆ ทๆฌๆ ผๅผ๏ผ |
| examples[i]["image"] : List[PIL.Image] โโ ๅ่ง่งๅพๅ |
| examples[i]["lang"] : str โโ ่ฏญ่จๆไปค |
| examples[i]["state"] : np.ndarray (ๅฏ้) โโ ๆบๅจไบบๆฌไฝ็ถๆ๏ผshape (1, state_dim) |
| |
| PI0 Observation ๆ ผๅผ๏ผ |
| images : dict[key -> Tensor[B, H, W, C]], ๅผๅ [-1, 1] |
| image_masks : dict[key -> Tensor[B]], bool |
| state : Tensor[B, action_dim] |
| tokenized_prompt : Tensor[B, max_token_len], int32 |
| tokenized_prompt_mask : Tensor[B, max_token_len], bool |
| |
| Args: |
| examples: List[dict]๏ผๆฏไธช dict ไธบไธไธชๆ ทๆฌใ |
| device: ็ฎๆ torch ่ฎพๅคใ |
| |
| Returns: |
| Observation ๅฏน่ฑก๏ผopenpi.models.model.Observation๏ผใ |
| """ |
| from openpi.models.model import Observation |
|
|
| batch_size = len(examples) |
|
|
| |
| |
| replicate_single_view = getattr(self, "_replicate_single_view", False) |
| dynamic_image_keys = getattr(self, "_dynamic_image_keys", False) |
|
|
| |
| num_views = len(examples[0].get("image", [])) if examples else len(self.image_keys) |
| if replicate_single_view and num_views == 1 and len(self.image_keys) > 1: |
| num_views = len(self.image_keys) |
| if dynamic_image_keys: |
| active_keys = list(self.image_keys)[: max(1, num_views)] |
| else: |
| active_keys = list(self.image_keys) |
|
|
| images_batch: Dict[str, List[torch.Tensor]] = {k: [] for k in active_keys} |
| masks_batch: Dict[str, List[bool]] = {k: [] for k in active_keys} |
|
|
| for example in examples: |
| imgs: List[Image.Image] = example.get("image", []) |
| if replicate_single_view and len(imgs) == 1 and len(active_keys) > 1: |
| imgs = imgs * len(active_keys) |
| for idx, key in enumerate(active_keys): |
| if idx < len(imgs): |
| t = _pil_to_tensor_normalized(imgs[idx], _IMAGE_RESOLUTION) |
| images_batch[key].append(t) |
| masks_batch[key].append(True) |
| else: |
| |
| t = torch.zeros( |
| (3, _IMAGE_RESOLUTION[0], _IMAGE_RESOLUTION[1]), dtype=torch.float32 |
| ) |
| images_batch[key].append(t) |
| masks_batch[key].append(False) |
|
|
| images_tensor: Dict[str, torch.Tensor] = { |
| k: torch.stack(v, dim=0).to(device, dtype=torch.float32) |
| for k, v in images_batch.items() |
| } |
| image_masks_tensor: Dict[str, torch.Tensor] = { |
| k: torch.tensor(v, dtype=torch.bool, device=device) |
| for k, v in masks_batch.items() |
| } |
|
|
| |
| if self._tokenizer is None: |
| raise RuntimeError( |
| "Tokenizer ๆชๅๅงๅใ่ฏทๅจ config ไธญ่ฎพ็ฝฎ framework.tokenizer_path๏ผ" |
| "ๆ็กฎไฟ openpi ็ฝ็ปๅฏ่ฎฟ้ฎไปฅ่ชๅจไธ่ฝฝใ" |
| ) |
|
|
| pi05 = self._pi0_cfg.pi05 |
| use_state = getattr(self, "_use_state", True) |
| _pi05_missing_state_warned = False |
| all_tokens = [] |
| all_masks = [] |
| for example in examples: |
| lang = example.get("lang", example.get("language", "")) |
| state_for_tok = None |
| if pi05 and use_state: |
| |
| |
| raw_state = example.get("state", None) |
| if raw_state is not None: |
| s = np.array(raw_state) |
| if s.ndim > 1: |
| s = s[0] |
| s = s.flatten()[: self._pi0_cfg.state_dim] |
| state_for_tok = s |
| elif not _pi05_missing_state_warned: |
| |
| |
| |
| |
| |
| |
| _pi05_missing_state_warned = True |
| toks, mask = self._tokenizer.tokenize(lang, state_for_tok) |
| all_tokens.append(toks) |
| all_masks.append(mask) |
|
|
| tokenized_prompt = torch.tensor( |
| np.stack(all_tokens, axis=0), dtype=torch.int32, device=device |
| ) |
| tokenized_prompt_mask = torch.tensor( |
| np.stack(all_masks, axis=0), dtype=torch.bool, device=device |
| ) |
|
|
| |
| |
| |
| |
| |
| state_dim = self._pi0_cfg.state_dim |
| state_list = [] |
| for example in examples: |
| raw = example.get("state", None) if use_state else None |
| if raw is not None: |
| s = np.array(raw, dtype=np.float32) |
| if s.ndim > 1: |
| s = s[0] |
| s = s.flatten() |
| else: |
| s = np.zeros(state_dim, dtype=np.float32) |
| |
| if len(s) >= state_dim: |
| s = s[:state_dim] |
| else: |
| s = np.concatenate([s, np.zeros(state_dim - len(s), dtype=np.float32)]) |
| state_list.append(s) |
| state_tensor = torch.tensor( |
| np.stack(state_list, axis=0), dtype=torch.float32, device=device |
| ) |
|
|
| return Observation( |
| images=images_tensor, |
| image_masks=image_masks_tensor, |
| state=state_tensor, |
| tokenized_prompt=tokenized_prompt, |
| tokenized_prompt_mask=tokenized_prompt_mask, |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| def load_state_dict(self, state_dict, strict=True, assign=False): |
| """ |
| ้ๅ load_state_dict๏ผ่งฃๅณ PI0 checkpoint key ๅ็ผไธๅน้
้ฎ้ขใ |
| |
| **้ฎ้ขๆ นๅ ๏ผ** |
| `baseframework.from_pretrained` ่ฐ็จ `FrameworkModel.load_state_dict(ckpt_dict, strict=True)`ใ |
| PI0Framework ็ state_dict key ๅธฆ `pi0_model.` ๅ็ผ๏ผ`self.pi0_model` ๆฏๅญๆจกๅ๏ผ๏ผ |
| ไฝ `convert_jax_model_to_pytorch.py` ่พๅบ็ safetensors ๆฏ่ฃธ key๏ผๆ ๅ็ผ๏ผใ |
| ็ดๆฅ `load_state_dict` ไธฅๆ ผๆจกๅผๅฟ
็ถๅคฑ่ดฅใ |
| |
| **ไฟฎๅค๏ผ** |
| ๆฃๆตๅฐ state_dict ไธญๅไธบ่ฃธ key๏ผไธๅซ `pi0_model.` ๅ็ผ๏ผๆถ๏ผ |
| ่ชๅจๅฐๆ้ๅ ่ฝฝๅฐ `self.pi0_model`๏ผ็ป่ฟ `PI0Framework` ๅฑ็ๅ็ผ้ฎ้ข๏ผใ |
| """ |
| if state_dict and not any(k.startswith("pi0_model.") for k in state_dict.keys()): |
| logger.info( |
| "[PI0Framework.load_state_dict] ๆฃๆตๅฐ่ฃธ key๏ผpi05_libero_pytorch ๆ ผๅผ๏ผ๏ผ" |
| "็ดๆฅๅ ่ฝฝๅฐ self.pi0_model๏ผ่ทณ่ฟ pi0_model. ๅ็ผ๏ผ" |
| ) |
| |
| |
| |
| missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False) |
| expected_missing = {"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"} |
| real_missing = set(missing) - expected_missing |
| if real_missing: |
| logger.warning(f"[PI0Framework.load_state_dict] ็ๆญฃ็ผบๅคฑ็ key๏ผ้ๆ้็ปๅฎ๏ผ๏ผ{real_missing}") |
| if unexpected: |
| logger.warning(f"[PI0Framework.load_state_dict] ๅคไฝ็ key๏ผ{set(unexpected)}") |
| logger.info("[PI0Framework.load_state_dict] pi0 ๆ้ๅ ่ฝฝๅฎๆ๏ผembed_tokens.weight ็ฑๆ้็ปๅฎ่ชๅจๅๆญฅ๏ผ") |
| from torch.nn.modules.module import _IncompatibleKeys |
| return _IncompatibleKeys(missing, unexpected) |
| return super().load_state_dict(state_dict, strict=strict, assign=assign) |
|
|
| |
| |
| |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_checkpoint: str, **kwargs): |
| """ |
| ไป checkpoint ๅ ่ฝฝ PI0Framework๏ผไพ server_policy.py ่ฐ็จใ |
| |
| **ไธบไปไน้่ฆ้ๅ๏ผ** |
| `baseframework.from_pretrained` ๆ็ป่ฐ็จ `FrameworkModel.load_state_dict(ckpt_dict)`๏ผ |
| ่ PI0Framework ็ state_dict key ๅธฆ `pi0_model.` ๅ็ผ๏ผๅ ไธบ self.pi0_model ๆฏๅญๆจกๅ๏ผ๏ผ |
| ไฝ convert_jax_model_to_pytorch.py ่พๅบ็ safetensors ๆฏ่ฃธ key๏ผๆ ๅ็ผ๏ผใ |
| ็ดๆฅ load_state_dict ไผๅ key ไธๅน้
่ๅคฑ่ดฅใ |
| ๆฌๆนๆณๅฐๆ้็ดๆฅๅ ่ฝฝๅฐ self.pi0_model๏ผ็ป่ฟๅ็ผ้ฎ้ขใ |
| |
| ๆฏๆไธค็ง checkpoint ๆ ผๅผ๏ผ |
| |
| **ๆ ผๅผ A๏ผstarVLA wrapper ็ฎๅฝ**๏ผๆจ่็จไบ server_policy.py ้จ็ฝฒ๏ผ |
| ``` |
| <run_dir>/ |
| โโโ config.yaml # framework.name=PI0, framework.pi0.*, framework.effective_action_dim=7 |
| โโโ dataset_statistics.json # franka โ action โ min/max/mask๏ผ7 ็ปด๏ผ็ฑ pi05_libero q01/q99 ่ฝฌๆข๏ผ |
| โโโ checkpoints/ |
| โโโ model.safetensors # ๅณ pi05_libero_pytorch/model.safetensors๏ผๅฏ่ฝฏ้พๆฅ๏ผ |
| ``` |
| server_policy.py ๅฏๅจๅฝไปค๏ผ |
| `--ckpt_path <run_dir>/checkpoints/model.safetensors` |
| |
| **ๆ ผๅผ B๏ผ็ดๆฅ pi05_libero_pytorch ็ฎๅฝ**๏ผไธๅซ dataset_statistics๏ผไป
ไพๅฟซ้ๆต่ฏ๏ผ |
| ``` |
| <pi0_dir>/ |
| โโโ model.safetensors # convert_jax_model_to_pytorch.py ่พๅบ |
| โโโ config.json # {"action_dim": 32, "action_horizon": 10, ...} |
| ``` |
| ๆณจๆ๏ผๆญคๆ ผๅผไธ norm_stats ไธบ็ฉบ๏ผModelClient.get_action_stats ไผๅคฑ่ดฅ๏ผ |
| ้่ฆๆๅจๆไพ็ป่ฎก้ๆๆน็จๆ ผๅผ Aใ |
| |
| Args: |
| pretrained_checkpoint: checkpoint ๆไปถ่ทฏๅพ๏ผ.safetensors ๆ .pt๏ผใ |
| |
| Returns: |
| PI0Framework ๅฎไพ๏ผๅทฒๅ ่ฝฝๆ้ใ |
| """ |
| from pathlib import Path as _Path |
|
|
| pretrained_checkpoint = _Path(pretrained_checkpoint) |
|
|
| |
| wrapper_run_dir = pretrained_checkpoint.parents[1] |
| wrapper_cfg_yaml = wrapper_run_dir / "config.yaml" |
| wrapper_stats_json = wrapper_run_dir / "dataset_statistics.json" |
|
|
| if wrapper_cfg_yaml.exists() and wrapper_stats_json.exists(): |
| from starVLA.model.framework.share_tools import read_mode_config, dict_to_namespace |
| from starVLA.model.framework import build_framework |
|
|
| model_config, norm_stats = read_mode_config(pretrained_checkpoint) |
| config = dict_to_namespace(model_config) |
| config.trainer.pretrained_checkpoint = None |
|
|
| |
| fw_cfg_ref = getattr(config, "framework", config) |
| if hasattr(fw_cfg_ref, "pi0_checkpoint"): |
| fw_cfg_ref.pi0_checkpoint = None |
|
|
| model = build_framework(cfg=config) |
| model.norm_stats = norm_stats |
|
|
| |
| if pretrained_checkpoint.exists(): |
| model.load_pi0_weights(str(pretrained_checkpoint)) |
|
|
| logger.info(f"[from_pretrained ๆ ผๅผA] ไป wrapper ็ฎๅฝๅ ่ฝฝ๏ผ{wrapper_run_dir}") |
| return model |
|
|
| |
| pi0_dir = pretrained_checkpoint.parent |
| pi0_config_json = pi0_dir / "config.json" |
|
|
| if pi0_config_json.exists(): |
| import json |
| from omegaconf import OmegaConf |
|
|
| with open(pi0_config_json) as _f: |
| pi0_cfg_dict = json.load(_f) |
|
|
| |
| pi05_detected = False |
| if pretrained_checkpoint.suffix == ".safetensors": |
| try: |
| import safetensors |
| with safetensors.safe_open(str(pretrained_checkpoint), framework="pt") as _sf: |
| _keys = list(_sf.keys()) |
| pi05_detected = "time_mlp_in.weight" in _keys |
| except Exception: |
| pass |
|
|
| config = OmegaConf.create({ |
| "framework": { |
| "name": "PI0", |
| "pi0": { |
| "paligemma_variant": pi0_cfg_dict.get("paligemma_variant", "gemma_2b"), |
| "action_expert_variant": pi0_cfg_dict.get("action_expert_variant", "gemma_300m"), |
| "pi05": pi05_detected, |
| "action_dim": pi0_cfg_dict.get("action_dim", 32), |
| "action_horizon": pi0_cfg_dict.get("action_horizon", 50), |
| "dtype": pi0_cfg_dict.get("precision", "bfloat16"), |
| }, |
| "pi0_checkpoint": None, |
| "num_inference_steps": 10, |
| }, |
| "trainer": {"pretrained_checkpoint": None}, |
| }) |
|
|
| model = cls(config=config) |
| model.load_pi0_weights(str(pretrained_checkpoint)) |
| model.norm_stats = {} |
|
|
| logger.warning( |
| "[from_pretrained ๆ ผๅผB] ็ดๆฅๅ ่ฝฝ pi05_libero_pytorch ็ฎๅฝ๏ผ" |
| "norm_stats ไธบ็ฉบใModelClient.get_action_stats ไผๅคฑ่ดฅใ" |
| "ๅปบ่ฎฎๅๅปบ starVLA wrapper ็ฎๅฝ๏ผๅซ config.yaml + dataset_statistics.json๏ผใ" |
| ) |
| logger.info(f"[from_pretrained ๆ ผๅผB] pi05_detected={pi05_detected}๏ผ่ทฏๅพ๏ผ{pretrained_checkpoint}") |
| return model |
|
|
| |
| logger.warning( |
| f"PI0Framework.from_pretrained๏ผๆ ๆณ่ฏๅซ checkpoint ๆ ผๅผ๏ผ{pretrained_checkpoint}๏ผใ" |
| "่ฏทๅ่ docstring ๅๅปบ starVLA wrapper ็ฎๅฝ๏ผๆ ผๅผ A๏ผใ" |
| "ๅฐ่ฏ่ฐ็จ็ถ็ฑป from_pretrained๏ผๅคงๆฆ็ๅ key ไธๅน้
่ๅคฑ่ดฅ๏ผใ" |
| ) |
| return super().from_pretrained(pretrained_checkpoint, **kwargs) |
|
|
| |
|
|
| |
| |
| |
| |
| _PI0_HARDCODED_ACTION_DIM: int = 32 |
|
|
| def _replace_pi0_projection_layers(self, action_dim: int, state_dim: int = None) -> None: |
| """ |
| ๅฐ PI0Pytorch ็กฌ็ผ็ ็ 32D ๆๅฝฑๅฑๆฟๆขไธบ็ฎๆ ็ปดๅบฆใ |
| |
| - action_in_proj / action_out_proj ๆ action_dim ๆฟๆขใ |
| - state_proj ๆ state_dim ๆฟๆข๏ผstate_dim ๅฏไธ action_dim ไธๅ๏ผ |
| ๅฆ unified 74D state + 37D action ็ๅบๆฏ๏ผใ |
| - ๅฝๆ็ปดๅบฆไธ็กฌ็ผ็ ็ 32D ็ธๅๆถ๏ผๅฏนๅบๅฑไธๅๆฟๆขใ |
| - ๆฟๆขๅ็ๅฑไธบ้ๆบๅๅงๅ๏ผๅ ่ฝฝ checkpoint ๆถ๏ผ |
| `_filter_ckpt_by_shape` ไผๅ shape ไธๅน้
่่ชๅจ่ทณ่ฟ่ฟไบ keyใ |
| |
| Args: |
| action_dim: ๅจไฝ็ปดๅบฆ๏ผ็จไบ action_in_proj / action_out_proj๏ผใ |
| state_dim: ็ถๆ็ปดๅบฆ๏ผ็จไบ state_proj๏ผใ็ผบ็ๆถไธ action_dim ็ธๅใ |
| """ |
| if state_dim is None: |
| state_dim = action_dim |
|
|
| |
| if action_dim != self._PI0_HARDCODED_ACTION_DIM: |
| proj_width = self.pi0_model.action_in_proj.out_features |
| self.pi0_model.action_in_proj = nn.Linear(action_dim, proj_width) |
| self.pi0_model.action_out_proj = nn.Linear(proj_width, action_dim) |
| logger.info( |
| f"[PI0Framework] action_in/out_proj ๅทฒๆฟๆข๏ผ" |
| f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {proj_width}) " |
| f"โ Linear({action_dim}, {proj_width}) " |
| f"[้ๆบๅๅงๅ๏ผไธๅ ่ฝฝ checkpoint ๆ้]" |
| ) |
|
|
| |
| if not self._pi0_cfg.pi05 and hasattr(self.pi0_model, "state_proj"): |
| if state_dim != self._PI0_HARDCODED_ACTION_DIM: |
| state_proj_width = self.pi0_model.state_proj.out_features |
| self.pi0_model.state_proj = nn.Linear(state_dim, state_proj_width) |
| logger.info( |
| f"[PI0Framework] state_proj ๅทฒๆฟๆข๏ผ" |
| f"Linear({self._PI0_HARDCODED_ACTION_DIM}, {state_proj_width}) " |
| f"โ Linear({state_dim}, {state_proj_width}) " |
| f"[้ๆบๅๅงๅ๏ผไธๅ ่ฝฝ checkpoint ๆ้]" |
| ) |
|
|
| def _filter_ckpt_by_shape(self, state_dict: dict) -> dict: |
| """ |
| ่ฟๆปค checkpoint state_dict๏ผ็งป้คไธๅฝๅๆจกๅ shape ไธไธ่ด็ keyใ |
| |
| load_state_dict(strict=False) ้ๅฐ shape ไธๅน้
ๆถไปไผๆๅบ RuntimeError๏ผ |
| ๆฌๆนๆณๆๅ่ฟๆปค๏ผ่ฎฉ่ฟไบ key ไฟๆ้ๆบๅๅงๅ๏ผๅ
ถไฝ key ๆญฃๅธธๅ ่ฝฝใ |
| |
| Args: |
| state_dict: ไปๆไปถ่ฏปๅ็ๅๅงๆ้ๅญๅ
ธ๏ผkey ไธบ่ฃธๅ๏ผไธๅซๅ็ผ๏ผใ |
| |
| Returns: |
| ่ฟๆปคๅ็ๅญๅ
ธ๏ผๅชไฟ็ shape ๅน้
๏ผๆๆจกๅไธญไธๅญๅจ๏ผ็ keyใ |
| """ |
| model_sd = self.pi0_model.state_dict() |
| filtered: dict = {} |
| skipped: list = [] |
|
|
| for k, v in state_dict.items(): |
| if k in model_sd and model_sd[k].shape != v.shape: |
| skipped.append( |
| f" {k}: ckpt{tuple(v.shape)} โ model{tuple(model_sd[k].shape)}" |
| ) |
| else: |
| filtered[k] = v |
|
|
| if skipped: |
| logger.info( |
| f"[load_pi0_weights] ่ทณ่ฟ {len(skipped)} ไธช shape ไธๅน้
็ key" |
| f"๏ผ่ฟไบๅฑไฟๆ้ๆบๅๅงๅ๏ผๅฐๅจ่ฎญ็ปไธญไปๅคดๅญฆไน ๏ผ๏ผ" |
| ) |
| for s in skipped: |
| logger.info(s) |
|
|
| return filtered |
|
|
| def load_pi0_weights(self, checkpoint_path: str) -> None: |
| """ |
| ไป pi0 ้ข่ฎญ็ป checkpoint ๅ ่ฝฝๆ้ๅฐ self.pi0_modelใ |
| |
| ๆฏๆๆ ผๅผ๏ผ |
| - .safetensors : ไฝฟ็จ safetensors.torch.load_file + ่ฟๆปคๅ load_state_dict |
| - .pt / .pth : ไฝฟ็จ torch.load๏ผmap_location="cpu"๏ผ |
| |
| ๅฝ config.action_dim ไธ checkpoint ็็กฌ็ผ็ ็ปดๅบฆ๏ผ32๏ผไธไธ่ดๆถ๏ผ |
| _filter_ckpt_by_shape ไผ่ชๅจ่ทณ่ฟ action_in_proj / action_out_proj / state_proj๏ผ |
| ่ฟไบๅฑไฟๆ _replace_pi0_projection_layers ็้ๆบๅๅงๅ๏ผ็ฑ่ฎญ็ป่ช่กๅญฆไน ใ |
| |
| Args: |
| checkpoint_path: ๆ้ๆไปถ่ทฏๅพใ |
| |
| Raises: |
| FileNotFoundError: ๆไปถไธๅญๅจใ |
| RuntimeError: ๆ้ๅ ่ฝฝๅคฑ่ดฅใ |
| """ |
| checkpoint_path = Path(checkpoint_path) |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError(f"pi0 checkpoint ไธๅญๅจ๏ผ{checkpoint_path}") |
|
|
| logger.info(f"ๅ ่ฝฝ pi0 ้ข่ฎญ็ปๆ้๏ผ{checkpoint_path}") |
|
|
| if checkpoint_path.suffix == ".safetensors": |
| import safetensors.torch as sf_torch |
| state_dict = sf_torch.load_file(str(checkpoint_path)) |
| state_dict = self._filter_ckpt_by_shape(state_dict) |
| missing, unexpected = self.pi0_model.load_state_dict(state_dict, strict=False) |
| |
| expected_missing = { |
| "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" |
| } |
| real_missing = set(missing) - expected_missing |
| if real_missing: |
| logger.warning(f"ๅ ่ฝฝๅ็ๆญฃ็ผบๅคฑ็ key๏ผ{len(real_missing)} ไธช๏ผ๏ผ{list(real_missing)[:10]} ...") |
| if unexpected: |
| logger.warning(f"ๅ ่ฝฝๅๅคไฝ็ key๏ผ{len(unexpected)} ไธช๏ผ๏ผ{list(unexpected)[:10]} ...") |
| else: |
| state_dict = torch.load(str(checkpoint_path), map_location="cpu") |
| |
| if "model" in state_dict and isinstance(state_dict["model"], dict): |
| state_dict = state_dict["model"] |
| state_dict = self._filter_ckpt_by_shape(state_dict) |
| model_keys = set(self.pi0_model.state_dict().keys()) |
| checkpoint_keys = set(state_dict.keys()) |
| missing = model_keys - checkpoint_keys |
| unexpected = checkpoint_keys - model_keys |
| if missing: |
| logger.warning(f"ๆ้ไธญ็ผบๅฐ็ key๏ผ{len(missing)} ไธช๏ผ๏ผ{list(missing)[:10]} ...") |
| if unexpected: |
| logger.warning(f"ๆ้ไธญๅคไฝ็ key๏ผ{len(unexpected)} ไธช๏ผ๏ผ{list(unexpected)[:10]} ...") |
| self.pi0_model.load_state_dict(state_dict, strict=False) |
|
|
| logger.info("pi0 ้ข่ฎญ็ปๆ้ๅ ่ฝฝๅฎๆฏใ") |
|
|
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| ่ฎญ็ปๅๅไผ ๆญ๏ผ่ฎก็ฎๆตๅน้
๏ผflow matching๏ผๆๅคฑใ |
| |
| ๅบไบ PI0Pytorch.forward ็่ฎญ็ป้ป่พ๏ผ |
| 1. ไป examples ไธญ่ฏปๅ image / lang / action / state |
| 2. ๅฐ action pad ๅฐ model ็ action_dim๏ผๅฆ 32๏ผ๏ผไธ่ถณ็็ปดๅบฆ่กฅ้ถ |
| 3. ้่ฟ _preprocess_examples ๆๅปบ PI0 Observation ๅฏน่ฑก |
| 4. ่ฐ็จ PI0Pytorch.forward(observation, actions_target)๏ผ |
| - ้ๆ ท noise ๅ time |
| - ็บฟๆงๆๅผๅ ๅช๏ผx_t = t * noise + (1-t) * actions |
| - ็ฎๆ ้ๅบฆ๏ผu_t = noise - actions |
| - ้ขๆต้ๅบฆ๏ผv_t = model(x_t, t) |
| - ๆๅคฑ๏ผMSE(u_t, v_t) per element โ ๅๅๅๅผ |
| 5. ่ฟๅ {"action_loss": scalar} |
| |
| Args: |
| examples: List[dict]๏ผๆฏไธช dict ๅ
ๅซ๏ผ |
| - "image" (List[PIL.Image | np.ndarray]): ๅ่ง่งๅพๅ |
| - "lang" (str): ่ฏญ่จๆไปค |
| - "action" (np.ndarray): ็ฎๆ ๅจไฝ๏ผshape (action_horizon, D)๏ผ |
| D ้ๅธธๆฏ effective_action_dim๏ผๅฆ 7๏ผ๏ผไธ่ถณ action_dim ไผ่ชๅจ zero-pad |
| - "state" (np.ndarray, ๅฏ้): ๆบๅจไบบๆฌไฝ็ถๆ |
| **kwargs: ้ข็๏ผๆๆชไฝฟ็จใ |
| |
| Returns: |
| dict: |
| "action_loss" (torch.Tensor): ๆ ้๏ผbatch ๅนณๅๆตๅน้
MSE ๆๅคฑใ |
| """ |
| if not isinstance(examples, list): |
| examples = [examples] |
|
|
| device = next(self.pi0_model.parameters()).device |
| action_dim = self._pi0_cfg.action_dim |
| action_horizon = self._pi0_cfg.action_horizon |
|
|
| |
| observation = self._preprocess_examples(examples, device) |
|
|
| |
| |
| actions_list = [] |
| for example in examples: |
| a = np.array(example["action"], dtype=np.float32) |
|
|
| |
| if a.ndim == 1: |
| a = a[np.newaxis, :] |
|
|
| |
| H, D = a.shape |
| if H > action_horizon: |
| a = a[:action_horizon] |
| elif H < action_horizon: |
| a = np.concatenate( |
| [a, np.zeros((action_horizon - H, D), dtype=np.float32)], axis=0 |
| ) |
|
|
| |
| D = a.shape[1] |
| if D > action_dim: |
| a = a[:, :action_dim] |
| elif D < action_dim: |
| a = np.concatenate( |
| [a, np.zeros((action_horizon, action_dim - D), dtype=np.float32)], axis=1 |
| ) |
|
|
| actions_list.append(a) |
|
|
| actions_target = torch.tensor( |
| np.stack(actions_list, axis=0), dtype=torch.float32, device=device |
| ) |
|
|
| |
| |
| |
| |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| loss_per_element = self.pi0_model.forward(observation, actions_target) |
|
|
| |
| action_loss = loss_per_element.mean() |
|
|
| return {"action_loss": action_loss} |
|
|
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict], |
| **kwargs, |
| ) -> dict: |
| """ |
| ๆจ็๏ผๆ นๆฎ่งๆต้ขๆตๅจไฝๅบๅใ |
| |
| ๆต็จ๏ผ |
| 1. ๅฐ examples ไธญ็ๅพๅ่ฝฌๆขไธบๅฝไธๅๅผ ้ |
| 2. Tokenize ่ฏญ่จๆไปค |
| 3. ๆๅ/ๅฏน้ฝๆฌไฝ็ถๆ |
| 4. ๆ้ PI0 Observation ๅฏน่ฑก |
| 5. ่ฐ็จ PI0Pytorch.sample_actions ่ฟ่กๆตๅน้
ๅปๅชๆจ็ |
| 6. ่ฟๅๅฝไธๅๅจไฝ |
| |
| Args: |
| examples: List[dict]๏ผๆฏไธช dict ๅ
ๅซ๏ผ |
| - "image" (List[PIL.Image | np.ndarray]): ๅ่ง่งๅพๅ๏ผ้กบๅบไธ config.framework.image_keys ๅฏนๅบใ |
| ๆฅๅ PIL Image ๆ (H, W, 3) uint8 numpy ๆฐ็ป๏ผeval_libero.py ๆ ผๅผ๏ผใ |
| - "lang" (str) : ไปปๅก่ฏญ่จๆไปค |
| - "state" (np.ndarray, ๅฏ้): ๆบๅจไบบๆฌไฝ็ถๆ๏ผshape (state_dim,) ๆ (1, state_dim)ใ |
| โ ๏ธ pi05 ๆจกๅผไธ state ๆฏ tokenized prompt ็ไธ้จๅ๏ผ็ผบๅคฑไผๅฏผ่ด prompt ๆ ผๅผ |
| ้่ฏฏใๆง่ฝไธ้ใeval_libero.py ้่ฆๅจ example_dict ไธญๆทปๅ ๆญคๅญๆฎตใ |
| **kwargs: ้ขๅคๅๆฐ๏ผๆๆชไฝฟ็จ๏ผใ |
| |
| Returns: |
| dict: |
| "normalized_actions" (np.ndarray): shape [B, action_horizon, action_dim]๏ผ |
| ๅฝไธๅ่ณ [-1, 1] ็้ขๆตๅจไฝๅบๅใ |
| """ |
| if not isinstance(examples, list): |
| examples = [examples] |
|
|
| |
| device = next(self.pi0_model.parameters()).device |
|
|
| |
| observation = self._preprocess_examples(examples, device) |
|
|
| |
| num_steps = kwargs.get("num_steps", self.num_inference_steps) |
| pred_actions = self.pi0_model.sample_actions( |
| device, observation, num_steps=num_steps |
| ) |
|
|
| normalized_actions = pred_actions.detach().cpu().numpy() |
|
|
| |
| |
| |
| |
| |
| if self.effective_action_dim is not None: |
| normalized_actions = normalized_actions[:, :, : self.effective_action_dim] |
| |
| return {"normalized_actions": normalized_actions} |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import argparse |
| from omegaconf import OmegaConf |
|
|
| |
| _DEFAULT_CKPT = ( |
| "/mnt/data/fangyu/model/openpi/openpi-assets/checkpoints" |
| "/pi05_libero_pytorch/model.safetensors" |
| ) |
| _DEFAULT_TOKENIZER = ( |
| "/root/.cache/openpi/big_vision/paligemma_tokenizer.model" |
| ) |
|
|
| parser = argparse.ArgumentParser(description="PI0Framework ๅฟซ้ๅ็ๆต่ฏ") |
| parser.add_argument("--config_yaml", type=str, default=None, help="YAML ้
็ฝฎๆไปถ่ทฏๅพ๏ผไผๅ
ไบๅ
่้
็ฝฎ๏ผ") |
| parser.add_argument("--pi0_checkpoint", type=str, default=_DEFAULT_CKPT, help="pi0 safetensors ๆ้่ทฏๅพ") |
| parser.add_argument("--tokenizer_path", type=str, default=_DEFAULT_TOKENIZER, help="PaliGemma tokenizer .model ่ทฏๅพ") |
| parser.add_argument("--device", type=str, default=None, help="่ฟ่ก่ฎพๅค๏ผๅฆ cuda:0 / cpu๏ผ้ป่ฎค่ชๅจๆฃๆต๏ผ") |
| parser.add_argument("--steps", type=int, default=10, help="ๆตๅน้
ๆจ็ๆญฅๆฐ") |
| parser.add_argument("--batch_size", type=int, default=1, help="ๆต่ฏ batch size") |
| args, _ = parser.parse_known_args() |
|
|
| if args.config_yaml: |
| cfg = OmegaConf.load(args.config_yaml) |
| else: |
| |
| cfg = OmegaConf.create({ |
| "framework": { |
| "name": "PI0", |
| "pi0": { |
| "paligemma_variant": "gemma_2b", |
| "action_expert_variant": "gemma_300m", |
| "pi05": True, |
| "action_dim": 32, |
| "action_horizon": 10, |
| "dtype": "bfloat16", |
| }, |
| "tokenizer_path": args.tokenizer_path, |
| "pi0_checkpoint": args.pi0_checkpoint, |
| "image_keys": ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"], |
| "num_inference_steps": args.steps, |
| } |
| }) |
|
|
| print("=" * 60) |
| print("PI0Framework ๆต่ฏ") |
| print("=" * 60) |
| print(f" checkpoint : {args.pi0_checkpoint}") |
| print(f" tokenizer : {args.tokenizer_path}") |
| print(f" infer steps : {args.steps}") |
| print(f" batch size : {args.batch_size}") |
|
|
| |
| print("\n[1/3] ๅๅงๅ PI0Framework ๅนถๅ ่ฝฝๆ้...") |
| model = PI0Framework(cfg) |
| total_params = sum(p.numel() for p in model.pi0_model.parameters()) |
| print(f" ๆจกๅๅๆฐ้: {total_params / 1e9:.2f}B") |
|
|
| |
| if args.device: |
| device = torch.device(args.device) |
| else: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\n[2/3] ็งป่ณ่ฎพๅค: {device}") |
| model = model.to(device) |
| |
| model.pi0_model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16") |
|
|
| |
| print(f"\n[3/3] ๆ้ batch_size={args.batch_size} ็ๅๆ ทๆฌๅนถ่ฐ็จ predict_action...") |
| fake_img = Image.fromarray( |
| np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) |
| ) |
| sample = { |
| "image": [fake_img, fake_img, fake_img], |
| "lang": "put the red cup on the plate", |
| "state": np.random.uniform(-1, 1, size=(32,)).astype(np.float32), |
| } |
| batch = [sample] * args.batch_size |
|
|
| import time |
| t0 = time.time() |
| result = model.predict_action(batch) |
| elapsed = time.time() - t0 |
|
|
| actions = result["normalized_actions"] |
| print(f"\n ่พๅบ normalized_actions shape : {actions.shape}") |
| print(f" ๆจ็่ๆถ : {elapsed*1000:.1f} ms") |
| print(f" ๅจไฝๅผๅ [min, max] : [{actions.min():.4f}, {actions.max():.4f}]") |
| print(f" ๅจไฝๅๅผ ยฑ ๆ ๅๅทฎ : {actions.mean():.4f} ยฑ {actions.std():.4f}") |
| print("\n[OK] PI0Framework ๆต่ฏๅฎๆ๏ผ") |
|
|