lsnu's picture
Add files using upload-large-folder tool
5ce8761 verified
import open3d # DON'T DELETE THIS!
from multiprocessing import Process, Manager
from pyrep.const import RenderMode
from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import (
task_file_to_task_class,
float_array_to_rgb_image
)
import rlbench.backend.task as task
import os
import pickle
from PIL import Image
from rlbench.backend.const import *
import numpy as np
import random
from data_generation.customized_rlbench import CustomizedEnvironment
from absl import app
from absl import flags
MESH_POINT_FOLDER = 'mesh_points'
MESH_POINT_FORMAT = '%d.pkl'
FLAGS = flags.FLAGS
flags.DEFINE_string('save_path',
'data/train_dataset/microsteps/seed{seed}',
'Where to save the demos.')
flags.DEFINE_list('tasks', [],
'The tasks to collect. If empty, all tasks are collected.')
flags.DEFINE_list('image_size', [128, 128],
'The size of the images tp save.')
flags.DEFINE_enum('renderer', 'opengl3', ['opengl', 'opengl3'],
'The renderer to use. opengl does not include shadows, '
'but is faster.')
flags.DEFINE_integer('processes', 1,
'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
'The number of episodes to collect per task.')
flags.DEFINE_integer('variations', -1,
'Number of variations to collect per task. -1 for all.')
flags.DEFINE_integer('offset', 0,
'First variation id.')
flags.DEFINE_boolean('state', False,
'Record the state (not available for all tasks).')
flags.DEFINE_integer('seed', 0,
'Seed of randomness')
def check_and_make(dir):
os.makedirs(dir, exist_ok=True)
class DemoSaver:
def __init__(self, demo, example_path):
self.demo = demo
self.example_path = example_path
def store(self, folder, attr):
# Create folder
path_ = os.path.join(self.example_path, folder)
os.makedirs(path_, exist_ok=True)
# Loop over demo and store
for i, obs in enumerate(self.demo):
# Read image
img = obs.__getattribute__(attr)
if 'rgb' in attr:
img = Image.fromarray(img)
elif 'depth' in attr:
img = float_array_to_rgb_image(img, scale_factor=DEPTH_SCALE)
elif 'mask' in attr:
img = Image.fromarray((img * 255).astype(np.uint8))
# Save image
img.save(os.path.join(path_, IMAGE_FORMAT % i))
# Set to None for pickling later
obs.__setattr__(attr, None)
def save_demo(demo, example_path):
ds = DemoSaver(demo, example_path)
paths_attrs = [
(LEFT_SHOULDER_RGB_FOLDER, 'left_shoulder_rgb'),
(LEFT_SHOULDER_DEPTH_FOLDER, 'left_shoulder_depth'),
(LEFT_SHOULDER_MASK_FOLDER, 'left_shoulder_mask'),
(RIGHT_SHOULDER_RGB_FOLDER, 'right_shoulder_rgb'),
(RIGHT_SHOULDER_DEPTH_FOLDER, 'right_shoulder_depth'),
(RIGHT_SHOULDER_MASK_FOLDER, 'right_shoulder_mask'),
(OVERHEAD_RGB_FOLDER, 'overhead_rgb'),
(OVERHEAD_DEPTH_FOLDER, 'overhead_depth'),
(OVERHEAD_MASK_FOLDER, 'overhead_mask'),
(WRIST_RGB_FOLDER, 'wrist_rgb'),
(WRIST_DEPTH_FOLDER, 'wrist_depth'),
(WRIST_MASK_FOLDER, 'wrist_mask'),
(FRONT_RGB_FOLDER, 'front_rgb'),
(FRONT_DEPTH_FOLDER, 'front_depth'),
(FRONT_MASK_FOLDER, 'front_mask')
]
# Save image data first and then None them
for folder, attr in paths_attrs:
ds.store(folder, attr)
# Store object point clouds
mesh_point_path = os.path.join(example_path, MESH_POINT_FOLDER)
os.makedirs(mesh_point_path, exist_ok=True)
for i, obs in enumerate(demo):
mesh_points = obs.mesh_points
with open(os.path.join(mesh_point_path, MESH_POINT_FORMAT % i), 'wb') as f:
pickle.dump(mesh_points, f)
obs.__delattr__('mesh_points')
# Save the low-dimension data
with open(os.path.join(example_path, LOW_DIM_PICKLE), 'wb') as f:
pickle.dump(demo, f)
def run(i, lock, task_index, variation_count, results, file_lock, tasks):
"""Each thread will choose one task and variation, and then gather
all the episodes_per_task for that variation."""
# Initialize each thread with random seed
np.random.seed(FLAGS.seed)
random.seed(FLAGS.seed)
num_tasks = len(tasks)
img_size = list(map(int, FLAGS.image_size))
obs_config = ObservationConfig()
obs_config.set_all(True)
obs_config.right_shoulder_camera.image_size = img_size
obs_config.left_shoulder_camera.image_size = img_size
obs_config.overhead_camera.image_size = img_size
obs_config.wrist_camera.image_size = img_size
obs_config.front_camera.image_size = img_size
# Store depth as 0 - 1
obs_config.right_shoulder_camera.depth_in_meters = False
obs_config.left_shoulder_camera.depth_in_meters = False
obs_config.overhead_camera.depth_in_meters = False
obs_config.wrist_camera.depth_in_meters = False
obs_config.front_camera.depth_in_meters = False
# We want to save the masks as rgb encodings.
obs_config.left_shoulder_camera.masks_as_one_channel = False
obs_config.right_shoulder_camera.masks_as_one_channel = False
obs_config.overhead_camera.masks_as_one_channel = False
obs_config.wrist_camera.masks_as_one_channel = False
obs_config.front_camera.masks_as_one_channel = False
# No need to save point cloud, we'll unproject them from depth
obs_config.left_shoulder_camera.point_cloud = False
obs_config.right_shoulder_camera.point_cloud = False
obs_config.overhead_camera.point_cloud = False
obs_config.wrist_camera.point_cloud = False
obs_config.front_camera.point_cloud = False
if FLAGS.renderer == 'opengl':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.overhead_camera.render_mode = RenderMode.OPENGL
obs_config.wrist_camera.render_mode = RenderMode.OPENGL
obs_config.front_camera.render_mode = RenderMode.OPENGL
rlbench_env = CustomizedEnvironment(
action_mode=MoveArmThenGripper(JointVelocity(), Discrete()),
obs_config=obs_config,
headless=True
)
rlbench_env.launch()
task_env = None
tasks_with_problems = results[i] = ''
while True:
# Figure out what task/variation this thread is going to do
with lock:
if task_index.value >= num_tasks:
print('Process', i, 'finished')
break
my_variation_count = variation_count.value
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
var_target = task_env.variation_count()
if FLAGS.variations >= 0:
var_target = np.minimum(FLAGS.variations+FLAGS.offset, var_target)
if my_variation_count >= var_target:
# If we have reached the required number of variations for this
# task, then move on to the next task.
variation_count.value = my_variation_count = FLAGS.offset
task_index.value += 1
variation_count.value += 1
if task_index.value >= num_tasks:
print('Process', i, 'finished')
break
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
task_env.set_variation(my_variation_count)
descriptions, obs = task_env.reset()
variation_path = os.path.join(
FLAGS.save_path, task_env.get_name(),
VARIATIONS_FOLDER % my_variation_count
)
print(variation_path)
check_and_make(variation_path)
with open(os.path.join(variation_path, VARIATION_DESCRIPTIONS), 'wb') as f:
pickle.dump(descriptions, f)
episodes_path = os.path.join(variation_path, EPISODES_FOLDER)
check_and_make(episodes_path)
abort_variation = False
print("episode per task", FLAGS.episodes_per_task)
for ex_idx in range(FLAGS.episodes_per_task):
print('Process', i, '// Task:', task_env.get_name(),
'// Variation:', my_variation_count, '// Demo:', ex_idx)
attempts = 10
while attempts > 0:
episode_path = os.path.join(episodes_path, EPISODE_FOLDER % ex_idx)
if os.path.exists(episode_path):
break
try:
print("starting demo")
demo, = task_env.get_demos(amount=1, live_demos=True)
print("demo collected")
except Exception as e:
attempts -= 1
if attempts > 0:
print('Process %d failed collecting task %s (variation: %d, '
'example: %d). Retrying...\n%s\n' % (
i, task_env.get_name(), my_variation_count, ex_idx,
str(e)))
continue
problem = (
'Process %d failed collecting task %s (variation: %d, '
'example: %d). Skipping this task/variation.\n%s\n' % (
i, task_env.get_name(), my_variation_count, ex_idx,
str(e))
)
print(problem)
tasks_with_problems += problem
abort_variation = True
break
with file_lock:
print("saving demo")
save_demo(demo, episode_path)
break
if abort_variation:
break
results[i] = tasks_with_problems
rlbench_env.shutdown()
def main(argv):
FLAGS.save_path = FLAGS.save_path.format(seed=FLAGS.seed)
task_files = [t.replace('.py', '') for t in os.listdir(task.TASKS_PATH)
if t != '__init__.py' and t.endswith('.py')]
if len(FLAGS.tasks) > 0:
for t in FLAGS.tasks:
if t not in task_files:
raise ValueError('Task %s not recognised!.' % t)
task_files = FLAGS.tasks
tasks = [task_file_to_task_class(t) for t in task_files]
manager = Manager()
result_dict = manager.dict()
file_lock = manager.Lock()
task_index = manager.Value('i', 0)
variation_count = manager.Value('i', FLAGS.offset)
lock = manager.Lock()
check_and_make(FLAGS.save_path)
processes = [Process(
target=run, args=(
i, lock, task_index, variation_count, result_dict, file_lock,
tasks))
for i in range(FLAGS.processes)]
[t.start() for t in processes]
[t.join() for t in processes]
print('Data collection done!')
for i in range(FLAGS.processes):
print(result_dict[i])
if __name__ == '__main__':
app.run(main)