File size: 6,094 Bytes
d899b9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""
Converts data from the BridgeData numpy format to TFRecord format.
Consider the following directory structure for the input data:
bridgedata_numpy/
rss/
toykitchen2/
set_table/
00/
train/
out.npy
val/
out.npy
icra/
...
The --depth parameter controls how much of the data to process at the
--input_path; for example, if --depth=5, then --input_path should be
"bridgedata_numpy", and all data will be processed. If --depth=3, then
--input_path should be "bridgedata_numpy/rss/toykitchen2", and only data
under "toykitchen2" will be processed.
The same directory structure will be replicated under --output_path. For
example, in the second case, the output will be written to
"{output_path}/set_table/00/...".
Can read/write directly from/to Google Cloud Storage.
Written by Kevin Black (kvablack@berkeley.edu).
"""
import os
from multiprocessing import Pool
import numpy as np
import tensorflow as tf
import tqdm
from absl import app, flags, logging
import pickle
from multiprocessing import cpu_count
FLAGS = flags.FLAGS
flags.DEFINE_string("input_path", None, "Input path", required=True)
flags.DEFINE_string("output_path", None, "Output path", required=True)
flags.DEFINE_integer(
"depth",
5,
"Number of directories deep to traverse. Looks for {input_path}/dir_1/dir_2/.../dir_{depth-1}/train/out.npy",
)
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
num_workers = 8
flags.DEFINE_integer("num_workers", num_workers, "Number of threads to use")
print(f"using {num_workers} workers")
def tensor_feature(value):
return tf.train.Feature(
bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()])
)
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))
def _strings_feature(string_list):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=s.encode('utf-8')))
def _bool_feature(value):
"""Returns a bool_list from a boolean."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
def process(path):
# with tf.io.gfile.GFile(path, "rb") as f:
# arr = np.load(f, allow_pickle=True)
try:
with tf.io.gfile.GFile(path, "rb") as f:
arr = np.load(path, allow_pickle=True)
except Exception as e:
print(f"Error loading {path}: {e}")
return
dirname = os.path.dirname(os.path.abspath(path))
outpath = os.path.join(FLAGS.output_path, *dirname.split(os.sep)[-FLAGS.depth :])
if tf.io.gfile.exists(outpath):
if FLAGS.overwrite:
logging.info(f"Deleting {outpath}")
tf.io.gfile.rmtree(outpath)
else:
logging.info(f"Skipping {outpath}")
return
if len(arr) == 0:
logging.info(f"Skipping {path}, empty")
return
tf.io.gfile.makedirs(outpath)
for i,traj in enumerate(arr):
write_path = f"{outpath}/out_{i}.tfrecord"
with tf.io.TFRecordWriter(write_path) as writer:
truncates = np.zeros(len(traj["actions"]), dtype=np.bool_)
truncates[-1] = True
frames_num = len(traj["observations"])
# remove empty string
traj["language"] = [x for x in traj["language"] if x != ""]
if len(traj["language"]) == 0:
traj["language"] = [""]
instr = traj["language"][0]
if(len(traj["language"]) > 2):
print(len(traj["language"]))
for i in range(frames_num):
tf_features = {
"observations/images0": tensor_feature(
np.array(
[traj["observations"][i]["images0"]],
dtype=np.uint8,
)
),
"observations/state": tensor_feature(
np.array(
[traj["observations"][i]["state"]],
dtype=np.float32,
)
),
"observations/qpos": tensor_feature(
np.array(
[traj["observations"][i]["qpos"]],
dtype=np.float32,
)
),
"observations/eef_transform": tensor_feature(
np.array(
[traj["observations"][i]["eef_transform"]],
dtype=np.float32,
)
),
"language": _bytes_feature(instr),
"actions": tensor_feature(
np.array(traj["actions"][i], dtype=np.float32)
),
"truncates": _bool_feature(i == frames_num - 1),
}
example = tf.train.Example(
features=tf.train.Features(
feature = tf_features
)
)
writer.write(example.SerializeToString())
def main(_):
assert FLAGS.depth >= 1
paths = tf.io.gfile.glob(
tf.io.gfile.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1)))
)
paths = [f"{p}/train/out.npy" for p in paths] + [f"{p}/val/out.npy" for p in paths]
# num_episodes = 0
# for dirpath in paths:
# with tf.io.gfile.GFile(dirpath, "rb") as f:
# arr = np.load(dirpath, allow_pickle=True)
# num_episodes += len(arr)
# print(num_episodes)
with Pool(FLAGS.num_workers) as p:
list(tqdm.tqdm(p.imap(process, paths), total=len(paths)))
if __name__ == "__main__":
app.run(main)
|