"""Utils for evaluating the OpenVLA policy.""" import json import os import time from collections import deque from glob import glob from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import numpy as np import torch import torch.nn.functional as F import torchvision from lerobot.configs.policies import PreTrainedConfig from PIL import Image from safetensors import safe_open from safetensors.torch import load_file from torch import Tensor, nn from tqdm import tqdm from transformers import ( AutoConfig, AutoProcessor, PretrainedConfig, PreTrainedModel, ) from transformers.models.auto.tokenization_auto import AutoTokenizer from veomni.models.vla.pi0 import PI0Policy, QwenPI0Policy IMAGE_KEYS = ( "base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb", ) class AdaptiveEnsembler: def __init__(self, pred_action_horizon, adaptive_ensemble_alpha=0.0): self.pred_action_horizon = pred_action_horizon self.action_history = deque(maxlen=self.pred_action_horizon) self.adaptive_ensemble_alpha = adaptive_ensemble_alpha def reset(self): self.action_history.clear() def ensemble_action(self, cur_action): self.action_history.append(cur_action) num_actions = len(self.action_history) if cur_action.ndim == 1: curr_act_preds = np.stack(self.action_history) else: curr_act_preds = np.stack([ pred_actions[i] for ( i, pred_actions) in zip(range(num_actions - 1, -1, -1), self.action_history) ]) # calculate cosine similarity between the current prediction and all previous predictions ref = curr_act_preds[num_actions - 1, :] previous_pred = curr_act_preds dot_product = np.sum(previous_pred * ref, axis=1) norm_previous_pred = np.linalg.norm(previous_pred, axis=1) norm_ref = np.linalg.norm(ref) cos_similarity = dot_product / (norm_previous_pred * norm_ref + 1e-7) # compute the weights for each prediction weights = np.exp(self.adaptive_ensemble_alpha * cos_similarity) weights = weights / weights.sum() # compute the weighted average across all predictions for this timestep cur_action = np.sum(weights[:, None] * curr_act_preds, axis=0) return cur_action def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image: crop_scale = 0.9 side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale out_size = (224, 224) # Convert input to PIL Image if isinstance(image, np.ndarray): arr = image if arr.dtype.kind == "f": # If floats likely in [0,1], map to [0,255] if arr.max() <= 1.0 and arr.min() >= 0.0: arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8) else: arr = np.clip(arr, 0.0, 255.0).astype(np.uint8) elif arr.dtype == np.uint16: # Map 16-bit to 8-bit arr = (arr / 257).astype(np.uint8) elif arr.dtype != np.uint8: arr = arr.astype(np.uint8) pil = Image.fromarray(arr) elif isinstance(image, Image.Image): pil = image else: raise TypeError("image must be a numpy array or PIL.Image.Image") # Force RGB for consistent output pil = pil.convert("RGB") W, H = pil.size # Compute centered crop box (integer pixels) crop_w = max(1, int(round(W * side_scale))) crop_h = max(1, int(round(H * side_scale))) left = (W - crop_w) // 2 top = (H - crop_h) // 2 right = left + crop_w bottom = top + crop_h cropped = pil.crop((left, top, right, bottom)) resized = cropped.resize(out_size, resample=Image.BILINEAR) return resized def resize_with_pad(img, width, height, pad_value=-1): # assume no-op when width height fits already if img.ndim != 4: raise ValueError(f"(b,c,h,w) expected, but {img.shape}") # channel last to channel first if necessary if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3): img = img.permute(0, 3, 1, 2) # (B, H, W, C) → (B, C, H, W) cur_height, cur_width = img.shape[2:] ratio = max(cur_width / width, cur_height / height) resized_height = int(cur_height / ratio) resized_width = int(cur_width / ratio) resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False) pad_height = max(0, int(height - resized_height)) pad_width = max(0, int(width - resized_width)) # pad on left and top of image padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) return padded_img class PolicyPreprocessMixin: """ A mixin class that provides preprocessing utilities for observations. Can be mixed into any policy class to add image, state, action, language handling. """ def prepare_images(self, observation: dict[str, Tensor]): """Normalize, resize, and pad images and stack them into a tensor. Args: observation (dict[str, Tensor]) Returns: images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0] img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing """ dtype = observation["state"].dtype bsize = observation["state"].shape[0] device = observation["state"].device images, img_masks = [], [] for key in IMAGE_KEYS: if key in observation["image"]: # resize, pad, and normalize img = observation["image"][key] # torch.Size([1, 3, 224, 224]) if isinstance(img, np.ndarray): img = torch.from_numpy(img) img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) img = self.image_processor(img)['pixel_values'] images.append(img) img_masks.append(True) else: img = np.zeros_like(img) images.append(img) img_masks.append(False) # import ipdb; ipdb.set_trace() if isinstance(images[0], torch.Tensor): images = torch.stack(images, dim=0).to(device=device) elif isinstance(images[0], np.ndarray): images = torch.from_numpy(np.stack(images, axis=0)).to( device=device) # torch.Size([3, 256, 1176]) img_masks = torch.tensor(img_masks, dtype=torch.bool).to(device=device) # (*b, n) return images, img_masks def prepare_state(self, observation): state = torch.from_numpy(observation["state"]) if isinstance(state, np.ndarray): state = torch.from_numpy(state) state = F.pad(state, (0, self.config.max_state_dim - state.shape[1])) return state def prepare_language(self, observation: dict[str, Tensor]): """If `prompt` is provided, modify it to PaliGemma format and tokenize it. If `lang_tokens` and `lang_masks` are provided, use them directly. PaliGemma expects prefix prompts to be formatted as: .... prompt , where uses `\\n`. So here we format the prompt to start with `` and end with `\\n`. Later, we will concatenate the images and language tokens into a single sequence. Args: observation (dict[str, Tensor]) Returns: lang_tokens (torch.Tensor): (*b, l) language tokens lang_masks (torch.Tensor): (*b, l) masks for language tokens, True if token is present, False if missing """ lang_tokens = observation.get("lang_tokens", None) lang_masks = observation.get("lang_masks", None) prompt = observation.get("prompt", None) # either provide `prompt` or (`lang_tokens`, `lang_masks`) if prompt is None and (lang_tokens is None or lang_masks is None): raise ValueError( "Either 'prompt' or ('lang_tokens', 'lang_masks') must be provided in the observation." ) device = observation["state"].device if prompt is not None and (lang_tokens is None or lang_masks is None): prompt = [ p if p.startswith("") else f"{p}" for p in prompt ] prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt] tokenized_prompt = self.language_tokenizer.__call__( prompt, padding="max_length", padding_side="right", max_length=self.config.tokenizer_max_length, return_tensors="pt", ) lang_tokens = tokenized_prompt["input_ids"].to(device=device) lang_masks = tokenized_prompt["attention_mask"].to( device=device, dtype=torch.bool) else: lang_tokens = observation["lang_tokens"].to(device=device) lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool) return lang_tokens, lang_masks @torch.no_grad def select_action(self, observation: dict[str, Tensor], noise: Tensor | None = None): self.eval() images, img_masks = self.prepare_images(observation) state = self.prepare_state(observation) lang_tokens, lang_masks = self.prepare_language(observation) device = 'cuda' dtype = torch.bfloat16 actions = self.model.sample_actions( images.to(dtype=dtype, device=device), img_masks.to(device=device), lang_tokens.to(device=device), lang_masks.to(device=device), state.to(dtype=dtype, device=device), ) return actions class QwenPI0InferencePolicy(PolicyPreprocessMixin, QwenPI0Policy): pass # Only combine necessary functions class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy): pass # Only combine necessary functions def merge_qwen_config(policy_config, qwen_config): if hasattr(qwen_config, 'to_dict'): config_dict = qwen_config.to_dict() else: config_dict = qwen_config text_keys = { "hidden_size", "intermediate_size", "num_hidden_layers", "num_attention_heads", "num_key_value_heads", "rms_norm_eps", "rope_theta", "vocab_size", "max_position_embeddings", "hidden_act", "tie_word_embeddings", "tokenizer_path", } for key in text_keys: if key in config_dict: setattr(policy_config, key, config_dict[key]) print(f"✅ Merged LLM: {key} = {config_dict[key]}") if "vision_config" in config_dict: policy_config.vision_config = qwen_config.vision_config else: print("⚠️ Warning: 'vision_config' not found in qwen_config!") return policy_config class QwenPiServer: ''' policy wrapper to support action ensemble or chunk execution ''' def __init__( self, path_to_pi_model="", adaptive_ensemble_alpha=0.1, action_ensemble_horizon=8, use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble use_bf16=True, ) -> None: self.adaptive_ensemble_alpha = adaptive_ensemble_alpha self.action_ensemble_horizon = action_ensemble_horizon self.use_length = use_length self.task_description = None self.action_ensembler = AdaptiveEnsembler(self.action_ensemble_horizon, self.adaptive_ensemble_alpha) self.vla = self.load_vla(path_to_pi_model) self.vla = self.vla self.global_step = 0 self.last_action_chunk = None def init_norm( self, states_path='/home/yangshuai/yangshuai_ssd0/checkpoint/qwen_pi0/norm_stats.json', state_dim=14, action_dim=14): ''' TODO: show be rewritten as a dict ''' with open(states_path) as f: norm_stats = json.load(f)['hanging_mug-aloha-agilex_clean_50_rep'] self.state_mean = np.array( norm_stats["norm_stats"]["state"]["mean"][:state_dim], dtype=np.float32) self.state_std = np.array( norm_stats["norm_stats"]["state"]["std"][:state_dim], dtype=np.float32) self.action_mean = np.array( norm_stats["norm_stats"]["actions"]["mean"][:action_dim], dtype=np.float32) self.action_std = np.array( norm_stats["norm_stats"]["actions"]["std"][:action_dim], dtype=np.float32) def state_normalizer(self, unnorm_state): state = (unnorm_state - self.state_mean) / (self.state_std + 1e-6) return state def action_unnormalizer(self, norm_action): action = norm_action * (self.action_std + 1e-6) + self.action_mean return action def load_vla(self, path_to_pi_model) -> QwenPI0Policy: # load model from types import SimpleNamespace from transformers import AutoTokenizer from veomni.data.dataset import build_vla_dataset tokenizer = AutoTokenizer.from_pretrained( '/home/yangshuai/yangshuai_ssd0/rep/VLA_pretraining/checkpoints/Qwen2.5-VL-3B-Instruct' ) config = SimpleNamespace() config.max_state_dim = 14 config.max_action_dim = 14 config.tokenizer_max_length = 128 config.resize_imgs_with_padding = (224, 224) dataset = build_vla_dataset( datasets_type='agilex', repo_id= '/home/yangshuai/yangshuai_ssd0/cache/huggingface/lerobot/hanging_mug-aloha-agilex_clean_50_rep', config=config, chunk_size=50, tokenizer=tokenizer, ) return dataset def reset(self, task_description: str) -> None: self.task_description = task_description if self.use_length == -1: self.action_ensembler.reset() self.global_step = 0 self.last_action_chunk = None def infer(self, observation, center_crop=True): """Generates an action with the VLA policy.""" action = self.vla[self.global_step]['actions'][0] self.global_step += 1 return dict(action=action.numpy()) if __name__ == "__main__": from .websocket_policy_server import WebsocketPolicyServer # PATH_TO_PI_MODEL = "/home/yangshuai/yangshuai_ssd0/checkpoint/qwen_pi0/qwenpi0_libero_48token_6bs_4node_bf16vlm_fp32_fsdp2_compile_tuneV/checkpoints/global_step_30260/hf_ckpt" PATH_TO_PI_MODEL = "/home/yangshuai/yangshuai_ssd0/checkpoint/qwen_pi0/8GPU_cotraining/checkpoints/global_step_35000/hf_ckpt" model = QwenPiServer(PATH_TO_PI_MODEL, use_length=50) # To debug model with server model_server = WebsocketPolicyServer(model, port=8000) model_server.serve_forever() # # To debug model only # import torch # import numpy as np # from PIL import Image # from .image_tools import convert_to_uint8 # device = torch.device("cuda") # base_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8) # left_wrist_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8) # state = np.random.rand(1,8).astype(np.float32) # prompt = ["do something"] # observation = { # "image": { # "base_0_rgb": convert_to_uint8(base_0_rgb), # "left_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb), # "right_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb), # }, # "state": state, # "prompt": prompt, # } # model.infer(observation)