File size: 8,967 Bytes
d899b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import numpy as np
import os
import cv2
from multiprocessing import Pool, cpu_count, current_process
import tensorflow as tf
from tqdm import tqdm
import json

def _parse_function(proto):
    # Define how to parse the data here.
    feature_description = {
        'joint': tf.io.FixedLenFeature([], tf.string),
        'image': tf.io.FixedLenFeature([], tf.string),
        'instruction': tf.io.FixedLenFeature([], tf.string),
        'terminate_episode': tf.io.FixedLenFeature([], tf.int64),
        'gripper': tf.io.FixedLenFeature([], tf.string, default_value=""), 
        'tcp': tf.io.FixedLenFeature([], tf.string, default_value=""),
        'tcp_base': tf.io.FixedLenFeature([], tf.string, default_value="")
    }
    parsed_features = tf.io.parse_single_example(proto, feature_description)
    # Parse tensors
    parsed_features['joint'] = tf.io.parse_tensor(parsed_features['joint'], out_type=tf.float64)
    parsed_features['image'] = tf.io.parse_tensor(parsed_features['image'], out_type=tf.uint8)
    parsed_features['instruction'] = tf.io.parse_tensor(parsed_features['instruction'], out_type=tf.string)
    parsed_features['gripper'] = tf.cond(
        tf.math.equal(parsed_features['gripper'], ""),
        lambda: tf.constant([], dtype=tf.float64),
        lambda: tf.io.parse_tensor(parsed_features['gripper'], out_type=tf.float64)
    )
    parsed_features['tcp'] = tf.cond(
        tf.math.equal(parsed_features['tcp'], ""),
        lambda: tf.constant([], dtype=tf.float64),
        lambda: tf.io.parse_tensor(parsed_features['tcp'], out_type=tf.float64)
    )
    parsed_features['tcp_base'] = tf.cond(
        tf.math.equal(parsed_features['tcp_base'], ""),
        lambda: tf.constant([], dtype=tf.float64),
        lambda: tf.io.parse_tensor(parsed_features['tcp_base'], out_type=tf.float64)
    )
    return parsed_features

def convert_color(color_file, color_timestamps):
    """
    Args:
    - color_file: the color video file;
    - color_timestamps: the color timestamps;
    - dest_color_dir: the destination color directory.
    """
    cap = cv2.VideoCapture(color_file)
    cnt = 0
    frames = []
    while True:
        ret, frame = cap.read()
        if ret:
            resized_frame = cv2.resize(frame, (640, 360))
            frames.append(resized_frame)
            cnt += 1
        else:
            break
    cap.release()
    return frames
    
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _bool_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))

def serialize_example(joint,gripper,tcp,tcp_base,image,instruction,terminate_episode):
    feature = {
            'joint': _bytes_feature(tf.io.serialize_tensor(joint)),
            'image': _bytes_feature(tf.io.serialize_tensor(image)),
            'instruction': _bytes_feature(tf.io.serialize_tensor(instruction)),
            'terminate_episode': _bool_feature(terminate_episode),
        }
    if gripper is not None:
        feature['gripper'] = _bytes_feature(tf.io.serialize_tensor(gripper))
    if tcp is not None:
        feature['tcp'] = _bytes_feature(tf.io.serialize_tensor(tcp))
    if tcp_base is not None:
        feature['tcp_base'] = _bytes_feature(tf.io.serialize_tensor(tcp_base))
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def compress_tfrecord(tfrecord_path):
    raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
    parsed_dataset = raw_dataset.map(_parse_function)

    # Serialize and write to a new TFRecord file
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for features in parsed_dataset:
            image_tensor = features['image']
            image_np = image_tensor.numpy()  
            if len(image_np.shape) <= 1: # already compressed
                return
            _, compressed_image = cv2.imencode('.jpg', image_np)
            features['image'] = tf.io.serialize_tensor(tf.convert_to_tensor(compressed_image.tobytes(), dtype=tf.string))

            def _bytes_feature(value):
                """Returns a bytes_list from a string / byte."""
                if isinstance(value, type(tf.constant(0))):
                    value = value.numpy()  
                return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

            feature_dict = {
                'joint': _bytes_feature(features['joint']),
                'image': _bytes_feature(features['image']),
                'instruction': _bytes_feature(features['instruction']),
                'terminate_episode': tf.train.Feature(int64_list=tf.train.Int64List(value=[features['terminate_episode']])),
                'gripper': _bytes_feature(features['gripper']),
                'tcp': _bytes_feature(features['tcp']),
                'tcp_base': _bytes_feature(features['tcp_base'])
            }
            example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict))
            serialized_example = example_proto.SerializeToString()
            writer.write(serialized_example)
    print(f"compressed {tfrecord_path}")
    
def write_task(args):
    task_dir,output_dir = args
    
    all_instructions = json.load(open('./instruction.json'))
    instruction = None
    for taskid in list(all_instructions.keys()):
        if taskid in task_dir:
            instruction = all_instructions[taskid]['task_description_english']
    if instruction is None:
        return
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    joints = np.load(os.path.join(task_dir,"transformed/joint.npy"),allow_pickle=True).item()
    if not os.path.exists(os.path.join(task_dir,"transformed/gripper.npy")):
        return
    grippers = np.load(os.path.join(task_dir,"transformed/gripper.npy"),allow_pickle=True).item()
    tcps = np.load(os.path.join(task_dir,"transformed/tcp.npy"),allow_pickle=True).item()
    tcp_bases = np.load(os.path.join(task_dir,"transformed/tcp_base.npy"),allow_pickle=True).item()
    
    for camid in joints.keys():
        timesteps = joints[camid]
        if len(timesteps) == 0:
            continue
        tfrecord_path = os.path.join(output_dir,f'cam_{camid}.tfrecord')
        timesteps_file = os.path.join(task_dir,f'cam_{camid}/timestamps.npy')
        
        if not os.path.exists(timesteps_file):
            continue
        if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
            continue

        timesteps_file = np.load(timesteps_file,allow_pickle=True).item()
        images = convert_color(os.path.join(task_dir,f'cam_{camid}/color.mp4'),timesteps_file['color'])
        if len(timesteps) != len(images): ## BUG FROM RH20T
            continue
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for i,timestep in enumerate(timesteps):
                # image = cv2.imread(os.path.join(img_dir,f"{timestep}.jpg"))
                image = cv2.imencode('.jpg', images[i])[1].tobytes()
                joint_pos = joints[camid][timestep]
                tcp = next((item for item in tcps[camid] if item['timestamp'] == timestep), None)['tcp']
                tcp_base = next((item for item in tcp_bases[camid] if item['timestamp'] == timestep), None)['tcp']
                if timestep not in grippers[camid]:
                    gripper_pos = None
                else:
                    gripper_pos = grippers[camid][timestep]['gripper_info']
                terminate_episode = i == len(timesteps) - 1
                # read from instruction.json
                serialized_example = serialize_example(joint_pos,gripper_pos,tcp,tcp_base,image,instruction,terminate_episode)
                writer.write(serialized_example)
            

def write_tfrecords(root_dir,output_dir,num_processes = None):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if num_processes is None:
        num_processes = cpu_count()
        
    num_files = 0
    args = []
    for dirs in os.listdir(root_dir):
        for task in os.listdir(os.path.join(root_dir,dirs)):
            if 'human' in task:
                continue
            task_dir = os.path.join(root_dir,dirs,task)
            joint_path = os.path.join(task_dir,"transformed/joint.npy")
            if not os.path.exists(joint_path):
                continue
            num_files += 1
            task_out = os.path.join(output_dir,dirs,task)
            os.makedirs(task_out,exist_ok=True)
            args.append((task_dir,task_out))
            
    with tqdm(total=num_files, desc="Processing files") as pbar:
        with Pool(num_processes) as pool:   
            for _ in pool.imap_unordered(write_task, args):
                pbar.update(1)
                
write_tfrecords('../datasets/rh20t/raw_data/','../datasets/rh20t/tfrecords/')