|
|
import sys |
|
|
import hydra |
|
|
import torch |
|
|
import os.path as osp |
|
|
from tqdm import tqdm |
|
|
from copy import deepcopy |
|
|
from src.utils.hydra import init_config |
|
|
|
|
|
|
|
|
src_folder = osp.dirname(osp.dirname(osp.abspath(__file__))) |
|
|
sys.path.append(src_folder) |
|
|
sys.path.append(osp.join(src_folder, "dependencies/grid_graph/python/bin")) |
|
|
sys.path.append(osp.join(src_folder, "dependencies/parallel_cut_pursuit/python/wrappers")) |
|
|
|
|
|
|
|
|
__all__ = ['compute_semantic_metrics', 'compute_semantic_metrics_s3dis_6fold'] |
|
|
|
|
|
|
|
|
def compute_semantic_metrics( |
|
|
model, |
|
|
datamodule, |
|
|
stage='val', |
|
|
verbose=True): |
|
|
"""Helper function to compute the semantic segmentation metrics of a |
|
|
model on a given dataset. |
|
|
""" |
|
|
|
|
|
from src.data import NAGBatch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stage == 'train': |
|
|
dataset = datamodule.train_dataset |
|
|
dataloader = datamodule.train_dataloader() |
|
|
elif stage == 'val': |
|
|
dataset = datamodule.val_dataset |
|
|
dataloader = datamodule.val_dataloader() |
|
|
elif stage == 'test': |
|
|
dataset = datamodule.test_dataset |
|
|
dataloader = datamodule.test_dataloader() |
|
|
else: |
|
|
raise ValueError(f"Unknown stage : {stage}") |
|
|
|
|
|
|
|
|
|
|
|
dataset = _set_attribute_preserving_transforms(dataset) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
enum = tqdm(dataloader) if verbose else dataloader |
|
|
for nag_list in enum: |
|
|
nag = NAGBatch.from_nag_list([nag.cuda() for nag in nag_list]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nag = dataset.on_device_transform(nag) |
|
|
|
|
|
|
|
|
|
|
|
model.validation_step(nag, None) |
|
|
|
|
|
|
|
|
semantic = deepcopy(model.val_cm) |
|
|
model.val_cm.reset() |
|
|
|
|
|
if not verbose: |
|
|
return semantic |
|
|
|
|
|
print(f"mIoU : {semantic.miou().cpu().item()}") |
|
|
print(f"OA : {semantic.oa().cpu().item()}") |
|
|
print(f"mAcc : {semantic.macc().cpu().item()}") |
|
|
|
|
|
return semantic |
|
|
|
|
|
|
|
|
def compute_semantic_metrics_s3dis_6fold( |
|
|
fold_ckpt, |
|
|
experiment_config, |
|
|
stage='val', |
|
|
verbose=False): |
|
|
"""Helper function to compute the semantic segmentation metrics of a |
|
|
model on a S3DIS 6-fold. |
|
|
|
|
|
:param fold_ckpt: dict |
|
|
Dictionary with S3DIS fold numbers as keys and checkpoint paths |
|
|
as values |
|
|
:param experiment_config: str |
|
|
Experiment config to use for inference. For instance for S3DIS |
|
|
with semantic segmentation: 'semantic/s3dis' |
|
|
:param stage: str |
|
|
:param verbose: bool |
|
|
:return: |
|
|
""" |
|
|
|
|
|
from src.metrics import ConfusionMatrix |
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
semantic_list = [] |
|
|
num_classes = None |
|
|
|
|
|
for fold, ckpt_path in fold_ckpt.items(): |
|
|
|
|
|
if verbose: |
|
|
print(f"\nFold {fold}") |
|
|
|
|
|
|
|
|
cfg = init_config(overrides=[ |
|
|
f"experiment={experiment_config}", |
|
|
f"datamodule.fold={fold}", |
|
|
f"ckpt_path={ckpt_path}"]) |
|
|
|
|
|
|
|
|
datamodule = hydra.utils.instantiate(cfg.datamodule) |
|
|
datamodule.prepare_data() |
|
|
datamodule.setup() |
|
|
|
|
|
|
|
|
model = hydra.utils.instantiate(cfg.model) |
|
|
|
|
|
|
|
|
model = model._load_from_checkpoint(cfg.ckpt_path) |
|
|
model = model.eval().cuda() |
|
|
|
|
|
|
|
|
semantic = compute_semantic_metrics( |
|
|
model, |
|
|
datamodule, |
|
|
stage=stage, |
|
|
verbose=verbose) |
|
|
|
|
|
|
|
|
|
|
|
num_classes = datamodule.train_dataset.num_classes |
|
|
|
|
|
del model, datamodule |
|
|
|
|
|
|
|
|
semantic_list.append(semantic) |
|
|
|
|
|
|
|
|
semantic_6fold = ConfusionMatrix(num_classes) |
|
|
|
|
|
|
|
|
for i in range(len(semantic_list)): |
|
|
semantic_6fold.confmat += semantic_list[i].confmat.cpu() |
|
|
|
|
|
|
|
|
print(f"\n6-fold") |
|
|
print(f"mIoU : {semantic_6fold.miou().cpu().item()}") |
|
|
print(f"OA : {semantic_6fold.oa().cpu().item()}") |
|
|
print(f"mAcc : {semantic_6fold.macc().cpu().item()}") |
|
|
|
|
|
return semantic_6fold, semantic_list |
|
|
|
|
|
|
|
|
def _set_attribute_preserving_transforms(dataset): |
|
|
"""For the sake of visualization, we require that `NAGAddKeysTo` |
|
|
does not remove input `Data` attributes after moving them to |
|
|
`Data.x`, so we may visualize them. |
|
|
""" |
|
|
|
|
|
from src.transforms import NAGAddKeysTo |
|
|
|
|
|
for t in dataset.on_device_transform.transforms: |
|
|
if isinstance(t, NAGAddKeysTo): |
|
|
t.delete_after = False |
|
|
|
|
|
return dataset |
|
|
|