| from typing import Dict |
|
|
| import numpy as np |
| import torch |
| import math |
| import einops |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
|
|
| IMAGE_KEYS = ( |
| "base_0_rgb", |
| "left_wrist_0_rgb", |
| "right_wrist_0_rgb", |
| ) |
|
|
|
|
| def dict_apply(func, d): |
| """ |
| Apply a function to all values in a dictionary recursively. |
| If the value is a dictionary, it will apply the function to its values. |
| """ |
| for key, value in d.items(): |
| if isinstance(value, dict): |
| dict_apply(func, value) |
| else: |
| d[key] = func(value) |
| return d |
|
|
| class Normalizer: |
| def __init__( |
| self, |
| norm_stats: Dict[str, Dict[str, np.ndarray]], |
| from_file: bool=False, |
| data_type: str=None, |
| norm_type: Dict[str, str] | None = None, |
| ): |
| if from_file: |
| if data_type == 'libero': |
| norm_stats['state']['mean'] = np.array(norm_stats['state']['mean'][:8]) |
| norm_stats['state']['std'] = np.array(norm_stats['state']['std'][:8]) |
| norm_stats['actions']['mean'] = np.array(norm_stats['actions']['mean'][:7]) |
| norm_stats['actions']['std'] = np.array(norm_stats['actions']['std'][:7]) |
| elif data_type == 'robotwin': |
| norm_stats['observation.state'], norm_stats['action'] = {}, {} |
| norm_stats['observation.state']['q01'] = np.array(norm_stats['observation.state.arm.position']['q01'][:6] + norm_stats['observation.state.effector.position']['q01'][:1] + norm_stats['observation.state.arm.position']['q01'][6:] + norm_stats['observation.state.effector.position']['q01'][1:]) |
| norm_stats['observation.state']['q99'] = np.array(norm_stats['observation.state.arm.position']['q99'][:6] + norm_stats['observation.state.effector.position']['q99'][:1] + norm_stats['observation.state.arm.position']['q99'][6:] + norm_stats['observation.state.effector.position']['q99'][1:]) |
| norm_stats['action']['q01'] = np.array(norm_stats['action.arm.position']['q01'][:6] + norm_stats['action.effector.position']['q01'][:1] + norm_stats['action.arm.position']['q01'][6:] + norm_stats['action.effector.position']['q01'][1:]) |
| norm_stats['action']['q99'] = np.array(norm_stats['action.arm.position']['q99'][:6] + norm_stats['action.effector.position']['q99'][:1] + norm_stats['action.arm.position']['q99'][6:] + norm_stats['action.effector.position']['q99'][1:]) |
| elif data_type == 'robotwin_rep': |
| norm_stats['observation.state'], norm_stats['action'] = {}, {} |
| norm_stats['observation.state']['q01'] = np.array(norm_stats['observation.state.arm.position']['q01'] + norm_stats['observation.state.effector.position']['q01']) |
| norm_stats['observation.state']['q99'] = np.array(norm_stats['observation.state.arm.position']['q99'] + norm_stats['observation.state.effector.position']['q99']) |
| norm_stats['action']['q01'] = np.array(norm_stats['action.arm.position']['q01'][:6] + norm_stats['action.effector.position']['q01'][:1] + norm_stats['action.arm.position']['q01'][6:] + norm_stats['action.effector.position']['q01'][1:]) |
| norm_stats['action']['q99'] = np.array(norm_stats['action.arm.position']['q99'][:6] + norm_stats['action.effector.position']['q99'][:1] + norm_stats['action.arm.position']['q99'][6:] + norm_stats['action.effector.position']['q99'][1:]) |
| elif data_type == 'customized': |
| for key in norm_stats: |
| if isinstance(norm_stats[key], dict): |
| for sub_key in norm_stats[key]: |
| norm_stats[key][sub_key] = np.array(norm_stats[key][sub_key]) |
| self.norm_stats = norm_stats |
| else: |
| self.norm_stats = dict_apply(lambda x: x.astype(np.float32), norm_stats) |
| self.norm_type = norm_type or {} |
| self.from_file = from_file |
|
|
| def normalize(self, data: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]: |
| normalized_data = {} |
| for key, value in data.items(): |
| if key in self.norm_stats: |
| norm_type = self.norm_type.get(key, "identity") |
| if norm_type == "meanstd": |
| mean = self.norm_stats[key]["mean"] |
| std = self.norm_stats[key]["std"] |
| normalized_value = (value - mean) / (std + 1e-6) |
| elif norm_type == "bounds_99_woclip": |
| low = self.norm_stats[key]["q01"] |
| high = self.norm_stats[key]["q99"] |
| normalized_value = (value - low) / (high - low + 1e-6) * 2.0 - 1.0 |
| elif norm_type == "std": |
| std = self.norm_stats[key]["std"] |
| normalized_value = value / (std + 1e-6) |
| elif norm_type == "minmax": |
| min_val = self.norm_stats[key]["min"] |
| max_val = self.norm_stats[key]["max"] |
| normalized_value = (value - min_val) / ( |
| max_val - min_val + 1e-6 |
| ) * 2 - 1 |
| elif norm_type == "identity": |
| normalized_value = value |
| else: |
| raise ValueError( |
| f"Unknown normalization type: {norm_type}. Supported types are 'meanstd', 'minmax', and 'identity'." |
| ) |
| normalized_data[key] = normalized_value |
| else: |
| |
| normalized_data[key] = value |
| return normalized_data |
|
|
| def unnormalize(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: |
| """ |
| Unnormalize the given data using stored normalization statistics. |
| |
| Args: |
| data (Dict[str, np.ndarray]): Dictionary of normalized arrays to unnormalize. |
| |
| Returns: |
| Dict[str, np.ndarray]: Dictionary of unnormalized arrays. |
| """ |
| unnormalized_data = {} |
| for key, value in data.items(): |
| if key in self.norm_stats: |
| norm_type = self.norm_type.get(key, "identity") |
| stats = self.norm_stats[key] |
| if norm_type == "meanstd": |
| mean = stats["mean"] |
| std = stats["std"] |
| unnormalized_value = value * (std + 1e-6) + mean |
| elif norm_type == "bounds_98" or norm_type == 'bounds_98_woclip': |
| low = self.norm_stats[key]["q02"] |
| high = self.norm_stats[key]["q98"] |
| unnormalized_value = ((value + 1.0) / 2.0) * (high - low + 1e-6) + low |
| elif norm_type == "bounds_99" or norm_type == "bounds_99_woclip": |
| low = self.norm_stats[key]["q01"] |
| high = self.norm_stats[key]["q99"] |
| unnormalized_value = ((value + 1.0) / 2.0) * (high - low + 1e-6) + low |
| elif norm_type == "std": |
| std = stats["std"] |
| unnormalized_value = value * (std + 1e-6) |
| elif norm_type == "minmax": |
| min_val = stats["min"] |
| max_val = stats["max"] |
| |
| unnormalized_value = (value + 1) / 2.0 * (max_val - min_val + 1e-6) + min_val |
| elif norm_type == "identity": |
| unnormalized_value = value |
| else: |
| raise ValueError( |
| f"Unknown normalization type: {norm_type}. Supported types are 'meanstd', 'minmax', and 'identity'." |
| ) |
| unnormalized_data[key] = unnormalized_value |
| else: |
| |
| unnormalized_data[key] = value |
| return unnormalized_data |
|
|
| def resize_with_pad_item(img, width, height, pad_value=-1): |
| |
| if img.ndim != 3: |
| raise ValueError(f"(c,h,w) expected, but {img.shape}") |
|
|
| cur_height, cur_width = img.shape[1:] |
|
|
| ratio = max(cur_width / width, cur_height / height) |
| resized_height = int(cur_height / ratio) |
| resized_width = int(cur_width / ratio) |
| resized_img = F.interpolate( |
| img.unsqueeze(0), size=(resized_height, resized_width), mode="bilinear", align_corners=False |
| ).squeeze(0) |
|
|
| pad_height = max(0, int(height - resized_height)) |
| pad_width = max(0, int(width - resized_width)) |
|
|
| |
| padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) |
| return padded_img |
|
|
| def prepare_images(config, image_processor, observation: dict[str, Tensor], use_depth_align=False): |
| """Normalize, resize, and pad images and stack them into a tensor. |
| |
| Args: |
| observation (dict[str, Tensor]) |
| |
| Returns: |
| images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0] |
| img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing |
| """ |
| dtype = observation["state"].dtype |
| images, img_masks = [], [] |
| if use_depth_align: |
| pil_images = [] |
|
|
| for key in IMAGE_KEYS: |
| if key in observation["image"]: |
| |
| img = observation["image"][key] |
| assert img.ndim == 3, f"Expected 3D image, got {img.shape}" |
| pil_img = img.cpu().numpy() |
| if image_processor is None: |
| img = img.to(dtype) / 127.5 - 1.0 |
| img = resize_with_pad_item( |
| img, *config.resize_imgs_with_padding, pad_value=-1.0 |
| ) |
| else: |
| img = resize_with_pad_item( |
| img, *config.resize_imgs_with_padding, pad_value=0 |
| ) |
| img = image_processor(img)['pixel_values'] |
| images.append(img) |
| img_masks.append(True) |
| if use_depth_align: |
| pil_images.append(pil_img) |
| else: |
| |
| if image_processor is None: |
| img = torch.full_like(img, fill_value=-1.0) |
| if use_depth_align: |
| pil_img = torch.full_like(pil_img, fill_value=-1.0) |
| else: |
| img = np.zeros_like(img) |
| if use_depth_align: |
| pil_img = np.zeros_like(pil_img) |
| images.append(img) |
| if use_depth_align: |
| pil_images.append(pil_img) |
| img_masks.append(False) |
| if isinstance(images[0], torch.Tensor): |
| images = torch.stack(images, dim=0) |
| elif isinstance(images[0], np.ndarray): |
| images = torch.from_numpy(np.stack(images, axis=0)) |
| img_masks = torch.tensor(img_masks, dtype=torch.bool) |
|
|
| if use_depth_align: |
| pil_images = torch.from_numpy(np.stack(pil_images, axis=0)) |
| else: |
| pil_images = [] |
|
|
| return images, img_masks, pil_images |
|
|
| def prepare_state(config, observation: dict[str, Tensor]): |
| """Pad the state to the maximum state dimension. |
| |
| Args: |
| observation (dict[str, Tensor]) |
| |
| Returns: |
| state (torch.Tensor): (*b, max_state_dim) padded state tensor |
| """ |
| state = observation["state"] |
| state = F.pad(state, (0, config.max_state_dim - state.shape[-1])) |
| return state |
|
|
| def prepare_action(config, observation: dict[str, Tensor]): |
| """Pad the action to the maximum action dimension. |
| |
| Args: |
| observation (dict[str, Tensor]) |
| |
| Returns: |
| action (torch.Tensor): (*b, n, max_action_dim) padded action tensor |
| action_dim (int): the actual dimension of the action before padding |
| """ |
| |
| action = observation["action"] |
| action = F.pad(action, (0, config.max_action_dim - action.shape[-1])) |
| return action |
|
|
| def prepare_language(config, language_tokenizer, observation: dict[str, Tensor]): |
| """If `prompt` is provided, modify it to PaliGemma format and tokenize it. |
| If `lang_tokens` and `lang_masks` are provided, use them directly. |
| |
| PaliGemma expects prefix prompts to be formatted as: |
| <images> .... <images> <bos> prompt <sep>, where <sep> uses `\\n`. |
| So here we format the prompt to start with `<bos>` and end with `\\n`. |
| Later, we will concatenate the images and language tokens into a single sequence. |
| |
| Args: |
| observation (dict[str, Tensor]) |
| |
| Returns: |
| lang_tokens (torch.Tensor): (*b, l) language tokens |
| lang_masks (torch.Tensor): (*b, l) masks for language tokens, True if token is present, False if missing |
| """ |
| lang_tokens = observation.get("lang_tokens", None) |
| lang_masks = observation.get("lang_masks", None) |
| prompt = observation.get("prompt", None) |
|
|
| |
| if prompt is None and (lang_tokens is None or lang_masks is None): |
| raise ValueError( |
| "Either 'prompt' or ('lang_tokens', 'lang_masks') must be provided in the observation." |
| ) |
|
|
| device = observation["state"].device |
| if prompt is not None and (lang_tokens is None or lang_masks is None): |
| prompt = [p if p.startswith("<bos>") else f"<bos>{p}" for p in prompt] |
| prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt] |
| tokenized_prompt = language_tokenizer.__call__( |
| prompt, |
| padding="max_length", |
| padding_side="right", |
| max_length=config.tokenizer_max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| lang_tokens = tokenized_prompt["input_ids"].to(device=device) |
| lang_masks = tokenized_prompt["attention_mask"].to( |
| device=device, dtype=torch.bool |
| ) |
| else: |
| lang_tokens = observation["lang_tokens"].to(device=device) |
| lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool) |
|
|
| return lang_tokens.squeeze(0), lang_masks.squeeze(0) |