| | import sys |
| | import hydra |
| | import torch |
| | import os.path as osp |
| | from tqdm import tqdm |
| | from copy import deepcopy |
| | from src.utils.hydra import init_config |
| |
|
| |
|
| | src_folder = osp.dirname(osp.dirname(osp.abspath(__file__))) |
| | sys.path.append(src_folder) |
| | sys.path.append(osp.join(src_folder, "dependencies/grid_graph/python/bin")) |
| | sys.path.append(osp.join(src_folder, "dependencies/parallel_cut_pursuit/python/wrappers")) |
| |
|
| |
|
| | __all__ = ['compute_semantic_metrics', 'compute_semantic_metrics_s3dis_6fold'] |
| |
|
| |
|
| | def compute_semantic_metrics( |
| | model, |
| | datamodule, |
| | stage='val', |
| | verbose=True): |
| | """Helper function to compute the semantic segmentation metrics of a |
| | model on a given dataset. |
| | """ |
| | |
| | from src.data import NAGBatch |
| |
|
| | |
| | |
| | |
| | if stage == 'train': |
| | dataset = datamodule.train_dataset |
| | dataloader = datamodule.train_dataloader() |
| | elif stage == 'val': |
| | dataset = datamodule.val_dataset |
| | dataloader = datamodule.val_dataloader() |
| | elif stage == 'test': |
| | dataset = datamodule.test_dataset |
| | dataloader = datamodule.test_dataloader() |
| | else: |
| | raise ValueError(f"Unknown stage : {stage}") |
| |
|
| | |
| | |
| | dataset = _set_attribute_preserving_transforms(dataset) |
| |
|
| | |
| | |
| | with torch.no_grad(): |
| | enum = tqdm(dataloader) if verbose else dataloader |
| | for nag_list in enum: |
| | nag = NAGBatch.from_nag_list([nag.cuda() for nag in nag_list]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | nag = dataset.on_device_transform(nag) |
| |
|
| | |
| | |
| | model.validation_step(nag, None) |
| |
|
| | |
| | semantic = deepcopy(model.val_cm) |
| | model.val_cm.reset() |
| |
|
| | if not verbose: |
| | return semantic |
| |
|
| | print(f"mIoU : {semantic.miou().cpu().item()}") |
| | print(f"OA : {semantic.oa().cpu().item()}") |
| | print(f"mAcc : {semantic.macc().cpu().item()}") |
| |
|
| | return semantic |
| |
|
| |
|
| | def compute_semantic_metrics_s3dis_6fold( |
| | fold_ckpt, |
| | experiment_config, |
| | stage='val', |
| | verbose=False): |
| | """Helper function to compute the semantic segmentation metrics of a |
| | model on a S3DIS 6-fold. |
| | |
| | :param fold_ckpt: dict |
| | Dictionary with S3DIS fold numbers as keys and checkpoint paths |
| | as values |
| | :param experiment_config: str |
| | Experiment config to use for inference. For instance for S3DIS |
| | with semantic segmentation: 'semantic/s3dis' |
| | :param stage: str |
| | :param verbose: bool |
| | :return: |
| | """ |
| | |
| | from src.metrics import ConfusionMatrix |
| |
|
| | |
| | |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | semantic_list = [] |
| | num_classes = None |
| |
|
| | for fold, ckpt_path in fold_ckpt.items(): |
| |
|
| | if verbose: |
| | print(f"\nFold {fold}") |
| |
|
| | |
| | cfg = init_config(overrides=[ |
| | f"experiment={experiment_config}", |
| | f"datamodule.fold={fold}", |
| | f"ckpt_path={ckpt_path}"]) |
| |
|
| | |
| | datamodule = hydra.utils.instantiate(cfg.datamodule) |
| | datamodule.prepare_data() |
| | datamodule.setup() |
| |
|
| | |
| | model = hydra.utils.instantiate(cfg.model) |
| |
|
| | |
| | model = model._load_from_checkpoint(cfg.ckpt_path) |
| | model = model.eval().cuda() |
| |
|
| | |
| | semantic = compute_semantic_metrics( |
| | model, |
| | datamodule, |
| | stage=stage, |
| | verbose=verbose) |
| |
|
| | |
| | |
| | num_classes = datamodule.train_dataset.num_classes |
| |
|
| | del model, datamodule |
| |
|
| | |
| | semantic_list.append(semantic) |
| |
|
| | |
| | semantic_6fold = ConfusionMatrix(num_classes) |
| |
|
| | |
| | for i in range(len(semantic_list)): |
| | semantic_6fold.confmat += semantic_list[i].confmat.cpu() |
| |
|
| | |
| | print(f"\n6-fold") |
| | print(f"mIoU : {semantic_6fold.miou().cpu().item()}") |
| | print(f"OA : {semantic_6fold.oa().cpu().item()}") |
| | print(f"mAcc : {semantic_6fold.macc().cpu().item()}") |
| |
|
| | return semantic_6fold, semantic_list |
| |
|
| |
|
| | def _set_attribute_preserving_transforms(dataset): |
| | """For the sake of visualization, we require that `NAGAddKeysTo` |
| | does not remove input `Data` attributes after moving them to |
| | `Data.x`, so we may visualize them. |
| | """ |
| | |
| | from src.transforms import NAGAddKeysTo |
| |
|
| | for t in dataset.on_device_transform.transforms: |
| | if isinstance(t, NAGAddKeysTo): |
| | t.delete_after = False |
| |
|
| | return dataset |
| |
|