| | import argparse |
| | import cv2 |
| | import datetime |
| | import h5py |
| | import init_path |
| | import json |
| | import numpy as np |
| | import os |
| | import robosuite as suite |
| | import time |
| | from glob import glob |
| | from robosuite import load_controller_config |
| | from robosuite.wrappers import DataCollectionWrapper, VisualizationWrapper |
| | from robosuite.utils.input_utils import input2action |
| |
|
| |
|
| | import libero.libero.envs.bddl_utils as BDDLUtils |
| | from libero.libero.envs import * |
| |
|
| |
|
| | def collect_human_trajectory( |
| | env, device, arm, env_configuration, problem_info, remove_directory=[] |
| | ): |
| | """ |
| | Use the device (keyboard or SpaceNav 3D mouse) to collect a demonstration. |
| | The rollout trajectory is saved to files in npz format. |
| | Modify the DataCollectionWrapper wrapper to add new fields or change data formats. |
| | |
| | Args: |
| | env (MujocoEnv): environment to control |
| | device (Device): to receive controls from the device |
| | arms (str): which arm to control (eg bimanual) 'right' or 'left' |
| | env_configuration (str): specified environment configuration |
| | """ |
| |
|
| | reset_success = False |
| | while not reset_success: |
| | try: |
| | env.reset() |
| | reset_success = True |
| | except: |
| | continue |
| |
|
| | |
| | env.render() |
| |
|
| | task_completion_hold_count = ( |
| | -1 |
| | ) |
| | device.start_control() |
| |
|
| | |
| | saving = True |
| | count = 0 |
| |
|
| | while True: |
| | count += 1 |
| | |
| | active_robot = ( |
| | env.robots[0] |
| | if env_configuration == "bimanual" |
| | else env.robots[arm == "left"] |
| | ) |
| |
|
| | |
| | action, grasp = input2action( |
| | device=device, |
| | robot=active_robot, |
| | active_arm=arm, |
| | env_configuration=env_configuration, |
| | ) |
| |
|
| | |
| | if action is None: |
| | print("Break") |
| | saving = False |
| | break |
| |
|
| | |
| |
|
| | env.step(action) |
| | env.render() |
| | |
| | if task_completion_hold_count == 0: |
| | break |
| |
|
| | |
| | if env._check_success(): |
| | if task_completion_hold_count > 0: |
| | task_completion_hold_count -= 1 |
| | else: |
| | task_completion_hold_count = 10 |
| | else: |
| | task_completion_hold_count = -1 |
| |
|
| | print(count) |
| | |
| | if not saving: |
| | remove_directory.append(env.ep_directory.split("/")[-1]) |
| | env.close() |
| | return saving |
| |
|
| |
|
| | def gather_demonstrations_as_hdf5( |
| | directory, out_dir, env_info, args, remove_directory=[] |
| | ): |
| | """ |
| | Gathers the demonstrations saved in @directory into a |
| | single hdf5 file. |
| | |
| | The strucure of the hdf5 file is as follows. |
| | |
| | data (group) |
| | date (attribute) - date of collection |
| | time (attribute) - time of collection |
| | repository_version (attribute) - repository version used during collection |
| | env (attribute) - environment name on which demos were collected |
| | |
| | demo1 (group) - every demonstration has a group |
| | model_file (attribute) - model xml string for demonstration |
| | states (dataset) - flattened mujoco states |
| | actions (dataset) - actions applied during demonstration |
| | |
| | demo2 (group) |
| | ... |
| | |
| | Args: |
| | directory (str): Path to the directory containing raw demonstrations. |
| | out_dir (str): Path to where to store the hdf5 file. |
| | env_info (str): JSON-encoded string containing environment information, |
| | including controller and robot info |
| | """ |
| |
|
| | hdf5_path = os.path.join(out_dir, "demo.hdf5") |
| | f = h5py.File(hdf5_path, "w") |
| |
|
| | |
| | grp = f.create_group("data") |
| |
|
| | num_eps = 0 |
| | env_name = None |
| |
|
| | for ep_directory in os.listdir(directory): |
| | |
| | if ep_directory in remove_directory: |
| | |
| | continue |
| | state_paths = os.path.join(directory, ep_directory, "state_*.npz") |
| | states = [] |
| | actions = [] |
| |
|
| | for state_file in sorted(glob(state_paths)): |
| | dic = np.load(state_file, allow_pickle=True) |
| | env_name = str(dic["env"]) |
| |
|
| | states.extend(dic["states"]) |
| | for ai in dic["action_infos"]: |
| | actions.append(ai["actions"]) |
| |
|
| | if len(states) == 0: |
| | continue |
| |
|
| | |
| | |
| | del states[-1] |
| | assert len(states) == len(actions) |
| |
|
| | num_eps += 1 |
| | ep_data_grp = grp.create_group("demo_{}".format(num_eps)) |
| |
|
| | |
| | xml_path = os.path.join(directory, ep_directory, "model.xml") |
| | with open(xml_path, "r") as f: |
| | xml_str = f.read() |
| | ep_data_grp.attrs["model_file"] = xml_str |
| |
|
| | |
| | ep_data_grp.create_dataset("states", data=np.array(states)) |
| | ep_data_grp.create_dataset("actions", data=np.array(actions)) |
| |
|
| | |
| | now = datetime.datetime.now() |
| | grp.attrs["date"] = "{}-{}-{}".format(now.month, now.day, now.year) |
| | grp.attrs["time"] = "{}:{}:{}".format(now.hour, now.minute, now.second) |
| | grp.attrs["repository_version"] = suite.__version__ |
| | grp.attrs["env"] = env_name |
| | grp.attrs["env_info"] = env_info |
| |
|
| | grp.attrs["problem_info"] = json.dumps(problem_info) |
| | grp.attrs["bddl_file_name"] = args.bddl_file |
| | grp.attrs["bddl_file_content"] = str(open(args.bddl_file, "r", encoding="utf-8")) |
| |
|
| | f.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--directory", |
| | type=str, |
| | default="demonstration_data", |
| | ) |
| | parser.add_argument( |
| | "--robots", |
| | nargs="+", |
| | type=str, |
| | default="Panda", |
| | help="Which robot(s) to use in the env", |
| | ) |
| | parser.add_argument( |
| | "--config", |
| | type=str, |
| | default="single-arm-opposed", |
| | help="Specified environment configuration if necessary", |
| | ) |
| | parser.add_argument( |
| | "--arm", |
| | type=str, |
| | default="right", |
| | help="Which arm to control (eg bimanual) 'right' or 'left'", |
| | ) |
| | parser.add_argument( |
| | "--camera", |
| | type=str, |
| | default="agentview", |
| | help="Which camera to use for collecting demos", |
| | ) |
| | parser.add_argument( |
| | "--controller", |
| | type=str, |
| | default="OSC_POSE", |
| | help="Choice of controller. Can be 'IK_POSE' or 'OSC_POSE'", |
| | ) |
| | parser.add_argument("--device", type=str, default="spacemouse") |
| | parser.add_argument( |
| | "--pos-sensitivity", |
| | type=float, |
| | default=1.5, |
| | help="How much to scale position user inputs", |
| | ) |
| | parser.add_argument( |
| | "--rot-sensitivity", |
| | type=float, |
| | default=1.0, |
| | help="How much to scale rotation user inputs", |
| | ) |
| | parser.add_argument( |
| | "--num-demonstration", |
| | type=int, |
| | default=50, |
| | help="How much to scale rotation user inputs", |
| | ) |
| | parser.add_argument("--bddl-file", type=str) |
| |
|
| | parser.add_argument("--vendor-id", type=int, default=9583) |
| | parser.add_argument("--product-id", type=int, default=50734) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | controller_config = load_controller_config(default_controller=args.controller) |
| |
|
| | |
| | config = { |
| | "robots": args.robots, |
| | "controller_configs": controller_config, |
| | } |
| |
|
| | assert os.path.exists(args.bddl_file) |
| | problem_info = BDDLUtils.get_problem_info(args.bddl_file) |
| | |
| |
|
| | |
| | problem_name = problem_info["problem_name"] |
| | domain_name = problem_info["domain_name"] |
| | language_instruction = problem_info["language_instruction"] |
| | if "TwoArm" in problem_name: |
| | config["env_configuration"] = args.config |
| | print(language_instruction) |
| | env = TASK_MAPPING[problem_name]( |
| | bddl_file_name=args.bddl_file, |
| | **config, |
| | has_renderer=True, |
| | has_offscreen_renderer=False, |
| | render_camera=args.camera, |
| | ignore_done=True, |
| | use_camera_obs=False, |
| | reward_shaping=True, |
| | control_freq=20, |
| | ) |
| |
|
| | |
| | env = VisualizationWrapper(env) |
| |
|
| | |
| | env_info = json.dumps(config) |
| |
|
| | |
| | tmp_directory = "demonstration_data/tmp/{}_ln_{}/{}".format( |
| | problem_name, |
| | language_instruction.replace(" ", "_").strip('""'), |
| | str(time.time()).replace(".", "_"), |
| | ) |
| |
|
| | env = DataCollectionWrapper(env, tmp_directory) |
| |
|
| | |
| | if args.device == "keyboard": |
| | from robosuite.devices import Keyboard |
| |
|
| | device = Keyboard( |
| | pos_sensitivity=args.pos_sensitivity, rot_sensitivity=args.rot_sensitivity |
| | ) |
| | env.viewer.add_keypress_callback("any", device.on_press) |
| | env.viewer.add_keyup_callback("any", device.on_release) |
| | env.viewer.add_keyrepeat_callback("any", device.on_press) |
| | elif args.device == "spacemouse": |
| | from robosuite.devices import SpaceMouse |
| |
|
| | device = SpaceMouse( |
| | args.vendor_id, |
| | args.product_id, |
| | pos_sensitivity=args.pos_sensitivity, |
| | rot_sensitivity=args.rot_sensitivity, |
| | ) |
| | else: |
| | raise Exception( |
| | "Invalid device choice: choose either 'keyboard' or 'spacemouse'." |
| | ) |
| |
|
| | |
| | t1, t2 = str(time.time()).split(".") |
| | new_dir = os.path.join( |
| | args.directory, |
| | f"{domain_name}_ln_{problem_name}_{t1}_{t2}_" |
| | + language_instruction.replace(" ", "_").strip('""'), |
| | ) |
| |
|
| | os.makedirs(new_dir) |
| |
|
| | |
| |
|
| | remove_directory = [] |
| | i = 0 |
| | while i < args.num_demonstration: |
| | print(i) |
| | saving = collect_human_trajectory( |
| | env, device, args.arm, args.config, problem_info, remove_directory |
| | ) |
| | if saving: |
| | print(remove_directory) |
| | gather_demonstrations_as_hdf5( |
| | tmp_directory, new_dir, env_info, args, remove_directory |
| | ) |
| | i += 1 |
| |
|