File size: 2,165 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tensorflow as tf
import os
import torch
from torch.utils.data import IterableDataset

def _parse_function(example_proto):
    # This dictionary defines the structure of the TFRecord
    feature_description = {
        'steps/observation/rgb': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True),
        'steps/observation/instruction': tf.io.FixedLenSequenceFeature([512], tf.int64, allow_missing=True),
        'steps/observation/effector_translation': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True),
        'steps/action': tf.io.FixedLenSequenceFeature([2], tf.float32, allow_missing=True),
        'steps/reward': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
        'steps/is_first': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
        'steps/is_last': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
    }
    
    parsed = tf.io.parse_single_example(example_proto, feature_description)
    
    # Decode JPEG images
    def decode_images(rgb_sequence):
        return tf.map_fn(lambda x: tf.io.decode_jpeg(x), rgb_sequence, fn_output_signature=tf.uint8)

    parsed['steps/observation/rgb'] = decode_images(parsed['steps/observation/rgb'])
    
    return parsed

class LanguageTableDataset(IterableDataset):
    def __init__(self, data_dir, num_shards=None):
        self.data_dir = data_dir
        self.file_pattern = os.path.join(data_dir, "language_table-train.tfrecord*")
        self.files = tf.io.gfile.glob(self.file_pattern)
        if num_shards:
            self.files = sorted(self.files)[:num_shards]
            
    def __iter__(self):
        dataset = tf.data.TFRecordDataset(self.files)
        dataset = dataset.map(_parse_function)
        
        for item in dataset:
            yield {
                'obs': torch.from_numpy(item['steps/observation/rgb'].numpy()),
                'actions': torch.from_numpy(item['steps/action'].numpy()),
                'rewards': torch.from_numpy(item['steps/reward'].numpy()),
                'effector': torch.from_numpy(item['steps/observation/effector_translation'].numpy()),
            }