File size: 2,165 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | import tensorflow as tf
import os
import torch
from torch.utils.data import IterableDataset
def _parse_function(example_proto):
# This dictionary defines the structure of the TFRecord
feature_description = {
'steps/observation/rgb': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True),
'steps/observation/instruction': tf.io.FixedLenSequenceFeature([512], tf.int64, allow_missing=True),
'steps/observation/effector_translation': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True),
'steps/action': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True),
'steps/reward': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
'steps/is_first': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'steps/is_last': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
}
parsed = tf.io.parse_single_example(example_proto, feature_description)
# Decode JPEG images
def decode_images(rgb_sequence):
return tf.map_fn(lambda x: tf.io.decode_jpeg(x), rgb_sequence, fn_output_signature=tf.uint8)
parsed['steps/observation/rgb'] = decode_images(parsed['steps/observation/rgb'])
return parsed
class LanguageTableDataset(IterableDataset):
def __init__(self, data_dir, num_shards=None):
self.data_dir = data_dir
self.file_pattern = os.path.join(data_dir, "language_table-train.tfrecord*")
self.files = tf.io.gfile.glob(self.file_pattern)
if num_shards:
self.files = sorted(self.files)[:num_shards]
def __iter__(self):
dataset = tf.data.TFRecordDataset(self.files)
dataset = dataset.map(_parse_function)
for item in dataset:
yield {
'obs': torch.from_numpy(item['steps/observation/rgb'].numpy()),
'actions': torch.from_numpy(item['steps/action'].numpy()),
'rewards': torch.from_numpy(item['steps/reward'].numpy()),
'effector': torch.from_numpy(item['steps/observation/effector_translation'].numpy()),
}
|