import tensorflow as tf import os import torch from torch.utils.data import IterableDataset def _parse_function(example_proto): # This dictionary defines the structure of the TFRecord feature_description = { 'steps/observation/rgb': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True), 'steps/observation/instruction': tf.io.FixedLenSequenceFeature([512], tf.int64, allow_missing=True), 'steps/observation/effector_translation': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True), 'steps/action': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True), 'steps/reward': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 'steps/is_first': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), 'steps/is_last': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True), } parsed = tf.io.parse_single_example(example_proto, feature_description) # Decode JPEG images def decode_images(rgb_sequence): return tf.map_fn(lambda x: tf.io.decode_jpeg(x), rgb_sequence, fn_output_signature=tf.uint8) parsed['steps/observation/rgb'] = decode_images(parsed['steps/observation/rgb']) return parsed class LanguageTableDataset(IterableDataset): def __init__(self, data_dir, num_shards=None): self.data_dir = data_dir self.file_pattern = os.path.join(data_dir, "language_table-train.tfrecord*") self.files = tf.io.gfile.glob(self.file_pattern) if num_shards: self.files = sorted(self.files)[:num_shards] def __iter__(self): dataset = tf.data.TFRecordDataset(self.files) dataset = dataset.map(_parse_function) for item in dataset: yield { 'obs': torch.from_numpy(item['steps/observation/rgb'].numpy()), 'actions': torch.from_numpy(item['steps/action'].numpy()), 'rewards': torch.from_numpy(item['steps/reward'].numpy()), 'effector': torch.from_numpy(item['steps/observation/effector_translation'].numpy()), }