| | import tensorflow as tf |
| | import os |
| | import torch |
| | from torch.utils.data import IterableDataset |
| |
|
| | def _parse_function(example_proto): |
| | |
| | feature_description = { |
| | 'steps/observation/rgb': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True), |
| | 'steps/observation/instruction': tf.io.FixedLenSequenceFeature([512], tf.int64, allow_missing=True), |
| | 'steps/observation/effector_translation': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True), |
| | 'steps/action': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True), |
| | 'steps/reward': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True), |
| | 'steps/is_first': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), |
| | 'steps/is_last': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), |
| | } |
| | |
| | parsed = tf.io.parse_single_example(example_proto, feature_description) |
| | |
| | |
| | def decode_images(rgb_sequence): |
| | return tf.map_fn(lambda x: tf.io.decode_jpeg(x), rgb_sequence, fn_output_signature=tf.uint8) |
| |
|
| | parsed['steps/observation/rgb'] = decode_images(parsed['steps/observation/rgb']) |
| | |
| | return parsed |
| |
|
| | class LanguageTableDataset(IterableDataset): |
| | def __init__(self, data_dir, num_shards=None): |
| | self.data_dir = data_dir |
| | self.file_pattern = os.path.join(data_dir, "language_table-train.tfrecord*") |
| | self.files = tf.io.gfile.glob(self.file_pattern) |
| | if num_shards: |
| | self.files = sorted(self.files)[:num_shards] |
| | |
| | def __iter__(self): |
| | dataset = tf.data.TFRecordDataset(self.files) |
| | dataset = dataset.map(_parse_function) |
| | |
| | for item in dataset: |
| | yield { |
| | 'obs': torch.from_numpy(item['steps/observation/rgb'].numpy()), |
| | 'actions': torch.from_numpy(item['steps/action'].numpy()), |
| | 'rewards': torch.from_numpy(item['steps/reward'].numpy()), |
| | 'effector': torch.from_numpy(item['steps/observation/effector_translation'].numpy()), |
| | } |
| |
|