Spaces:
Sleeping
Sleeping
| import pybullet as p | |
| import PySimpleGUI as sg | |
| import pickle | |
| from os import getcwd | |
| from urdfpy import URDF | |
| from os.path import abspath, dirname, basename, splitext | |
| from transforms3d.affines import decompose | |
| from transforms3d.quaternions import mat2quat | |
| import numpy as np | |
| class PyBulletRecorder: | |
| class LinkTracker: | |
| def __init__(self, | |
| name, | |
| body_id, | |
| link_id, | |
| link_origin, | |
| mesh_path, | |
| mesh_scale, | |
| mesh_material=None): | |
| self.body_id = body_id | |
| self.link_id = link_id | |
| self.mesh_path = mesh_path | |
| self.mesh_scale = mesh_scale | |
| self.mesh_material = mesh_material | |
| decomposed_origin = decompose(link_origin) | |
| orn = mat2quat(decomposed_origin[1]) | |
| orn = [orn[1], orn[2], orn[3], orn[0]] | |
| self.link_pose = [decomposed_origin[0], | |
| orn] | |
| self.name = name | |
| def transform(self, position, orientation): | |
| return p.multiplyTransforms( | |
| position, orientation, | |
| self.link_pose[0], self.link_pose[1], | |
| ) | |
| def get_keyframe(self): | |
| if self.link_id == -1: | |
| position, orientation = p.getBasePositionAndOrientation( | |
| self.body_id) | |
| position, orientation = self.transform( | |
| position=position, orientation=orientation) | |
| else: | |
| link_state = p.getLinkState(self.body_id, | |
| self.link_id, | |
| computeForwardKinematics=True) | |
| position, orientation = self.transform( | |
| position=link_state[4], | |
| orientation=link_state[5]) | |
| return { | |
| 'position': list(position), | |
| 'orientation': list(orientation) | |
| } | |
| def __init__(self): | |
| self.states = [] | |
| self.links = [] | |
| def register_object(self, body_id, urdf_path, global_scaling=1, color=None): | |
| link_id_map = dict() | |
| n = p.getNumJoints(body_id) | |
| link_id_map[str(p.getBodyInfo(body_id)[0].decode('gb2312'))] = -1 | |
| for link_id in range(0, n): | |
| link_id_map[str(p.getJointInfo(body_id, link_id)[ | |
| 12].decode('gb2312'))] = link_id | |
| dir_path = dirname(abspath(urdf_path)) | |
| file_name = splitext(basename(urdf_path))[0] | |
| robot = URDF.load(urdf_path) | |
| for link in robot.links: | |
| # print("robot link:", body_id, link.name, link_id_map.keys()) | |
| if link.name not in link_id_map: | |
| print("skip links !! ", link.name, link_id_map, len(robot.links), p.getBodyInfo(body_id)[0].decode('gb2312')) | |
| continue | |
| link_id = link_id_map[link.name] | |
| if len(link.visuals) > 0: | |
| for i, link_visual in enumerate(link.visuals): | |
| mesh_material = None | |
| if link_visual.material is not None: | |
| mesh_material = link_visual.material | |
| if color is not None: | |
| mesh_material.name = mesh_material.name + f"_{np.random.randint(100)}" # mark it | |
| mesh_material.color = color | |
| if link_visual.geometry.mesh is not None: | |
| print("use mesh", i, link_id_map.keys()) | |
| mesh_scale = [global_scaling, | |
| global_scaling, global_scaling]\ | |
| if link_visual.geometry.mesh.scale is None \ | |
| else link_visual.geometry.mesh.scale * global_scaling | |
| self.links.append(('mesh', | |
| PyBulletRecorder.LinkTracker( | |
| name=file_name + f'_{body_id}_{link.name}_{i}', | |
| body_id=body_id, | |
| link_id=link_id, | |
| link_origin= # If link_id == -1 then is base link, | |
| # PyBullet will return | |
| # inertial_origin @ visual_origin, | |
| # so need to undo that transform | |
| (np.linalg.inv(link.inertial.origin) | |
| if link_id == -1 | |
| else np.identity(4)) @ | |
| link_visual.origin * global_scaling, | |
| mesh_path=dir_path + '/' + | |
| link_visual.geometry.mesh.filename, | |
| mesh_scale=mesh_scale, | |
| mesh_material=mesh_material))) | |
| if link_visual.geometry.box is not None: | |
| print("use box", i, link_id_map.keys(), link_visual.geometry.box.__dict__) | |
| # import IPython; IPython.embed() | |
| mesh_scale = link_visual.geometry.box.size / 2 | |
| self.links.append(('box', | |
| PyBulletRecorder.LinkTracker( | |
| name=file_name + f'_{body_id}_{link.name}_{i}', | |
| body_id=body_id, | |
| link_id=link_id, | |
| link_origin= (np.linalg.inv(link.inertial.origin) | |
| if link_id == -1 | |
| else np.identity(4)) @ | |
| link_visual.origin * global_scaling, | |
| mesh_path='box', | |
| mesh_scale=mesh_scale, | |
| mesh_material=mesh_material))) | |
| if link_visual.geometry.cylinder is not None: | |
| print("use cylinder", i, link_id_map.keys(), link_visual.geometry.cylinder.__dict__) | |
| mesh_scale = [link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.length] | |
| self.links.append(('cylinder', | |
| PyBulletRecorder.LinkTracker( | |
| name=file_name + f'_{body_id}_{link.name}_{i}', | |
| body_id=body_id, | |
| link_id=link_id, | |
| link_origin= (np.linalg.inv(link.inertial.origin) | |
| if link_id == -1 | |
| else np.identity(4)) @ | |
| link_visual.origin * global_scaling, | |
| mesh_path='cylinder', | |
| mesh_scale=mesh_scale, | |
| mesh_material=mesh_material))) | |
| if link_visual.geometry.sphere is not None: | |
| print("use sphere", i, link_id_map.keys(), link_visual.geometry.sphere.__dict__) | |
| mesh_scale = [link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius] | |
| self.links.append(('sphere', | |
| PyBulletRecorder.LinkTracker( | |
| name=file_name + f'_{body_id}_{link.name}_{i}', | |
| body_id=body_id, | |
| link_id=link_id, | |
| link_origin= (np.linalg.inv(link.inertial.origin) | |
| if link_id == -1 | |
| else np.identity(4)) @ | |
| link_visual.origin * global_scaling, | |
| mesh_path='sphere', | |
| mesh_scale=mesh_scale, | |
| mesh_material=mesh_material))) | |
| def add_keyframe(self): | |
| # Ideally, call every p.stepSimulation() | |
| current_state = {} | |
| for name, link in self.links: | |
| current_state[link.name] = link.get_keyframe() | |
| self.states.append(current_state) | |
| def prompt_save(self): | |
| layout = [[sg.Text('Do you want to save previous episode?')], | |
| [sg.Button('Yes'), sg.Button('No')]] | |
| window = sg.Window('PyBullet Recorder', layout) | |
| save = False | |
| while True: | |
| event, values = window.read() | |
| if event in (None, 'No'): | |
| break | |
| elif event == 'Yes': | |
| save = True | |
| break | |
| window.close() | |
| if save: | |
| layout = [[sg.Text('Where do you want to save it?')], | |
| [sg.Text('Path'), sg.InputText(getcwd())], | |
| [sg.Button('OK')]] | |
| window = sg.Window('PyBullet Recorder', layout) | |
| event, values = window.read() | |
| window.close() | |
| self.save(values[0]) | |
| self.reset() | |
| def reset(self): | |
| self.states = [] | |
| def get_formatted_output(self): | |
| retval = {} | |
| for geo_name, link in self.links: | |
| if geo_name == 'mesh': | |
| retval[link.name] = { | |
| 'type': 'mesh', | |
| 'mesh_path': link.mesh_path, | |
| 'mesh_scale': link.mesh_scale, | |
| 'frames': [state[link.name] for state in self.states] | |
| } | |
| if geo_name == 'box': | |
| # print("retval: box!") | |
| retval[link.name] = { | |
| 'type': 'cube', | |
| 'name': link.name, | |
| 'mesh_scale': link.mesh_scale, | |
| 'frames': [state[link.name] for state in self.states] | |
| } | |
| if geo_name == 'cylinder': | |
| retval[link.name] = { | |
| 'type': 'cylinder', | |
| 'name': link.name, | |
| 'mesh_scale': link.mesh_scale, | |
| 'frames': [state[link.name] for state in self.states] | |
| } | |
| if geo_name == 'sphere': | |
| retval[link.name] = { | |
| 'type': 'sphere', | |
| 'name': link.name, | |
| 'mesh_scale': link.mesh_scale, | |
| 'frames': [state[link.name] for state in self.states] | |
| } | |
| if link.mesh_material is not None: | |
| retval[link.name]['mesh_material_name'] = link.mesh_material.name | |
| retval[link.name] ['mesh_material_color'] = link.mesh_material.color | |
| return retval | |
| def save(self, path): | |
| if path is None: | |
| print("[Recorder] Path is None.. not saving") | |
| else: | |
| print("[Recorder] Saving state to {}".format(path)) | |
| pickle.dump(self.get_formatted_output(), open(path, 'wb')) | |