File size: 3,365 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import time
import copy
import os
import torch
from thirdparty.learning3d.models import PointNet, PointNetLK, DCP, DGCNN, iPCRNet, PRNet, PPFNet, RPMNet
from r3pm_net.model import R3PMNet
from r3pm_net.feature_extractor import feature_extractor # import your feature extractor here
from tools import metrics, l3d_helper
def l3d_reg_and_eval(source, target, method, gt_transformation, args, source_normals = None, target_normals = None):
'''
Perform registration and evaluation for a given Lerning3d method.
Args:
source (o3d.geometry.PointCloud): source point cloud
target (o3d.geometry.PointCloud): target point cloud
method (str): method name (e.g., 'dcp', 'rpmnet', 'pcrnet', 'pointnetlk')
gt_transformation (np.ndarray): ground truth transformation matrix
args (argparse.Namespace): arguments
Returns:
evaluation_results (np.ndarray): The evaluation results (rmse, rotation_error, translation_error, computation_time).
'''
# define the model
if method == 'dcp':
dgcnn = DGCNN(emb_dims=args.emb_dims)
model = DCP(feature_model=dgcnn, cycle=True)
elif method == 'rpmnet':
model = RPMNet(feature_model=PPFNet())
elif method == 'pcrnet':
ptnet = PointNet(emb_dims=args.emb_dims)
model = iPCRNet(feature_model=ptnet)
elif method == 'pointnetlk':
ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
model = PointNetLK(feature_model=ptnet)
elif method == 'prnet':
model = PRNet(emb_dims=args.emb_dims, num_iters=args.num_iterations)
elif method == 'r3pmnet':
FEATURE_MODEL = feature_extractor
model = R3PMNet(feature_model=FEATURE_MODEL)
else:
raise ValueError(f"Unknown method: {method}")
# load pretrained model
if args.pretrained:
if not os.path.isfile(args.pretrained):
raise FileNotFoundError(f"Pretrained checkpoint not found.")
try:
payload = torch.load(args.pretrained, weights_only=False)
except TypeError:
payload = torch.load(args.pretrained)
model.load_state_dict(payload, strict=False)
# move model to device and set to eval mode
model = model.to(args.device)
model.eval()
if source_normals is not None and target_normals is not None:
source_tensor = l3d_helper.add_normal(source, source_normals, args.device)
target_tensor = l3d_helper.add_normal(target, target_normals, args.device)
print(source_tensor.shape)
else:
# convert data to tensor
source_tensor = l3d_helper.convert_data(source, args.device)
target_tensor = l3d_helper.convert_data(target, args.device)
# Perform warm-up run (to avoid slow first run)
model(target_tensor, source_tensor)
# perform registration
start_time = time.time()
output = model(target_tensor, source_tensor)
end_time = time.time()
computation_time = end_time - start_time
# Apply transformation
result = output['est_T'].detach().cpu().numpy()[0]
result = result.reshape(4, 4)
pc_result = copy.deepcopy(source).transform(result)
# Evaluation
evaluation_results = metrics.all_evaluations(source, target, pc_result, computation_time, gt_transformation, output, corres = None)
return pc_result, evaluation_results |