| from collections.abc import Sequence |
| import logging |
|
|
| import torch |
|
|
| from openpi.shared import image_tools |
|
|
| logger = logging.getLogger("openpi") |
|
|
| |
| IMAGE_KEYS = ( |
| "base_0_rgb", |
| "left_wrist_0_rgb", |
| "right_wrist_0_rgb", |
| ) |
|
|
| IMAGE_RESOLUTION = (224, 224) |
|
|
|
|
| def preprocess_observation_pytorch( |
| observation, |
| *, |
| train: bool = False, |
| image_keys: Sequence[str] = IMAGE_KEYS, |
| image_resolution: tuple[int, int] = IMAGE_RESOLUTION, |
| ): |
| """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. |
| |
| This function avoids complex type annotations that can cause torch.compile issues. |
| """ |
| if not set(image_keys).issubset(observation.images): |
| raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") |
|
|
| batch_shape = observation.state.shape[:-1] |
|
|
| out_images = {} |
| for key in image_keys: |
| image = observation.images[key] |
|
|
| |
| |
| is_channels_first = image.shape[1] == 3 |
|
|
| if is_channels_first: |
| |
| image = image.permute(0, 2, 3, 1) |
|
|
| if image.shape[1:3] != image_resolution: |
| logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") |
| image = image_tools.resize_with_pad_torch(image, *image_resolution) |
|
|
| if train: |
| |
| image = image / 2.0 + 0.5 |
|
|
| |
| if "wrist" not in key: |
| |
| height, width = image.shape[1:3] |
|
|
| |
| crop_height = int(height * 0.95) |
| crop_width = int(width * 0.95) |
|
|
| |
| max_h = height - crop_height |
| max_w = width - crop_width |
| if max_h > 0 and max_w > 0: |
| |
| start_h = torch.randint(0, max_h + 1, (1,), device=image.device) |
| start_w = torch.randint(0, max_w + 1, (1,), device=image.device) |
| image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] |
|
|
| |
| image = torch.nn.functional.interpolate( |
| image.permute(0, 3, 1, 2), |
| size=(height, width), |
| mode="bilinear", |
| align_corners=False, |
| ).permute(0, 2, 3, 1) |
|
|
| |
| |
| angle = torch.rand(1, device=image.device) * 10 - 5 |
| if torch.abs(angle) > 0.1: |
| |
| angle_rad = angle * torch.pi / 180.0 |
|
|
| |
| cos_a = torch.cos(angle_rad) |
| sin_a = torch.sin(angle_rad) |
|
|
| |
| grid_x = torch.linspace(-1, 1, width, device=image.device) |
| grid_y = torch.linspace(-1, 1, height, device=image.device) |
|
|
| |
| grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") |
|
|
| |
| grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) |
| grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) |
|
|
| |
| grid_x_rot = grid_x * cos_a - grid_y * sin_a |
| grid_y_rot = grid_x * sin_a + grid_y * cos_a |
|
|
| |
| grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) |
|
|
| image = torch.nn.functional.grid_sample( |
| image.permute(0, 3, 1, 2), |
| grid, |
| mode="bilinear", |
| padding_mode="zeros", |
| align_corners=False, |
| ).permute(0, 2, 3, 1) |
|
|
| |
| |
| |
| brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 |
| image = image * brightness_factor |
|
|
| |
| |
| contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 |
| mean = image.mean(dim=[1, 2, 3], keepdim=True) |
| image = (image - mean) * contrast_factor + mean |
|
|
| |
| |
| |
| saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 |
| gray = image.mean(dim=-1, keepdim=True) |
| image = gray + (image - gray) * saturation_factor |
|
|
| |
| image = torch.clamp(image, 0, 1) |
|
|
| |
| image = image * 2.0 - 1.0 |
|
|
| |
| if is_channels_first: |
| image = image.permute(0, 3, 1, 2) |
|
|
| out_images[key] = image |
|
|
| |
| out_masks = {} |
| for key in out_images: |
| if key not in observation.image_masks: |
| |
| out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) |
| else: |
| out_masks[key] = observation.image_masks[key] |
|
|
| |
| class SimpleProcessedObservation: |
| def __init__(self, **kwargs): |
| for key, value in kwargs.items(): |
| setattr(self, key, value) |
|
|
| return SimpleProcessedObservation( |
| images=out_images, |
| image_masks=out_masks, |
| state=observation.state, |
| tokenized_prompt=observation.tokenized_prompt, |
| tokenized_prompt_mask=observation.tokenized_prompt_mask, |
| token_ar_mask=observation.token_ar_mask, |
| token_loss_mask=observation.token_loss_mask, |
| ) |
|
|