| """Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.""" |
|
|
| import filecmp |
| import json |
| import os |
| import shutil |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import json_numpy |
| import numpy as np |
| import requests |
| import tensorflow as tf |
| import torch |
| from huggingface_hub import HfApi, hf_hub_download |
| from PIL import Image |
| from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor |
|
|
| |
| json_numpy.patch() |
|
|
| from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig |
| from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction |
| from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor |
| from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead |
| from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone |
| from prismatic.models.projectors import NoisyActionProjector, ProprioProjector |
| from prismatic.vla.constants import ( |
| ACTION_DIM, |
| ACTION_PROPRIO_NORMALIZATION_TYPE, |
| ) |
| from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType |
|
|
| |
| DATE = time.strftime("%Y_%m_%d") |
| DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") |
| DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
| OPENVLA_IMAGE_SIZE = 224 |
|
|
| |
| np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) |
|
|
|
|
| def model_is_on_hf_hub(model_path: str) -> bool: |
| """Checks whether a model path points to a model on Hugging Face Hub.""" |
| |
| try: |
| HfApi().model_info(model_path) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| def update_auto_map(pretrained_checkpoint: str) -> None: |
| """ |
| Update the AutoMap configuration in the checkpoint config.json file. |
| |
| This loads the config.json file inside the checkpoint directory and overwrites |
| the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes. |
| |
| Args: |
| pretrained_checkpoint: Path to the checkpoint directory |
| """ |
| if not os.path.isdir(pretrained_checkpoint): |
| return |
|
|
| config_path = os.path.join(pretrained_checkpoint, "config.json") |
| if not os.path.exists(config_path): |
| print(f"Warning: No config.json found at {config_path}") |
| return |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}") |
| shutil.copy2(config_path, backup_path) |
| print(f"Created backup of original config at: {os.path.abspath(backup_path)}") |
|
|
| |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| config["auto_map"] = { |
| "AutoConfig": "configuration_prismatic.OpenVLAConfig", |
| "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction", |
| } |
|
|
| |
| with open(config_path, "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| print(f"Updated config.json at: {os.path.abspath(config_path)}") |
| print("Changes made:") |
| print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"') |
| print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"') |
|
|
|
|
| def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool: |
| """ |
| Check if two files are identical in content. |
| |
| Args: |
| path1: Path to the first file |
| path2: Path to the second file |
| |
| Returns: |
| bool: True if files are identical, False otherwise |
| """ |
| path1, path2 = Path(path1), Path(path2) |
|
|
| |
| if path1.stat().st_size != path2.stat().st_size: |
| return False |
|
|
| |
| return filecmp.cmp(path1, path2, shallow=False) |
|
|
|
|
| def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None: |
| """ |
| Handle syncing of files between current directory and checkpoint. |
| |
| Creates backups if files exist but differ, and copies current versions to checkpoint. |
| |
| Args: |
| curr_filepath: Path to the current file version |
| checkpoint_filepath: Path where the file should be in the checkpoint |
| file_type: Description of the file type for logging |
| """ |
| if os.path.exists(checkpoint_filepath): |
| |
| match = check_identical_files(curr_filepath, checkpoint_filepath) |
|
|
| if not match: |
| print( |
| "\n------------------------------------------------------------------------------------------------\n" |
| f"Found mismatch between:\n" |
| f"Current: {curr_filepath}\n" |
| f"Checkpoint: {checkpoint_filepath}\n" |
| ) |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| backup_path = f"{checkpoint_filepath}.back.{timestamp}" |
| shutil.copy2(checkpoint_filepath, backup_path) |
| print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}") |
|
|
| |
| shutil.copy2(curr_filepath, checkpoint_filepath) |
| print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}") |
| print( |
| f"Changes complete. The checkpoint will now use the current version of {file_type}" |
| "\n------------------------------------------------------------------------------------------------\n" |
| ) |
| else: |
| |
| shutil.copy2(curr_filepath, checkpoint_filepath) |
| print( |
| "\n------------------------------------------------------------------------------------------------\n" |
| f"No {file_type} found in checkpoint directory.\n" |
| f"Copied current version from: {curr_filepath}\n" |
| f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}" |
| "\n------------------------------------------------------------------------------------------------\n" |
| ) |
|
|
|
|
| def check_model_logic_mismatch(pretrained_checkpoint: str) -> None: |
| """ |
| Check and sync model logic files between current code and checkpoint. |
| |
| Handles the relationship between current and checkpoint versions of both |
| modeling_prismatic.py and configuration_prismatic.py: |
| - If checkpoint file exists and differs: creates backup and copies current version |
| - If checkpoint file doesn't exist: copies current version |
| |
| Args: |
| pretrained_checkpoint: Path to the checkpoint directory |
| """ |
| if not os.path.isdir(pretrained_checkpoint): |
| return |
|
|
| |
| curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None} |
|
|
| for root, _, files in os.walk("./prismatic/"): |
| for filename in curr_files.keys(): |
| if filename in files and curr_files[filename] is None: |
| curr_files[filename] = os.path.join(root, filename) |
|
|
| |
| for filename, curr_filepath in curr_files.items(): |
| if curr_filepath is None: |
| print(f"WARNING: `{filename}` is not found anywhere in the current directory.") |
| continue |
|
|
| checkpoint_filepath = os.path.join(pretrained_checkpoint, filename) |
| _handle_file_sync(curr_filepath, checkpoint_filepath, filename) |
|
|
|
|
| def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str: |
| """ |
| Find a specific checkpoint file matching a pattern. |
| |
| Args: |
| pretrained_checkpoint: Path to the checkpoint directory |
| file_pattern: String pattern to match in filenames |
| |
| Returns: |
| str: Path to the matching checkpoint file |
| |
| Raises: |
| AssertionError: If no files or multiple files match the pattern |
| """ |
| assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}" |
|
|
| checkpoint_files = [] |
| for filename in os.listdir(pretrained_checkpoint): |
| if file_pattern in filename and "checkpoint" in filename: |
| full_path = os.path.join(pretrained_checkpoint, filename) |
| checkpoint_files.append(full_path) |
|
|
| assert len(checkpoint_files) == 1, ( |
| f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}" |
| ) |
|
|
| return checkpoint_files[0] |
|
|
|
|
| def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]: |
| """ |
| Load a component's state dict from checkpoint and handle DDP prefix if present. |
| |
| Args: |
| checkpoint_path: Path to the checkpoint file |
| |
| Returns: |
| Dict: The processed state dictionary for loading |
| """ |
| state_dict = torch.load(checkpoint_path, weights_only=True) |
|
|
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith("module."): |
| new_state_dict[k[7:]] = v |
| else: |
| new_state_dict[k] = v |
|
|
| return new_state_dict |
|
|
|
|
| def get_vla(cfg: Any) -> torch.nn.Module: |
| """ |
| Load and initialize the VLA model from checkpoint. |
| |
| Args: |
| cfg: Configuration object |
| |
| Returns: |
| torch.nn.Module: The initialized VLA model |
| """ |
| print("Instantiating pretrained VLA policy...") |
|
|
| |
| |
| |
| |
| |
| if not model_is_on_hf_hub(cfg.pretrained_checkpoint): |
| |
| AutoConfig.register("openvla", OpenVLAConfig) |
| AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) |
| AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) |
| AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) |
|
|
| |
| update_auto_map(cfg.pretrained_checkpoint) |
| check_model_logic_mismatch(cfg.pretrained_checkpoint) |
|
|
| |
| vla = AutoModelForVision2Seq.from_pretrained( |
| cfg.pretrained_checkpoint, |
| |
| torch_dtype=torch.bfloat16, |
| load_in_8bit=cfg.load_in_8bit, |
| load_in_4bit=cfg.load_in_4bit, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| ) |
|
|
| |
| if cfg.use_film: |
| vla = _apply_film_to_vla(vla, cfg) |
|
|
| |
| vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) |
|
|
| vla.eval() |
|
|
| |
| if not cfg.load_in_8bit and not cfg.load_in_4bit: |
| vla = vla.to(DEVICE) |
|
|
| |
| _load_dataset_stats(vla, cfg.pretrained_checkpoint) |
|
|
| return vla |
|
|
|
|
| def _apply_film_to_vla(vla: torch.nn.Module, cfg: Any) -> torch.nn.Module: |
| """ |
| Apply FiLM (Feature-wise Linear Modulation) to the VLA vision backbone. |
| |
| Args: |
| vla: The VLA model |
| cfg: Configuration object with model parameters |
| |
| Returns: |
| torch.nn.Module: VLA model with FiLM applied |
| """ |
| from peft import LoraConfig, get_peft_model |
|
|
| |
| lora_config = LoraConfig( |
| r=cfg.lora_rank, |
| lora_alpha=min(cfg.lora_rank, 16), |
| lora_dropout=0.0, |
| target_modules="all-linear", |
| init_lora_weights="gaussian", |
| ) |
| vla = get_peft_model(vla, lora_config) |
|
|
| |
| new_vision_backbone = FiLMedPrismaticVisionBackbone( |
| vision_backbone=vla.vision_backbone, llm_dim=vla.llm_dim, |
| ) |
| vla.model.vision_backbone = new_vision_backbone |
|
|
| |
| checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "vision_backbone") |
| state_dict = torch.load(checkpoint_path, weights_only=True) |
| vla.model.vision_backbone.load_state_dict(state_dict) |
|
|
| |
| vla = vla.model |
| vla.vision_backbone = vla.vision_backbone.to(torch.bfloat16) |
|
|
| return vla |
|
|
|
|
| def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None: |
| """ |
| Load dataset statistics used during training for action normalization. |
| |
| Args: |
| vla: The VLA model |
| checkpoint_path: Path to the checkpoint directory |
| """ |
| if model_is_on_hf_hub(checkpoint_path): |
| |
| dataset_statistics_path = hf_hub_download( |
| repo_id=checkpoint_path, |
| filename="dataset_statistics.json", |
| ) |
| else: |
| dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json") |
| if os.path.isfile(dataset_statistics_path): |
| with open(dataset_statistics_path, "r") as f: |
| norm_stats = json.load(f) |
| vla.norm_stats = norm_stats |
| else: |
| print( |
| "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" |
| "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." |
| "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." |
| ) |
|
|
|
|
| def get_processor(cfg: Any) -> AutoProcessor: |
| """ |
| Get the VLA model's Hugging Face processor. |
| |
| Args: |
| cfg: Configuration object with model parameters |
| |
| Returns: |
| AutoProcessor: The model's processor |
| """ |
| return AutoProcessor.from_pretrained(cfg.pretrained_checkpoint, trust_remote_code=True) |
|
|
|
|
| def get_proprio_projector(cfg: Any, llm_dim: int, proprio_dim: int) -> ProprioProjector: |
| """ |
| Get proprioception projector for the VLA model. |
| |
| Args: |
| cfg: Configuration object with model parameters |
| llm_dim: Dimension of the language model |
| proprio_dim: Dimension of proprioception data |
| |
| Returns: |
| ProprioProjector: The initialized proprio projector |
| """ |
| |
| proprio_projector = ProprioProjector( |
| llm_dim=llm_dim, |
| proprio_dim=proprio_dim, |
| ).to(DEVICE) |
| proprio_projector = proprio_projector.to(torch.bfloat16).to(DEVICE) |
| proprio_projector.eval() |
|
|
| |
| if model_is_on_hf_hub(cfg.pretrained_checkpoint): |
| model_path_to_proprio_projector_name = { |
| "moojink/openvla-7b-oft-finetuned-libero-spatial": "proprio_projector--150000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-object": "proprio_projector--150000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-goal": "proprio_projector--50000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-10": "proprio_projector--150000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "proprio_projector--300000_checkpoint.pt", |
| } |
| if cfg.pretrained_checkpoint not in model_path_to_proprio_projector_name.keys(): |
| raise ValueError("Unsupported HF Hub pretrained checkpoint found!") |
| |
| proprio_projector_path = hf_hub_download( |
| repo_id=cfg.pretrained_checkpoint, filename=model_path_to_proprio_projector_name[cfg.pretrained_checkpoint] |
| ) |
| state_dict = load_component_state_dict(proprio_projector_path) |
| proprio_projector.load_state_dict(state_dict) |
| else: |
| checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "proprio_projector") |
| state_dict = load_component_state_dict(checkpoint_path) |
| proprio_projector.load_state_dict(state_dict) |
|
|
| return proprio_projector |
|
|
|
|
| def get_noisy_action_projector(cfg: Any, llm_dim: int) -> NoisyActionProjector: |
| """ |
| Get noisy action projector for diffusion-based action prediction. |
| |
| Args: |
| cfg: Configuration object with model parameters |
| llm_dim: Dimension of the language model |
| |
| Returns: |
| NoisyActionProjector: The initialized noisy action projector |
| """ |
| |
| noisy_action_projector = NoisyActionProjector( |
| llm_dim=llm_dim, |
| ).to(DEVICE) |
| noisy_action_projector = noisy_action_projector.to(torch.bfloat16).to(DEVICE) |
| noisy_action_projector.eval() |
|
|
| |
| checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "noisy_action_projector") |
| state_dict = load_component_state_dict(checkpoint_path) |
| noisy_action_projector.load_state_dict(state_dict) |
|
|
| return noisy_action_projector |
|
|
|
|
| def get_action_head(cfg: Any, llm_dim: int) -> Union[L1RegressionActionHead, DiffusionActionHead]: |
| """ |
| Get action head for continuous value prediction. |
| |
| Args: |
| cfg: Configuration object with model parameters |
| llm_dim: Dimension of the language model |
| |
| Returns: |
| Union[L1RegressionActionHead, DiffusionActionHead]: The initialized action head |
| |
| Raises: |
| AssertionError: If both L1 regression and diffusion are specified |
| """ |
| assert not (cfg.use_l1_regression and cfg.use_diffusion), "Cannot use both L1 regression and diffusion action head!" |
|
|
| |
| if cfg.use_l1_regression: |
| action_head = L1RegressionActionHead(input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM) |
| elif cfg.use_diffusion: |
| action_head = DiffusionActionHead( |
| input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM, num_diffusion_steps_train=cfg.num_diffusion_steps_train |
| ) |
| |
| action_head.noise_scheduler.set_timesteps(cfg.num_diffusion_steps_inference) |
| else: |
| raise ValueError("Either use_l1_regression or use_diffusion must be True") |
|
|
| action_head = action_head.to(torch.bfloat16).to(DEVICE) |
| action_head.eval() |
|
|
| |
| if model_is_on_hf_hub(cfg.pretrained_checkpoint): |
| model_path_to_action_head_name = { |
| "moojink/openvla-7b-oft-finetuned-libero-spatial": "action_head--150000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-object": "action_head--150000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-goal": "action_head--50000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-10": "action_head--150000_checkpoint.pt", |
| "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "action_head--300000_checkpoint.pt", |
| } |
| if cfg.pretrained_checkpoint not in model_path_to_action_head_name.keys(): |
| raise ValueError("Unsupported HF Hub pretrained checkpoint found!") |
| |
| action_head_path = hf_hub_download( |
| repo_id=cfg.pretrained_checkpoint, filename=model_path_to_action_head_name[cfg.pretrained_checkpoint] |
| ) |
| state_dict = load_component_state_dict(action_head_path) |
| action_head.load_state_dict(state_dict) |
| else: |
| checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "action_head") |
| state_dict = load_component_state_dict(checkpoint_path) |
| action_head.load_state_dict(state_dict) |
|
|
| return action_head |
|
|
|
|
| def resize_image_for_policy(img: np.ndarray, resize_size: Union[int, Tuple[int, int]]) -> np.ndarray: |
| """ |
| Resize an image to match the policy's expected input size. |
| |
| Uses the same resizing scheme as in the training data pipeline for distribution matching. |
| |
| Args: |
| img: Numpy array containing the image |
| resize_size: Target size as int (square) or (height, width) tuple |
| |
| Returns: |
| np.ndarray: The resized image |
| """ |
| assert isinstance(resize_size, int) or isinstance(resize_size, tuple) |
| if isinstance(resize_size, int): |
| resize_size = (resize_size, resize_size) |
|
|
| |
| img = tf.image.encode_jpeg(img) |
| img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) |
| img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) |
| img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) |
|
|
| return img.numpy() |
|
|
|
|
| def crop_and_resize(image: tf.Tensor, crop_scale: float, batch_size: int) -> tf.Tensor: |
| """ |
| Center-crop an image and resize it back to original dimensions. |
| |
| Uses the same logic as in the training data pipeline for distribution matching. |
| |
| Args: |
| image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) with values in [0,1] |
| crop_scale: Area of center crop relative to original image |
| batch_size: Batch size |
| |
| Returns: |
| tf.Tensor: The cropped and resized image |
| """ |
| |
| assert image.shape.ndims in (3, 4), "Image must be 3D or 4D tensor" |
| expanded_dims = False |
| if image.shape.ndims == 3: |
| image = tf.expand_dims(image, axis=0) |
| expanded_dims = True |
|
|
| |
| new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,)) |
| new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,)) |
|
|
| |
| height_offsets = (1 - new_heights) / 2 |
| width_offsets = (1 - new_widths) / 2 |
| bounding_boxes = tf.stack( |
| [ |
| height_offsets, |
| width_offsets, |
| height_offsets + new_heights, |
| width_offsets + new_widths, |
| ], |
| axis=1, |
| ) |
|
|
| |
| image = tf.image.crop_and_resize( |
| image, bounding_boxes, tf.range(batch_size), (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE) |
| ) |
|
|
| |
| if expanded_dims: |
| image = image[0] |
|
|
| return image |
|
|
|
|
| def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image: |
| """ |
| Center crop an image to match training data distribution. |
| |
| Args: |
| image: Input image (PIL or numpy array) |
| |
| Returns: |
| Image.Image: Cropped PIL Image |
| """ |
| batch_size = 1 |
| crop_scale = 0.9 |
|
|
| |
| if not isinstance(image, tf.Tensor): |
| image = tf.convert_to_tensor(np.array(image)) |
|
|
| orig_dtype = image.dtype |
|
|
| |
| image = tf.image.convert_image_dtype(image, tf.float32) |
|
|
| |
| image = crop_and_resize(image, crop_scale, batch_size) |
|
|
| |
| image = tf.clip_by_value(image, 0, 1) |
| image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True) |
|
|
| |
| return Image.fromarray(image.numpy()).convert("RGB") |
|
|
|
|
| def check_image_format(image: Any) -> None: |
| """ |
| Validate input image format. |
| |
| Args: |
| image: Image to check |
| |
| Raises: |
| AssertionError: If image format is invalid |
| """ |
| is_numpy_array = isinstance(image, np.ndarray) |
| has_correct_shape = len(image.shape) == 3 and image.shape[-1] == 3 |
| has_correct_dtype = image.dtype == np.uint8 |
|
|
| assert is_numpy_array and has_correct_shape and has_correct_dtype, ( |
| "Incorrect image format detected! Make sure that the input image is a " |
| "numpy array with shape (H, W, 3) and dtype np.uint8!" |
| ) |
|
|
|
|
| def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray: |
| """ |
| Normalize proprioception data to match training distribution. |
| |
| Args: |
| proprio: Raw proprioception data |
| norm_stats: Normalization statistics |
| |
| Returns: |
| np.ndarray: Normalized proprioception data |
| """ |
| if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: |
| mask = norm_stats.get("mask", np.ones_like(norm_stats["min"], dtype=bool)) |
| proprio_high, proprio_low = np.array(norm_stats["max"]), np.array(norm_stats["min"]) |
| elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: |
| mask = norm_stats.get("mask", np.ones_like(norm_stats["q01"], dtype=bool)) |
| proprio_high, proprio_low = np.array(norm_stats["q99"]), np.array(norm_stats["q01"]) |
| else: |
| raise ValueError("Unsupported action/proprio normalization type detected!") |
|
|
| normalized_proprio = np.clip( |
| np.where( |
| mask, |
| 2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1, |
| proprio, |
| ), |
| a_min=-1.0, |
| a_max=1.0, |
| ) |
|
|
| return normalized_proprio |
|
|
|
|
| def prepare_images_for_vla(images: List[np.ndarray], cfg: Any) -> List[Image.Image]: |
| """ |
| Prepare images for VLA input by resizing and cropping as needed. |
| |
| Args: |
| images: List of input images as numpy arrays |
| cfg: Configuration object with parameters |
| |
| Returns: |
| List[Image.Image]: Processed images ready for the model |
| """ |
| processed_images = [] |
|
|
| for image in images: |
| |
| check_image_format(image) |
|
|
| |
| if image.shape != (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE, 3): |
| image = resize_image_for_policy(image, OPENVLA_IMAGE_SIZE) |
|
|
| |
| pil_image = Image.fromarray(image).convert("RGB") |
|
|
| |
| if cfg.center_crop: |
| pil_image = center_crop_image(pil_image) |
|
|
| processed_images.append(pil_image) |
|
|
| return processed_images |
|
|
|
|
| def get_vla_action( |
| cfg: Any, |
| vla: torch.nn.Module, |
| processor: Any, |
| obs: Dict[str, Any], |
| task_label: str, |
| action_head: Optional[torch.nn.Module] = None, |
| proprio_projector: Optional[torch.nn.Module] = None, |
| noisy_action_projector: Optional[torch.nn.Module] = None, |
| use_film: bool = False, |
| ) -> List[np.ndarray]: |
| """ |
| Generate action predictions with the VLA policy. |
| |
| Args: |
| cfg: Configuration object with parameters |
| vla: The VLA model |
| processor: Model processor for inputs |
| obs: Observation dictionary |
| task_label: Text description of the task |
| action_head: Optional action head for continuous actions |
| proprio_projector: Optional proprioception projector |
| noisy_action_projector: Optional noisy action projector for diffusion |
| use_film: Whether to use FiLM |
| |
| Returns: |
| List[np.ndarray]: Predicted actions |
| """ |
| with torch.inference_mode(): |
|
|
| |
| all_images = [obs["full_image"]] |
| if cfg.num_images_in_input > 1: |
| all_images.extend([obs[k] for k in obs.keys() if "wrist" in k]) |
|
|
| |
| all_images = prepare_images_for_vla(all_images, cfg) |
|
|
| |
| primary_image = all_images.pop(0) |
|
|
| |
| prompt = f"In: What action should the robot take to {task_label.lower()}?\nOut:" |
|
|
| |
| inputs = processor(prompt, primary_image).to(DEVICE, dtype=torch.bfloat16) |
|
|
| |
| if all_images: |
| all_wrist_inputs = [ |
| processor(prompt, image_wrist).to(DEVICE, dtype=torch.bfloat16) for image_wrist in all_images |
| ] |
| |
| primary_pixel_values = inputs["pixel_values"] |
| all_wrist_pixel_values = [wrist_inputs["pixel_values"] for wrist_inputs in all_wrist_inputs] |
| inputs["pixel_values"] = torch.cat([primary_pixel_values] + all_wrist_pixel_values, dim=1) |
|
|
| |
| proprio = None |
| if cfg.use_proprio: |
| proprio = obs["state"] |
| proprio_norm_stats = vla.norm_stats[cfg.unnorm_key]["proprio"] |
| obs["state"] = normalize_proprio(proprio, proprio_norm_stats) |
| proprio = obs["state"] |
|
|
| |
| if action_head is None: |
| |
| action, _ = vla.predict_action(**inputs, unnorm_key=cfg.unnorm_key, do_sample=False) |
| else: |
| |
| action, _ = vla.predict_action( |
| **inputs, |
| unnorm_key=cfg.unnorm_key, |
| do_sample=False, |
| proprio=proprio, |
| proprio_projector=proprio_projector, |
| noisy_action_projector=noisy_action_projector, |
| action_head=action_head, |
| use_film=use_film, |
| ) |
|
|
| |
| return [action[i] for i in range(len(action))] |
|
|
|
|
| def get_action_from_server( |
| observation: Dict[str, Any], server_endpoint: str = "http://0.0.0.0:8777/act" |
| ) -> Dict[str, Any]: |
| """ |
| Get VLA action from remote inference server. |
| |
| Args: |
| observation: Observation data to send to server |
| server_endpoint: URL of the inference server |
| |
| Returns: |
| Dict[str, Any]: Action response from server |
| """ |
| response = requests.post( |
| server_endpoint, |
| json=observation, |
| ) |
| return response.json() |
|
|