| import os |
| import glob |
| import argparse |
| import numpy as np |
| import laspy |
| import torch |
| import h5py |
| import hydra |
| import torch.nn.functional as F |
|
|
| from src.utils import init_config |
| from src.transforms import ( |
| instantiate_datamodule_transforms, |
| SampleRecursiveMainXYAxisTiling, |
| NAGRemoveKeys |
| ) |
| from src.datasets.gridnet import read_gridnet_tile |
|
|
|
|
| def run_inference(model, cfg, transforms_dict, root_dir, split, scale, pc_tiling): |
| split_dir = os.path.join(root_dir, split) |
| las_files = glob.glob(os.path.join(split_dir, "*", "lidar", "*.las")) |
| for filepath in las_files: |
| print(f"\n[Inference] Processing: {filepath}") |
| data_las = laspy.read(filepath) |
| offset_initial_las = np.array(data_las.header.offset, dtype=np.float64) |
| data_las = read_gridnet_tile( |
| filepath, xyz=True, intensity=True, rgb=True, semantic=False, instance=False, remap=True |
| ) |
| data_las.initial_index = torch.arange(data_las.pos.shape[0]) |
|
|
| pos_list = [] |
| pred_list = [] |
| indices_list = [] |
| pos_offset_init = None |
| for x in range(2**pc_tiling): |
| data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las) |
| nag = transforms_dict['pre_transform'](data) |
| nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag) |
| nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag) |
| nag = nag.cuda() |
| nag = transforms_dict['on_device_test_transform'](nag) |
| with torch.no_grad(): |
| output = model(nag) |
| |
| |
| |
| |
| |
| |
| semantic_pred = output.full_res_semantic_pred(super_index_level0_to_level1=nag[0].super_index, sub_level0_to_raw=nag[0].sub) |
| pos_list.append(data.pos.cpu()) |
| indices_list.append(data.initial_index.cpu()) |
| |
| pred_list.append(semantic_pred.cpu()) |
| |
| if pos_offset_init is None: |
| pos_offset_init = nag[0].pos_offset.cpu() |
|
|
| merged_pos = torch.cat(pos_list, dim=0) |
| merged_pred = torch.cat(pred_list, dim=0) |
| merged_pos_offset = pos_offset_init + offset_initial_las |
| |
| |
| merged_indices = torch.cat(indices_list, dim=0) |
| sorted_indices = torch.argsort(merged_indices) |
| merged_pos = merged_pos[sorted_indices] |
| merged_pred = merged_pred[sorted_indices] |
| |
|
|
| pos_data = (merged_pos.numpy() / scale).astype(int) |
| x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2] |
|
|
| header = laspy.LasHeader(point_format=3, version="1.2") |
| header.scales = scale |
| header.offsets = merged_pos_offset |
| las = laspy.LasData(header) |
| las.X, las.Y, las.Z = x, y, z |
|
|
| las.add_extra_dim( |
| laspy.ExtraBytesParams(name="classif", type=np.uint8, description="Predicted class") |
| ) |
| las.classif = merged_pred.numpy().astype(np.uint8) |
|
|
| output_las = filepath.replace('.las', '_classified.las') |
| las.write(output_las) |
| print(f"[Inference] Saved classified LAS to: {output_las}") |
|
|
| def export_logits(model, cfg, transforms_dict, root_dir, scale, pc_tiling): |
| las_files = glob.glob(os.path.join(root_dir, "*", "*", "*", "*.las")) |
| for filepath in las_files: |
| print(f"\n[Export Logits] Processing: {filepath}") |
| data_las = laspy.read(filepath) |
| offset_initial_las = np.array(data_las.header.offset, dtype=np.float64) |
|
|
| data_las = read_gridnet_tile( |
| filepath, xyz=True, intensity=True, rgb=True, |
| semantic=False, instance=False, remap=True |
| ) |
|
|
| pos_list = [] |
| logits_list = [] |
| pos_offset_init = None |
|
|
| for x in range(2**pc_tiling): |
| data = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_las) |
| nag = transforms_dict['pre_transform'](data) |
| nag = NAGRemoveKeys(level=0, keys=[ |
| k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys |
| ])(nag) |
| nag = NAGRemoveKeys(level='1+', keys=[ |
| k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys |
| ])(nag) |
| nag = nag.cuda() |
| nag = transforms_dict['on_device_test_transform'](nag) |
|
|
| with torch.no_grad(): |
| output = model(nag) |
|
|
| logits = output.voxel_logits_pred(super_index=nag[0].super_index) |
|
|
| pos_list.append(nag[0].pos.cpu()) |
| logits_list.append(logits.cpu()) |
|
|
| if pos_offset_init is None: |
| pos_offset_init = nag[0].pos_offset.cpu() |
|
|
| merged_pos = torch.cat(pos_list, dim=0) |
| merged_logits = torch.cat(logits_list, dim=0) |
| merged_pos_offset = pos_offset_init + offset_initial_las |
|
|
| pos_data = (merged_pos.numpy() / scale).astype(int) |
| x, y, z = pos_data[:, 0], pos_data[:, 1], pos_data[:, 2] |
| logits = merged_logits.numpy() |
|
|
| header = laspy.LasHeader(point_format=3, version="1.2") |
| header.scales = scale |
| header.offsets = merged_pos_offset |
| las = laspy.LasData(header) |
| las.X, las.Y, las.Z = x, y, z |
|
|
| soft_logits = F.softmax(torch.tensor(logits), dim=1).numpy() |
| for i in range(soft_logits.shape[1]): |
| scaled_logits = (255 * soft_logits[:, i]).clip(0, 255).astype(np.uint8) |
| las.add_extra_dim( |
| laspy.ExtraBytesParams(name=f"sof_log{i}", type=np.uint8, description=f"Logit {i}") |
| ) |
| setattr(las, f"sof_log{i}", scaled_logits[:]) |
|
|
| output_las = filepath.replace('.las', '_with_softmax.las') |
| las.write(output_las) |
| print(f"[Export Logits] Saved softmax LAS to: {output_las}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="SPT Inference and Logits Export") |
| parser.add_argument('--mode', choices=['inference', 'export_log'], required=True, help="Choose between full-resolution inference or export logits") |
| parser.add_argument('--split', type=str, default='test', help="Data split to process (only used in inference mode) test or val split") |
| parser.add_argument('--weights', type=str, required=True, help="Path to model checkpoint") |
| parser.add_argument('--root_dir', type=str, required=True, help="Root directory of the dataset") |
| parser.add_argument('--pc_tiling', type=str, default='3', help="PC tiling for point cloud sampling") |
| |
|
|
| args = parser.parse_args() |
| cfg = init_config(overrides=["experiment=semantic/gridnet"]) |
| transforms_dict = instantiate_datamodule_transforms(cfg.datamodule) |
|
|
| model = hydra.utils.instantiate(cfg.model) |
| model = model._load_from_checkpoint(args.weights) |
| model = model.eval().cuda() |
|
|
| SCALE = [0.001, 0.001, 0.001] |
| pc_tiling = int(args.pc_tiling) |
|
|
| if args.mode == 'inference': |
| run_inference(model, cfg, transforms_dict, args.root_dir, args.split, SCALE, pc_tiling) |
| elif args.mode == 'export_log': |
| export_logits(model, cfg, transforms_dict, args.root_dir, SCALE, pc_tiling) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|