Spaces:
Running
on
Zero
Running
on
Zero
| import xml.etree.ElementTree as ET | |
| import json | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| def parse_urdf(file_path): | |
| """ | |
| Parse a URDF file and extract information about joints of a specific type. | |
| Args: | |
| file_path (str): Path to the URDF file | |
| target_joint_type (str): Joint type to look for ('revolute', 'prismatic', etc.) | |
| Returns: | |
| tuple: (list of joint names of the target type, dict of descendants for each joint) | |
| """ | |
| # Parse the XML file | |
| tree = ET.parse(file_path) | |
| root = tree.getroot() | |
| # Find all joints | |
| joints = root.findall('.//joint') | |
| # Create dictionaries to store joint information | |
| joint_types = {} # joint_name -> joint_type | |
| child_links = {} # joint_name -> child_link_name | |
| parent_links = {} # joint_name -> parent_link_name | |
| # Extract information from joints | |
| for joint in joints: | |
| joint_name = joint.get('name') | |
| joint_type = joint.get('type') | |
| child_elem = joint.find('child') | |
| parent_elem = joint.find('parent') | |
| if child_elem is not None and parent_elem is not None: | |
| child_link = child_elem.get('link') | |
| parent_link = parent_elem.get('link') | |
| joint_types[joint_name] = joint_type | |
| child_links[joint_name] = child_link | |
| parent_links[joint_name] = parent_link | |
| # Create reverse mapping from link to joint | |
| link_to_joint = {} # link_name -> joint_name (where link is the child of the joint) | |
| for joint_name, child_link in child_links.items(): | |
| link_to_joint[child_link] = joint_name | |
| # Find joints with the target type | |
| target_joints = [(j, t) for j, t in joint_types.items()] | |
| # Build descendant chains for each joint | |
| descendants = {joint_name: [] for joint_name in joint_types.keys()} | |
| # For each joint | |
| for joint_name in joint_types.keys(): | |
| # Get its child link | |
| child_link = child_links.get(joint_name) | |
| # Look for joints that have this link as parent | |
| for other_joint, parent in parent_links.items(): | |
| if parent == child_link: | |
| # Add to descendants | |
| descendants[joint_name].append(other_joint) | |
| # Recursively find all descendants | |
| to_process = [other_joint] | |
| while to_process: | |
| current = to_process.pop(0) | |
| current_child_link = child_links.get(current) | |
| for j, p in parent_links.items(): | |
| if p == current_child_link and j not in descendants[joint_name]: | |
| descendants[joint_name].append(j) | |
| to_process.append(j) | |
| return target_joints, descendants | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--data_list', type=str, default='configs/partnet.json') | |
| parser.add_argument('--target_info', type=str, default='configs/partnet_target.json') | |
| parser.add_argument('--input_dir', type=str, default='datasets/PartNet_raw') | |
| parser.add_argument('--output_dir', type=str, default='datasets/PartNet') | |
| args = parser.parse_args() | |
| print(f"Processing PartNet joints...") | |
| # Preprocess all PartNet models | |
| for model_id in tqdm(os.listdir(args.input_dir)): | |
| urdf_path = f'{args.input_dir}/{model_id}/mobility.urdf' | |
| meta_path = f'{args.input_dir}/{model_id}/meta.json' | |
| joint_meta_path = f'{args.input_dir}/{model_id}/mobility_v2.json' | |
| output_dir = f'{args.input_dir}/{model_id}' | |
| joints_raw, descendants = parse_urdf(urdf_path) | |
| joints = [] | |
| with open(meta_path, 'r') as f: | |
| meta_info = json.load(f) | |
| with open(joint_meta_path, 'r') as f: | |
| joint_meta_info = json.load(f) | |
| for joint_name, joint_type in joints_raw: | |
| id = joint_name.split('_')[-1] | |
| for joint_meta in joint_meta_info: | |
| if joint_meta['id'] == int(id): | |
| joint_tag = joint_meta['name'] | |
| joints.append((joint_name, joint_type, joint_tag)) | |
| output = { | |
| "category": meta_info['model_cat'], | |
| "joints": joints, | |
| "descendants": descendants | |
| } | |
| with open(f'{output_dir}/joints.json', 'w') as f: | |
| json.dump(output, f, indent=4) | |
| # Preprocess target models | |
| with open(args.data_list) as f: | |
| data_info = json.load(f) | |
| with open(args.target_info) as f: | |
| target_info = json.load(f) | |
| model_ids = data_info['total_obj_ids'] | |
| for model_id in tqdm(model_ids): | |
| with open(f'{args.input_dir}/{model_id}/joints.json', 'r') as f: | |
| joints_info = json.load(f) | |
| category = joints_info['category'] | |
| target_joint_type = target_info[category]['joint_type'] | |
| target_joint_tags = target_info[category]['joint_tags'] | |
| valid_joints = [] | |
| for joint in joints_info['joints']: | |
| joint_name, joint_type, joint_tag = joint | |
| if joint_type == target_joint_type and joint_tag in target_joint_tags: | |
| valid_joints.append(joint_name) | |
| joints_info['joints'] = valid_joints | |
| os.makedirs(f'{args.output_dir}/{model_id}', exist_ok=True) | |
| with open(f'{args.output_dir}/{model_id}/joints.json', 'w') as f: | |
| json.dump(joints_info, f, indent=4) | |
| os.system(f"cp {args.input_dir}/{model_id}/mobility.urdf {args.output_dir}/{model_id}/mobility.urdf") | |