| """Utils for evaluating the OpenVLA policy.""" |
|
|
| import json |
| import os |
| import time |
| from collections import deque |
| from glob import glob |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torchvision |
| from lerobot.configs.policies import PreTrainedConfig |
| from PIL import Image |
| from safetensors import safe_open |
| from safetensors.torch import load_file |
| from torch import Tensor, nn |
| from tqdm import tqdm |
| from transformers import ( |
| AutoConfig, |
| AutoProcessor, |
| PretrainedConfig, |
| PreTrainedModel, |
| ) |
| from transformers.models.auto.tokenization_auto import AutoTokenizer |
| from veomni.models.vla.pi0 import PI0Policy, QwenPI0Policy |
|
|
| IMAGE_KEYS = ( |
| "base_0_rgb", |
| "left_wrist_0_rgb", |
| "right_wrist_0_rgb", |
| ) |
|
|
|
|
| class AdaptiveEnsembler: |
|
|
| def __init__(self, pred_action_horizon, adaptive_ensemble_alpha=0.0): |
| self.pred_action_horizon = pred_action_horizon |
| self.action_history = deque(maxlen=self.pred_action_horizon) |
| self.adaptive_ensemble_alpha = adaptive_ensemble_alpha |
|
|
| def reset(self): |
| self.action_history.clear() |
|
|
| def ensemble_action(self, cur_action): |
| self.action_history.append(cur_action) |
| num_actions = len(self.action_history) |
| if cur_action.ndim == 1: |
| curr_act_preds = np.stack(self.action_history) |
| else: |
| curr_act_preds = np.stack([ |
| pred_actions[i] for ( |
| i, |
| pred_actions) in zip(range(num_actions - |
| 1, -1, -1), self.action_history) |
| ]) |
|
|
| |
| ref = curr_act_preds[num_actions - 1, :] |
| previous_pred = curr_act_preds |
| dot_product = np.sum(previous_pred * ref, axis=1) |
| norm_previous_pred = np.linalg.norm(previous_pred, axis=1) |
| norm_ref = np.linalg.norm(ref) |
| cos_similarity = dot_product / (norm_previous_pred * norm_ref + 1e-7) |
|
|
| |
| weights = np.exp(self.adaptive_ensemble_alpha * cos_similarity) |
| weights = weights / weights.sum() |
|
|
| |
| cur_action = np.sum(weights[:, None] * curr_act_preds, axis=0) |
|
|
| return cur_action |
|
|
|
|
| def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image: |
| crop_scale = 0.9 |
| side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, |
| 1.0))) |
| out_size = (224, 224) |
|
|
| |
| if isinstance(image, np.ndarray): |
| arr = image |
| if arr.dtype.kind == "f": |
| |
| if arr.max() <= 1.0 and arr.min() >= 0.0: |
| arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8) |
| else: |
| arr = np.clip(arr, 0.0, 255.0).astype(np.uint8) |
| elif arr.dtype == np.uint16: |
| |
| arr = (arr / 257).astype(np.uint8) |
| elif arr.dtype != np.uint8: |
| arr = arr.astype(np.uint8) |
| pil = Image.fromarray(arr) |
| elif isinstance(image, Image.Image): |
| pil = image |
| else: |
| raise TypeError("image must be a numpy array or PIL.Image.Image") |
|
|
| |
| pil = pil.convert("RGB") |
| W, H = pil.size |
|
|
| |
| crop_w = max(1, int(round(W * side_scale))) |
| crop_h = max(1, int(round(H * side_scale))) |
| left = (W - crop_w) // 2 |
| top = (H - crop_h) // 2 |
| right = left + crop_w |
| bottom = top + crop_h |
|
|
| cropped = pil.crop((left, top, right, bottom)) |
| resized = cropped.resize(out_size, resample=Image.BILINEAR) |
| return resized |
|
|
|
|
| def resize_with_pad(img, width, height, pad_value=-1): |
| |
| if img.ndim != 4: |
| raise ValueError(f"(b,c,h,w) expected, but {img.shape}") |
|
|
| |
| if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3): |
| img = img.permute(0, 3, 1, 2) |
|
|
| cur_height, cur_width = img.shape[2:] |
|
|
| ratio = max(cur_width / width, cur_height / height) |
| resized_height = int(cur_height / ratio) |
| resized_width = int(cur_width / ratio) |
| resized_img = F.interpolate(img, |
| size=(resized_height, resized_width), |
| mode="bilinear", |
| align_corners=False) |
|
|
| pad_height = max(0, int(height - resized_height)) |
| pad_width = max(0, int(width - resized_width)) |
|
|
| |
| padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), |
| value=pad_value) |
| return padded_img |
|
|
|
|
| class PolicyPreprocessMixin: |
| """ |
| A mixin class that provides preprocessing utilities for observations. |
| Can be mixed into any policy class to add image, state, action, language handling. |
| """ |
|
|
| def prepare_images(self, observation: dict[str, Tensor]): |
| """Normalize, resize, and pad images and stack them into a tensor. |
| |
| Args: |
| observation (dict[str, Tensor]) |
| |
| Returns: |
| images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0] |
| img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing |
| """ |
| dtype = observation["state"].dtype |
| bsize = observation["state"].shape[0] |
| device = observation["state"].device |
| images, img_masks = [], [] |
| for key in IMAGE_KEYS: |
| if key in observation["image"]: |
| |
| img = observation["image"][key] |
|
|
| if isinstance(img, np.ndarray): |
| img = torch.from_numpy(img) |
|
|
| img = resize_with_pad(img, |
| *self.config.resize_imgs_with_padding, |
| pad_value=0) |
| img = self.image_processor(img)['pixel_values'] |
| images.append(img) |
| img_masks.append(True) |
| else: |
| img = np.zeros_like(img) |
| images.append(img) |
| img_masks.append(False) |
| |
| if isinstance(images[0], torch.Tensor): |
| images = torch.stack(images, dim=0).to(device=device) |
| elif isinstance(images[0], np.ndarray): |
| images = torch.from_numpy(np.stack(images, axis=0)).to( |
| device=device) |
| img_masks = torch.tensor(img_masks, |
| dtype=torch.bool).to(device=device) |
|
|
| return images, img_masks |
|
|
| def prepare_state(self, observation): |
| state = torch.from_numpy(observation["state"]) |
| if isinstance(state, np.ndarray): |
| state = torch.from_numpy(state) |
| state = F.pad(state, (0, self.config.max_state_dim - state.shape[1])) |
| return state |
|
|
| def prepare_language(self, observation: dict[str, Tensor]): |
| """If `prompt` is provided, modify it to PaliGemma format and tokenize it. |
| If `lang_tokens` and `lang_masks` are provided, use them directly. |
| |
| PaliGemma expects prefix prompts to be formatted as: |
| <images> .... <images> <bos> prompt <sep>, where <sep> uses `\\n`. |
| So here we format the prompt to start with `<bos>` and end with `\\n`. |
| Later, we will concatenate the images and language tokens into a single sequence. |
| |
| Args: |
| observation (dict[str, Tensor]) |
| |
| Returns: |
| lang_tokens (torch.Tensor): (*b, l) language tokens |
| lang_masks (torch.Tensor): (*b, l) masks for language tokens, True if token is present, False if missing |
| """ |
| lang_tokens = observation.get("lang_tokens", None) |
| lang_masks = observation.get("lang_masks", None) |
| prompt = observation.get("prompt", None) |
|
|
| |
| if prompt is None and (lang_tokens is None or lang_masks is None): |
| raise ValueError( |
| "Either 'prompt' or ('lang_tokens', 'lang_masks') must be provided in the observation." |
| ) |
|
|
| device = observation["state"].device |
| if prompt is not None and (lang_tokens is None or lang_masks is None): |
| prompt = [ |
| p if p.startswith("<bos>") else f"<bos>{p}" for p in prompt |
| ] |
| prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt] |
| tokenized_prompt = self.language_tokenizer.__call__( |
| prompt, |
| padding="max_length", |
| padding_side="right", |
| max_length=self.config.tokenizer_max_length, |
| return_tensors="pt", |
| ) |
| lang_tokens = tokenized_prompt["input_ids"].to(device=device) |
| lang_masks = tokenized_prompt["attention_mask"].to( |
| device=device, dtype=torch.bool) |
| else: |
| lang_tokens = observation["lang_tokens"].to(device=device) |
| lang_masks = observation["lang_masks"].to(device=device, |
| dtype=torch.bool) |
|
|
| return lang_tokens, lang_masks |
|
|
| @torch.no_grad |
| def select_action(self, |
| observation: dict[str, Tensor], |
| noise: Tensor | None = None): |
| self.eval() |
| images, img_masks = self.prepare_images(observation) |
| state = self.prepare_state(observation) |
| lang_tokens, lang_masks = self.prepare_language(observation) |
| device = 'cuda' |
| dtype = torch.bfloat16 |
|
|
| actions = self.model.sample_actions( |
| images.to(dtype=dtype, device=device), |
| img_masks.to(device=device), |
| lang_tokens.to(device=device), |
| lang_masks.to(device=device), |
| state.to(dtype=dtype, device=device), |
| ) |
| return actions |
|
|
|
|
| class QwenPI0InferencePolicy(PolicyPreprocessMixin, QwenPI0Policy): |
| pass |
|
|
|
|
| class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy): |
| pass |
|
|
|
|
| def merge_qwen_config(policy_config, qwen_config): |
| if hasattr(qwen_config, 'to_dict'): |
| config_dict = qwen_config.to_dict() |
| else: |
| config_dict = qwen_config |
|
|
| text_keys = { |
| "hidden_size", |
| "intermediate_size", |
| "num_hidden_layers", |
| "num_attention_heads", |
| "num_key_value_heads", |
| "rms_norm_eps", |
| "rope_theta", |
| "vocab_size", |
| "max_position_embeddings", |
| "hidden_act", |
| "tie_word_embeddings", |
| "tokenizer_path", |
| } |
|
|
| for key in text_keys: |
| if key in config_dict: |
| setattr(policy_config, key, config_dict[key]) |
| print(f"✅ Merged LLM: {key} = {config_dict[key]}") |
|
|
| if "vision_config" in config_dict: |
| policy_config.vision_config = qwen_config.vision_config |
| else: |
| print("⚠️ Warning: 'vision_config' not found in qwen_config!") |
|
|
| return policy_config |
|
|
|
|
| class QwenPiServer: |
| ''' |
| policy wrapper to support action ensemble or chunk execution |
| ''' |
|
|
| def __init__( |
| self, |
| path_to_pi_model="", |
| adaptive_ensemble_alpha=0.1, |
| action_ensemble_horizon=8, |
| use_length=1, |
| use_bf16=True, |
| ) -> None: |
|
|
| self.adaptive_ensemble_alpha = adaptive_ensemble_alpha |
| self.action_ensemble_horizon = action_ensemble_horizon |
| self.use_length = use_length |
|
|
| self.task_description = None |
|
|
| self.action_ensembler = AdaptiveEnsembler(self.action_ensemble_horizon, |
| self.adaptive_ensemble_alpha) |
|
|
| self.vla = self.load_vla(path_to_pi_model) |
| self.vla = self.vla |
| self.global_step = 0 |
| self.last_action_chunk = None |
|
|
| def init_norm( |
| self, |
| states_path='/home/yangshuai/yangshuai_ssd0/checkpoint/qwen_pi0/norm_stats.json', |
| state_dim=14, |
| action_dim=14): |
| ''' |
| TODO: show be rewritten as a dict |
| ''' |
| with open(states_path) as f: |
| norm_stats = json.load(f)['hanging_mug-aloha-agilex_clean_50_rep'] |
| self.state_mean = np.array( |
| norm_stats["norm_stats"]["state"]["mean"][:state_dim], |
| dtype=np.float32) |
| self.state_std = np.array( |
| norm_stats["norm_stats"]["state"]["std"][:state_dim], |
| dtype=np.float32) |
| self.action_mean = np.array( |
| norm_stats["norm_stats"]["actions"]["mean"][:action_dim], |
| dtype=np.float32) |
| self.action_std = np.array( |
| norm_stats["norm_stats"]["actions"]["std"][:action_dim], |
| dtype=np.float32) |
|
|
| def state_normalizer(self, unnorm_state): |
| state = (unnorm_state - self.state_mean) / (self.state_std + 1e-6) |
| return state |
|
|
| def action_unnormalizer(self, norm_action): |
| action = norm_action * (self.action_std + 1e-6) + self.action_mean |
| return action |
|
|
| def load_vla(self, path_to_pi_model) -> QwenPI0Policy: |
| |
| from types import SimpleNamespace |
|
|
| from transformers import AutoTokenizer |
| from veomni.data.dataset import build_vla_dataset |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| '/home/yangshuai/yangshuai_ssd0/rep/VLA_pretraining/checkpoints/Qwen2.5-VL-3B-Instruct' |
| ) |
|
|
| config = SimpleNamespace() |
| config.max_state_dim = 14 |
| config.max_action_dim = 14 |
| config.tokenizer_max_length = 128 |
| config.resize_imgs_with_padding = (224, 224) |
|
|
| dataset = build_vla_dataset( |
| datasets_type='agilex', |
| repo_id= |
| '/home/yangshuai/yangshuai_ssd0/cache/huggingface/lerobot/hanging_mug-aloha-agilex_clean_50_rep', |
| config=config, |
| chunk_size=50, |
| tokenizer=tokenizer, |
| ) |
| return dataset |
|
|
| def reset(self, task_description: str) -> None: |
| self.task_description = task_description |
| if self.use_length == -1: |
| self.action_ensembler.reset() |
|
|
| self.global_step = 0 |
| self.last_action_chunk = None |
|
|
| def infer(self, observation, center_crop=True): |
| """Generates an action with the VLA policy.""" |
| action = self.vla[self.global_step]['actions'][0] |
| self.global_step += 1 |
|
|
| return dict(action=action.numpy()) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| from .websocket_policy_server import WebsocketPolicyServer |
|
|
| |
| PATH_TO_PI_MODEL = "/home/yangshuai/yangshuai_ssd0/checkpoint/qwen_pi0/8GPU_cotraining/checkpoints/global_step_35000/hf_ckpt" |
|
|
| model = QwenPiServer(PATH_TO_PI_MODEL, use_length=50) |
|
|
| |
| model_server = WebsocketPolicyServer(model, port=8000) |
| model_server.serve_forever() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|