FreeArt3D / preprocess_partnet_joint.py
MorPhLingXD's picture
Upload demo
bf912fe
import xml.etree.ElementTree as ET
import json
import os
import argparse
from tqdm import tqdm
def parse_urdf(file_path):
"""
Parse a URDF file and extract information about joints of a specific type.
Args:
file_path (str): Path to the URDF file
target_joint_type (str): Joint type to look for ('revolute', 'prismatic', etc.)
Returns:
tuple: (list of joint names of the target type, dict of descendants for each joint)
"""
# Parse the XML file
tree = ET.parse(file_path)
root = tree.getroot()
# Find all joints
joints = root.findall('.//joint')
# Create dictionaries to store joint information
joint_types = {} # joint_name -> joint_type
child_links = {} # joint_name -> child_link_name
parent_links = {} # joint_name -> parent_link_name
# Extract information from joints
for joint in joints:
joint_name = joint.get('name')
joint_type = joint.get('type')
child_elem = joint.find('child')
parent_elem = joint.find('parent')
if child_elem is not None and parent_elem is not None:
child_link = child_elem.get('link')
parent_link = parent_elem.get('link')
joint_types[joint_name] = joint_type
child_links[joint_name] = child_link
parent_links[joint_name] = parent_link
# Create reverse mapping from link to joint
link_to_joint = {} # link_name -> joint_name (where link is the child of the joint)
for joint_name, child_link in child_links.items():
link_to_joint[child_link] = joint_name
# Find joints with the target type
target_joints = [(j, t) for j, t in joint_types.items()]
# Build descendant chains for each joint
descendants = {joint_name: [] for joint_name in joint_types.keys()}
# For each joint
for joint_name in joint_types.keys():
# Get its child link
child_link = child_links.get(joint_name)
# Look for joints that have this link as parent
for other_joint, parent in parent_links.items():
if parent == child_link:
# Add to descendants
descendants[joint_name].append(other_joint)
# Recursively find all descendants
to_process = [other_joint]
while to_process:
current = to_process.pop(0)
current_child_link = child_links.get(current)
for j, p in parent_links.items():
if p == current_child_link and j not in descendants[joint_name]:
descendants[joint_name].append(j)
to_process.append(j)
return target_joints, descendants
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_list', type=str, default='configs/partnet.json')
parser.add_argument('--target_info', type=str, default='configs/partnet_target.json')
parser.add_argument('--input_dir', type=str, default='datasets/PartNet_raw')
parser.add_argument('--output_dir', type=str, default='datasets/PartNet')
args = parser.parse_args()
print(f"Processing PartNet joints...")
# Preprocess all PartNet models
for model_id in tqdm(os.listdir(args.input_dir)):
urdf_path = f'{args.input_dir}/{model_id}/mobility.urdf'
meta_path = f'{args.input_dir}/{model_id}/meta.json'
joint_meta_path = f'{args.input_dir}/{model_id}/mobility_v2.json'
output_dir = f'{args.input_dir}/{model_id}'
joints_raw, descendants = parse_urdf(urdf_path)
joints = []
with open(meta_path, 'r') as f:
meta_info = json.load(f)
with open(joint_meta_path, 'r') as f:
joint_meta_info = json.load(f)
for joint_name, joint_type in joints_raw:
id = joint_name.split('_')[-1]
for joint_meta in joint_meta_info:
if joint_meta['id'] == int(id):
joint_tag = joint_meta['name']
joints.append((joint_name, joint_type, joint_tag))
output = {
"category": meta_info['model_cat'],
"joints": joints,
"descendants": descendants
}
with open(f'{output_dir}/joints.json', 'w') as f:
json.dump(output, f, indent=4)
# Preprocess target models
with open(args.data_list) as f:
data_info = json.load(f)
with open(args.target_info) as f:
target_info = json.load(f)
model_ids = data_info['total_obj_ids']
for model_id in tqdm(model_ids):
with open(f'{args.input_dir}/{model_id}/joints.json', 'r') as f:
joints_info = json.load(f)
category = joints_info['category']
target_joint_type = target_info[category]['joint_type']
target_joint_tags = target_info[category]['joint_tags']
valid_joints = []
for joint in joints_info['joints']:
joint_name, joint_type, joint_tag = joint
if joint_type == target_joint_type and joint_tag in target_joint_tags:
valid_joints.append(joint_name)
joints_info['joints'] = valid_joints
os.makedirs(f'{args.output_dir}/{model_id}', exist_ok=True)
with open(f'{args.output_dir}/{model_id}/joints.json', 'w') as f:
json.dump(joints_info, f, indent=4)
os.system(f"cp {args.input_dir}/{model_id}/mobility.urdf {args.output_dir}/{model_id}/mobility.urdf")