File size: 5,797 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import sys
import hydra
import torch
import os.path as osp
from tqdm import tqdm
from copy import deepcopy
from src.utils.hydra import init_config
src_folder = osp.dirname(osp.dirname(osp.abspath(__file__)))
sys.path.append(src_folder)
sys.path.append(osp.join(src_folder, "dependencies/grid_graph/python/bin"))
sys.path.append(osp.join(src_folder, "dependencies/parallel_cut_pursuit/python/wrappers"))
__all__ = ['compute_semantic_metrics', 'compute_semantic_metrics_s3dis_6fold']
def compute_semantic_metrics(
model,
datamodule,
stage='val',
verbose=True):
"""Helper function to compute the semantic segmentation metrics of a
model on a given dataset.
"""
# Local imports to avoid import loop errors
from src.data import NAGBatch
# Pick among train, val, and test datasets. It is important to note
# that the train dataset produces augmented spherical samples of
# large scenes, while the val and test dataset
if stage == 'train':
dataset = datamodule.train_dataset
dataloader = datamodule.train_dataloader()
elif stage == 'val':
dataset = datamodule.val_dataset
dataloader = datamodule.val_dataloader()
elif stage == 'test':
dataset = datamodule.test_dataset
dataloader = datamodule.test_dataloader()
else:
raise ValueError(f"Unknown stage : {stage}")
# Prevent `NAGAddKeysTo` from removing attributes to allow
# visualizing them after model inference
dataset = _set_attribute_preserving_transforms(dataset)
# Load a dataset item. This will return the hierarchical partition
# of an entire tile, within a NAG object
with torch.no_grad():
enum = tqdm(dataloader) if verbose else dataloader
for nag_list in enum:
nag = NAGBatch.from_nag_list([nag.cuda() for nag in nag_list])
# Apply on-device transforms on the NAG object. For the
# train dataset, this will select a spherical sample of the
# larger tile and apply some data augmentations. For the
# validation and test datasets, this will prepare an entire
# tile for inference
nag = dataset.on_device_transform(nag)
# NB: we use the "validation_step" protocol here, regardless
# of the stage the data comes from
model.validation_step(nag, None)
# Actions taken from on_validation_epoch_end()
semantic = deepcopy(model.val_cm)
model.val_cm.reset()
if not verbose:
return semantic
print(f"mIoU : {semantic.miou().cpu().item()}")
print(f"OA : {semantic.oa().cpu().item()}")
print(f"mAcc : {semantic.macc().cpu().item()}")
return semantic
def compute_semantic_metrics_s3dis_6fold(
fold_ckpt,
experiment_config,
stage='val',
verbose=False):
"""Helper function to compute the semantic segmentation metrics of a
model on a S3DIS 6-fold.
:param fold_ckpt: dict
Dictionary with S3DIS fold numbers as keys and checkpoint paths
as values
:param experiment_config: str
Experiment config to use for inference. For instance for S3DIS
with semantic segmentation: 'semantic/s3dis'
:param stage: str
:param verbose: bool
:return:
"""
# Local import to avoid import loop errors
from src.metrics import ConfusionMatrix
# Very ugly fix to ignore lightning's warning messages about the
# trainer and modules not being connected
import warnings
warnings.filterwarnings("ignore")
semantic_list = []
num_classes = None
for fold, ckpt_path in fold_ckpt.items():
if verbose:
print(f"\nFold {fold}")
# Parse the configs using hydra
cfg = init_config(overrides=[
f"experiment={experiment_config}",
f"datamodule.fold={fold}",
f"ckpt_path={ckpt_path}"])
# Instantiate the datamodule
datamodule = hydra.utils.instantiate(cfg.datamodule)
datamodule.prepare_data()
datamodule.setup()
# Instantiate the model
model = hydra.utils.instantiate(cfg.model)
# Load pretrained weights from a checkpoint file
model = model._load_from_checkpoint(cfg.ckpt_path)
model = model.eval().cuda()
# Compute metrics on the fold
semantic = compute_semantic_metrics(
model,
datamodule,
stage=stage,
verbose=verbose)
# Gather some details from the model and datamodule before
# deleting them
num_classes = datamodule.train_dataset.num_classes
del model, datamodule
# Store the metrics for each fold
semantic_list.append(semantic)
# Initialize the 6-fold metrics
semantic_6fold = ConfusionMatrix(num_classes)
# Group together per-fold panoptic and semantic results
for i in range(len(semantic_list)):
semantic_6fold.confmat += semantic_list[i].confmat.cpu()
# Print computed the metrics
print(f"\n6-fold")
print(f"mIoU : {semantic_6fold.miou().cpu().item()}")
print(f"OA : {semantic_6fold.oa().cpu().item()}")
print(f"mAcc : {semantic_6fold.macc().cpu().item()}")
return semantic_6fold, semantic_list
def _set_attribute_preserving_transforms(dataset):
"""For the sake of visualization, we require that `NAGAddKeysTo`
does not remove input `Data` attributes after moving them to
`Data.x`, so we may visualize them.
"""
# Local imports to avoid import loop errors
from src.transforms import NAGAddKeysTo
for t in dataset.on_device_transform.transforms:
if isinstance(t, NAGAddKeysTo):
t.delete_after = False
return dataset
|