| | import jax.numpy as jnp |
| |
|
| | from openpi.shared import image_tools |
| |
|
| |
|
| | def test_resize_with_pad_shapes(): |
| | |
| | images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) |
| | height = 20 |
| | width = 20 |
| | resized_images = image_tools.resize_with_pad(images, height, width) |
| | assert resized_images.shape == (2, height, width, 3) |
| | assert jnp.all(resized_images == 0) |
| |
|
| | |
| | images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8) |
| | height = 15 |
| | width = 15 |
| | resized_images = image_tools.resize_with_pad(images, height, width) |
| | assert resized_images.shape == (3, height, width, 3) |
| | assert jnp.all(resized_images == 0) |
| |
|
| | |
| | images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8) |
| | height = 50 |
| | width = 50 |
| | resized_images = image_tools.resize_with_pad(images, height, width) |
| | assert resized_images.shape == (1, height, width, 3) |
| | assert jnp.all(resized_images == 0) |
| |
|
| | |
| | images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8) |
| | height = 60 |
| | width = 80 |
| | resized_images = image_tools.resize_with_pad(images, height, width) |
| | assert resized_images.shape == (1, height, width, 3) |
| | assert jnp.all(resized_images == 0) |
| |
|