| import argparse |
| import importlib |
| import tqdm |
| import numpy as np |
| import os |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
| import tensorflow as tf |
| import tensorflow_datasets as tfds |
|
|
| from example_transform.transform import transform_step |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('dataset_name', help='name of the dataset to visualize') |
| args = parser.parse_args() |
|
|
|
|
| TARGET_SPEC = { |
| 'observation': { |
| 'image': {'shape': (128, 128, 3), |
| 'dtype': np.uint8, |
| 'range': (0, 255)} |
| }, |
| 'action': {'shape': (8,), |
| 'dtype': np.float32, |
| 'range': [(-1, -1, -1, -2*np.pi, -2*np.pi, -2*np.pi, -1, 0), |
| (+1, +1, +1, +2*np.pi, +2*np.pi, +2*np.pi, +1, 1)]}, |
| 'discount': {'shape': (), |
| 'dtype': np.float32, |
| 'range': (0, 1)}, |
| 'reward': {'shape': (), |
| 'dtype': np.float32, |
| 'range': (0, 1)}, |
| 'is_first': {'shape': (), |
| 'dtype': np.bool_, |
| 'range': None}, |
| 'is_last': {'shape': (), |
| 'dtype': np.bool_, |
| 'range': None}, |
| 'is_terminal': {'shape': (), |
| 'dtype': np.bool_, |
| 'range': None}, |
| 'language_instruction': {'shape': (), |
| 'dtype': str, |
| 'range': None}, |
| 'language_embedding': {'shape': (512,), |
| 'dtype': np.float32, |
| 'range': None}, |
| } |
|
|
|
|
| def check_elements(target, values): |
| """Recursively checks that elements in `values` match the TARGET_SPEC.""" |
| for elem in target: |
| if isinstance(values[elem], dict): |
| check_elements(target[elem], values[elem]) |
| else: |
| if target[elem]['shape']: |
| if tuple(values[elem].shape) != target[elem]['shape']: |
| raise ValueError( |
| f"Shape of {elem} should be {target[elem]['shape']} but is {tuple(values[elem].shape)}") |
| if not isinstance(values[elem], bytes) and values[elem].dtype != target[elem]['dtype']: |
| raise ValueError(f"Dtype of {elem} should be {target[elem]['dtype']} but is {values[elem].dtype}") |
| if target[elem]['range'] is not None: |
| if isinstance(target[elem]['range'], list): |
| for vmin, vmax, val in zip(target[elem]['range'][0], |
| target[elem]['range'][1], |
| values[elem]): |
| if not (val >= vmin and val <= vmax): |
| raise ValueError( |
| f"{elem} is out of range. Should be in {target[elem]['range']} but is {values[elem]}.") |
| else: |
| if not (np.all(values[elem] >= target[elem]['range'][0]) |
| and np.all(values[elem] <= target[elem]['range'][1])): |
| raise ValueError( |
| f"{elem} is out of range. Should be in {target[elem]['range']} but is {values[elem]}.") |
|
|
|
|
| |
| dataset_name = args.dataset_name |
| print(f"Visualizing data from dataset: {dataset_name}") |
| module = importlib.import_module(dataset_name) |
| ds = tfds.load(dataset_name, split='train') |
| ds = ds.shuffle(100) |
|
|
| for episode in tqdm.tqdm(ds.take(50)): |
| steps = tfds.as_numpy(episode['steps']) |
| for step in steps: |
| transformed_step = transform_step(step) |
| check_elements(TARGET_SPEC, transformed_step) |
| print("Test passed! You're ready to submit!") |
|
|