|
|
import torch |
|
|
import os |
|
|
import os.path as osp |
|
|
from torch.nn import ModuleList |
|
|
import logging |
|
|
from copy import deepcopy |
|
|
from typing import Any, List, Tuple, Dict |
|
|
from pytorch_lightning import LightningModule |
|
|
from torchmetrics import MaxMetric, MeanMetric |
|
|
from pytorch_lightning.loggers.wandb import WandbLogger |
|
|
|
|
|
from src.metrics import ConfusionMatrix |
|
|
from src.utils import loss_with_target_histogram, atomic_to_histogram, \ |
|
|
init_weights, wandb_confusion_matrix, knn_2, garbage_collection_cuda, \ |
|
|
SemanticSegmentationOutput |
|
|
from src.nn import Classifier |
|
|
from src.loss import MultiLoss |
|
|
from src.optim.lr_scheduler import ON_PLATEAU_SCHEDULERS |
|
|
from src.data import NAG |
|
|
from src.transforms import Transform, NAGSaveNodeIndex |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
__all__ = ['SemanticSegmentationModule'] |
|
|
|
|
|
|
|
|
class SemanticSegmentationModule(LightningModule): |
|
|
"""A LightningModule for semantic segmentation of point clouds. |
|
|
|
|
|
:param net: torch.nn.Module |
|
|
Backbone model. This can typically be an `SPT` object |
|
|
:param criterion: torch.nn._Loss |
|
|
Loss |
|
|
:param optimizer: torch.optim.Optimizer |
|
|
Optimizer |
|
|
:param scheduler: torch.optim.lr_scheduler.LRScheduler |
|
|
Learning rate scheduler |
|
|
:param num_classes: int |
|
|
Number of classes in the dataset |
|
|
:param class_names: List[str] |
|
|
Name for each class |
|
|
:param sampling_loss: bool |
|
|
If True, the target labels will be obtained from labels of |
|
|
the points sampled in the batch at hand. This affects |
|
|
training supervision where sampling augmentations may be |
|
|
used for dropping some points or superpoints. If False, the |
|
|
target labels will be based on exact superpoint-wise |
|
|
histograms of labels computed at preprocessing time, |
|
|
disregarding potential level-0 point down-sampling |
|
|
:param loss_type: str |
|
|
Type of loss applied. |
|
|
'ce': cross-entropy (if `multi_stage_loss_lambdas` is used, |
|
|
all 1+ levels will be supervised with cross-entropy). |
|
|
'kl': Kullback-Leibler divergence (if `multi_stage_loss_lambdas` |
|
|
is used, all 1+ levels will be supervised with cross-entropy). |
|
|
'ce_kl': cross-entropy on level 1 and Kullback-Leibler for |
|
|
all levels above |
|
|
'wce': not documented for now |
|
|
'wce_kl': not documented for now |
|
|
:param weighted_loss: bool |
|
|
If True, the loss will be weighted based on the class |
|
|
frequencies computed on the train dataset. See |
|
|
`BaseDataset.get_class_weight()` for more |
|
|
:param init_linear: str |
|
|
Initialization method for all linear layers. Supports |
|
|
'xavier_uniform', 'xavier_normal', 'kaiming_uniform', |
|
|
'kaiming_normal', 'trunc_normal' |
|
|
:param init_rpe: str |
|
|
Initialization method for all linear layers producing |
|
|
relative positional encodings. Supports 'xavier_uniform', |
|
|
'xavier_normal', 'kaiming_uniform', 'kaiming_normal', |
|
|
'trunc_normal' |
|
|
:param transformer_lr_scale: float |
|
|
Scaling parameter applied to the learning rate for the |
|
|
`TransformerBlock` in each `Stage` and for the pooling block |
|
|
in `DownNFuseStage` modules. Setting this to a value lower |
|
|
than 1 mitigates exploding gradients in attentive blocks |
|
|
during training |
|
|
:param multi_stage_loss_lambdas: List[float] |
|
|
List of weights for combining losses computed on the output |
|
|
of each partition level. If not specified, the loss will |
|
|
be computed on the level 1 outputs only |
|
|
:param gc_every_n_steps: int |
|
|
Explicitly call the garbage collector after a certain number |
|
|
of steps. May involve a computation overhead. Mostly hear |
|
|
for debugging purposes when observing suspicious GPU memory |
|
|
increase during training |
|
|
:param track_val_every_n_epoch: int |
|
|
If specified, the output for a validation batch of interest |
|
|
specified with `track_val_idx` will be stored to disk every |
|
|
`track_val_every_n_epoch` epochs. Must be a multiple of |
|
|
`check_val_every_n_epoch`. See `track_batch()` for more |
|
|
:param track_val_idx: int |
|
|
If specified, the output for the `track_val_idx`th |
|
|
validation batch will be saved to disk periodically based on |
|
|
`track_val_every_n_epoch`. If `track_test_idx=-1`, predictions |
|
|
for the entire test set will be saved to disk. |
|
|
Importantly, this index is expected to match the `Dataloader`'s |
|
|
index wrt the current epoch and NOT an index wrt the `Dataset`. |
|
|
Said otherwise, if the `Dataloader(shuffle=True)` then, the |
|
|
stored batch will not be the same at each epoch. For this |
|
|
reason, if tracking the same object across training is needed, |
|
|
the `Dataloader` and the transforms should be free from any |
|
|
stochasticity |
|
|
:param track_test_idx: |
|
|
If specified, the output for the `track_test_idx`th |
|
|
test batch will be saved to disk. If `track_test_idx=-1`, |
|
|
predictions for the entire test set will be saved to disk |
|
|
:param kwargs: Dict |
|
|
Kwargs will be passed to `_load_from_checkpoint()` |
|
|
""" |
|
|
|
|
|
_IGNORED_HYPERPARAMETERS = ['net', 'criterion'] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
net: torch.nn.Module, |
|
|
criterion: 'torch.nn._Loss', |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
|
num_classes: int, |
|
|
class_names: List[str] = None, |
|
|
sampling_loss: bool = False, |
|
|
loss_type: str = 'ce_kl', |
|
|
weighted_loss: bool = True, |
|
|
init_linear: str = None, |
|
|
init_rpe: str = None, |
|
|
transformer_lr_scale: float = 1, |
|
|
multi_stage_loss_lambdas: List[float] = None, |
|
|
gc_every_n_steps: int = 0, |
|
|
track_val_every_n_epoch: int = 1, |
|
|
track_val_idx: int = None, |
|
|
track_test_idx: int = None, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.save_hyperparameters( |
|
|
logger=False, ignore=self._IGNORED_HYPERPARAMETERS) |
|
|
|
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.class_names = class_names if class_names is not None \ |
|
|
else [f'class-{i}' for i in range(num_classes)] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(criterion, MultiLoss): |
|
|
self.criterion = criterion |
|
|
elif multi_stage_loss_lambdas is not None: |
|
|
criteria = [ |
|
|
deepcopy(criterion) |
|
|
for _ in range(len(multi_stage_loss_lambdas))] |
|
|
self.criterion = MultiLoss(criteria, multi_stage_loss_lambdas) |
|
|
else: |
|
|
self.criterion = criterion |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self.criterion, MultiLoss): |
|
|
for i in range(len(self.criterion.criteria)): |
|
|
self.criterion.criteria[i].ignore_index = num_classes |
|
|
else: |
|
|
self.criterion.ignore_index = num_classes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.net = net |
|
|
if self.multi_stage_loss: |
|
|
self.net.output_stage_wise = True |
|
|
assert len(self.net.out_dim) == len(self.criterion), \ |
|
|
f"The number of items in the multi-stage loss must match the " \ |
|
|
f"number of stages in the net. Found " \ |
|
|
f"{len(self.net.out_dim)} stages, but {len(self.criterion)} " \ |
|
|
f"criteria in the loss." |
|
|
|
|
|
|
|
|
if self.multi_stage_loss: |
|
|
self.head = ModuleList([ |
|
|
Classifier(dim, num_classes) for dim in self.net.out_dim]) |
|
|
else: |
|
|
self.head = Classifier(self.net.out_dim, num_classes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init = lambda m: init_weights(m, linear=init_linear, rpe=init_rpe) |
|
|
self.net.apply(init) |
|
|
self.head.apply(init) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.train_cm = ConfusionMatrix(num_classes) |
|
|
self.val_cm = ConfusionMatrix(num_classes) |
|
|
self.test_cm = ConfusionMatrix(num_classes) |
|
|
|
|
|
|
|
|
self.train_loss = MeanMetric() |
|
|
self.val_loss = MeanMetric() |
|
|
self.test_loss = MeanMetric() |
|
|
|
|
|
|
|
|
self.val_miou_best = MaxMetric() |
|
|
self.val_oa_best = MaxMetric() |
|
|
self.val_macc_best = MaxMetric() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.test_has_target = True |
|
|
|
|
|
|
|
|
|
|
|
self.gc_every_n_steps = int(gc_every_n_steps) |
|
|
|
|
|
def forward(self, nag: NAG) -> SemanticSegmentationOutput: |
|
|
x = self.net(nag) |
|
|
logits = [head(x_) for head, x_ in zip(self.head, x)] \ |
|
|
if self.multi_stage_loss else self.head(x) |
|
|
return SemanticSegmentationOutput(logits) |
|
|
|
|
|
@property |
|
|
def multi_stage_loss(self) -> bool: |
|
|
return isinstance(self.criterion, MultiLoss) |
|
|
|
|
|
def on_fit_start(self) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_classes = self.trainer.datamodule.train_dataset.num_classes |
|
|
assert num_classes == self.num_classes, \ |
|
|
f'LightningModule has {self.num_classes} classes while the ' \ |
|
|
f'LightningDataModule has {num_classes} classes.' |
|
|
|
|
|
self.class_names = self.trainer.datamodule.train_dataset.class_names |
|
|
|
|
|
if not self.hparams.weighted_loss: |
|
|
return |
|
|
|
|
|
if not hasattr(self.criterion, 'weight'): |
|
|
log.warning( |
|
|
f"{self.criterion} does not have a 'weight' attribute. " |
|
|
f"Class weights will be ignored...") |
|
|
return |
|
|
|
|
|
|
|
|
weight = self.trainer.datamodule.train_dataset.get_class_weight() |
|
|
self.criterion.weight = weight.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
if self.trainer.check_val_every_n_epoch is not None: |
|
|
assert (self.hparams.track_val_every_n_epoch |
|
|
% self.trainer.check_val_every_n_epoch == 0), \ |
|
|
(f"Expected 'track_val_every_n_epoch' to be a multiple of " |
|
|
f"'check_val_every_n_epoch', but received " |
|
|
f"{self.hparams.track_val_every_n_epoch} and " |
|
|
f"{self.trainer.check_val_every_n_epoch} instead.") |
|
|
|
|
|
def on_train_start(self) -> None: |
|
|
|
|
|
|
|
|
|
|
|
self.val_cm.reset() |
|
|
self.val_miou_best.reset() |
|
|
self.val_oa_best.reset() |
|
|
self.val_macc_best.reset() |
|
|
|
|
|
def gc_collect(self) -> None: |
|
|
num_steps = self.trainer.fit_loop.epoch_loop._batches_that_stepped + 1 |
|
|
period = self.gc_every_n_steps |
|
|
if period is None or period < 1: |
|
|
return |
|
|
if num_steps % period == 0: |
|
|
garbage_collection_cuda() |
|
|
|
|
|
def on_train_batch_start(self, *args) -> None: |
|
|
self.gc_collect() |
|
|
|
|
|
def on_validation_batch_start(self, *args) -> None: |
|
|
self.gc_collect() |
|
|
|
|
|
def on_test_batch_start(self, *args) -> None: |
|
|
self.gc_collect() |
|
|
|
|
|
def model_step( |
|
|
self, |
|
|
batch: NAG |
|
|
) -> Tuple[torch.Tensor, SemanticSegmentationOutput]: |
|
|
|
|
|
|
|
|
output = self.step_single_run_inference(batch) \ |
|
|
if isinstance(batch, NAG) \ |
|
|
else self.step_multi_run_inference(*batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not output.has_target: |
|
|
return None, output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.multi_stage_loss: |
|
|
if self.hparams.loss_type == 'ce': |
|
|
loss = self.criterion( |
|
|
output.logits, [y.argmax(dim=1) for y in output.y_hist]) |
|
|
elif self.hparams.loss_type == 'wce': |
|
|
y_hist_dominant = [] |
|
|
for y in output.y_hist: |
|
|
y_dominant = y.argmax(dim=1) |
|
|
y_hist_dominant_ = torch.zeros_like(y) |
|
|
y_hist_dominant_[:, y_dominant] = y.sum(dim=1) |
|
|
y_hist_dominant.append(y_hist_dominant_) |
|
|
loss = 0 |
|
|
enum = zip( |
|
|
self.criterion.lambdas, |
|
|
self.criterion.criteria, |
|
|
output.logits, |
|
|
y_hist_dominant) |
|
|
for lamb, criterion, a, b in enum: |
|
|
loss = loss + lamb * loss_with_target_histogram( |
|
|
criterion, a, b) |
|
|
elif self.hparams.loss_type == 'ce_kl': |
|
|
loss = 0 |
|
|
enum = zip( |
|
|
self.criterion.lambdas, |
|
|
self.criterion.criteria, |
|
|
output.logits, |
|
|
output.y_hist) |
|
|
for i, (lamb, criterion, a, b) in enumerate(enum): |
|
|
if i == 0: |
|
|
loss = loss + criterion(a, b.argmax(dim=1)) |
|
|
continue |
|
|
loss = loss + lamb * loss_with_target_histogram( |
|
|
criterion, a, b) |
|
|
elif self.hparams.loss_type == 'wce_kl': |
|
|
loss = 0 |
|
|
enum = zip( |
|
|
self.criterion.lambdas, |
|
|
self.criterion.criteria, |
|
|
output.logits, |
|
|
output.y_hist) |
|
|
for i, (lamb, criterion, a, b) in enumerate(enum): |
|
|
if i == 0: |
|
|
y_dominant = b.argmax(dim=1) |
|
|
y_hist_dominant = torch.zeros_like(b) |
|
|
y_hist_dominant[:, y_dominant] = b.sum(dim=1) |
|
|
loss = loss + loss_with_target_histogram( |
|
|
criterion, a, y_hist_dominant) |
|
|
continue |
|
|
loss = loss + lamb * loss_with_target_histogram( |
|
|
criterion, a, b) |
|
|
elif self.hparams.loss_type == 'kl': |
|
|
loss = 0 |
|
|
enum = zip( |
|
|
self.criterion.lambdas, |
|
|
self.criterion.criteria, |
|
|
output.logits, |
|
|
output.y_hist) |
|
|
for lamb, criterion, a, b in enum: |
|
|
loss = loss + lamb * loss_with_target_histogram( |
|
|
criterion, a, b) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unknown multi-stage loss '{self.hparams.loss_type}'") |
|
|
else: |
|
|
if self.hparams.loss_type == 'ce': |
|
|
loss = self.criterion(output.logits, output.y_hist.argmax(dim=1)) |
|
|
elif self.hparams.loss_type == 'wce': |
|
|
y_dominant = output.y_hist.argmax(dim=1) |
|
|
y_hist_dominant = torch.zeros_like(output.y_hist) |
|
|
y_hist_dominant[:, y_dominant] = output.y_hist.sum(dim=1) |
|
|
loss = loss_with_target_histogram( |
|
|
self.criterion, output.logits, y_hist_dominant) |
|
|
elif self.hparams.loss_type == 'kl': |
|
|
loss = loss_with_target_histogram( |
|
|
self.criterion, output.logits, output.y_hist) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unknown single-stage loss '{self.hparams.loss_type}'") |
|
|
|
|
|
return loss, output |
|
|
|
|
|
def step_single_run_inference(self, nag: NAG) -> SemanticSegmentationOutput: |
|
|
"""Single-run inference |
|
|
""" |
|
|
output = self.forward(nag) |
|
|
output = self.get_target(nag, output) |
|
|
return output |
|
|
|
|
|
def step_multi_run_inference( |
|
|
self, |
|
|
nag: NAG, |
|
|
transform: Transform, |
|
|
num_runs: int, |
|
|
key: str = 'tta_node_id' |
|
|
) -> SemanticSegmentationOutput: |
|
|
"""Multi-run inference, typically with test-time augmentation. |
|
|
See `BaseDataModule.on_after_batch_transfer` |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
transform.transforms = [NAGSaveNodeIndex(key=key)] \ |
|
|
+ transform.transforms |
|
|
|
|
|
|
|
|
|
|
|
output_multi = self._create_empty_output(nag) |
|
|
|
|
|
|
|
|
output_multi = self.get_target(nag, output_multi) |
|
|
|
|
|
|
|
|
|
|
|
seen = torch.zeros(nag.num_points[1], dtype=torch.bool) |
|
|
|
|
|
for i_run in range(num_runs): |
|
|
|
|
|
|
|
|
nag_ = transform(nag.clone()) |
|
|
|
|
|
|
|
|
output = self.forward(nag_) |
|
|
|
|
|
|
|
|
output_multi = self._update_output_multi( |
|
|
output_multi, nag, output, nag_, key) |
|
|
|
|
|
|
|
|
node_id = nag_[1][key] |
|
|
seen[node_id] = True |
|
|
|
|
|
|
|
|
transform.transforms = transform.transforms[1:] |
|
|
|
|
|
|
|
|
|
|
|
unseen_idx = torch.where(~seen)[0] |
|
|
batch = nag[1].batch |
|
|
if unseen_idx.shape[0] > 0: |
|
|
seen_idx = torch.where(seen)[0] |
|
|
x_search = nag[1].pos[seen_idx] |
|
|
x_query = nag[1].pos[unseen_idx] |
|
|
neighbors = knn_2( |
|
|
x_search, |
|
|
x_query, |
|
|
1, |
|
|
r_max=2, |
|
|
batch_search=batch[seen_idx] if batch is not None else None, |
|
|
batch_query=batch[unseen_idx] if batch is not None else None)[0] |
|
|
num_unseen = unseen_idx.shape[0] |
|
|
num_seen = seen_idx.shape[0] |
|
|
num_left_out = (neighbors == -1).sum().long() |
|
|
if num_left_out > 0: |
|
|
log.warning( |
|
|
f"Could not find a neighbor for all unseen nodes: num_seen=" |
|
|
f"{num_seen}, num_unseen={num_unseen}, num_left_out=" |
|
|
f"{num_left_out}. These left out nodes will default to " |
|
|
f"label-0 class prediction. Consider sampling less nodes " |
|
|
f"in the augmentations, or increase the search radius") |
|
|
|
|
|
|
|
|
output_multi = self._propagate_output_to_unseen_neighbors( |
|
|
output_multi, nag, seen, neighbors) |
|
|
|
|
|
return output_multi |
|
|
|
|
|
def _create_empty_output(self, nag: NAG) -> SemanticSegmentationOutput: |
|
|
"""Local helper method to initialize an empty output for |
|
|
multi-run prediction. |
|
|
""" |
|
|
device = nag.device |
|
|
num_classes = self.num_classes |
|
|
if self.multi_stage_loss: |
|
|
logits = [ |
|
|
torch.zeros(num_points, num_classes, device=device) |
|
|
for num_points in nag.num_points[1:]] |
|
|
else: |
|
|
logits = torch.zeros(nag.num_points[1], num_classes, device=device) |
|
|
return SemanticSegmentationOutput(logits) |
|
|
|
|
|
@staticmethod |
|
|
def _update_output_multi( |
|
|
output_multi: SemanticSegmentationOutput, |
|
|
nag: NAG, |
|
|
output: SemanticSegmentationOutput, |
|
|
nag_transformed: NAG, |
|
|
key: str |
|
|
) -> SemanticSegmentationOutput: |
|
|
"""Local helper method to accumulate multiple predictions on |
|
|
the same--or part of the same--point cloud. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if output.multi_stage: |
|
|
for i in range(len(output.logits)): |
|
|
node_id = nag_transformed[i + 1][key] |
|
|
output_multi.logits[i][node_id] += output.logits[i] |
|
|
else: |
|
|
node_id = nag_transformed[1][key] |
|
|
output_multi.logits[node_id] += output.logits |
|
|
return output_multi |
|
|
|
|
|
@staticmethod |
|
|
def _propagate_output_to_unseen_neighbors( |
|
|
output: SemanticSegmentationOutput, |
|
|
nag: NAG, |
|
|
seen: torch.Tensor, |
|
|
neighbors: torch.Tensor |
|
|
) -> SemanticSegmentationOutput: |
|
|
"""Local helper method to propagate predictions to unseen |
|
|
neighbors. |
|
|
""" |
|
|
seen_idx = torch.where(seen)[0] |
|
|
unseen_idx = torch.where(~seen)[0] |
|
|
if output.multi_stage: |
|
|
output.logits[0][unseen_idx] = output.logits[0][seen_idx][neighbors] |
|
|
else: |
|
|
output.logits[unseen_idx] = output.logits[seen_idx][neighbors] |
|
|
return output |
|
|
|
|
|
def get_target( |
|
|
self, |
|
|
nag: NAG, |
|
|
output: SemanticSegmentationOutput |
|
|
) -> SemanticSegmentationOutput: |
|
|
"""Recover the target histogram of labels from the NAG object. |
|
|
The labels will be saved in `output.y_hist`. |
|
|
|
|
|
If the `multi_stage_loss=True`, a list of label histograms |
|
|
will be recovered (one for each prediction level). |
|
|
|
|
|
If `sampling_loss=True`, the histogram(s) will be updated based |
|
|
on the actual level-0 point sampling. That is, superpoints will |
|
|
be supervised by the labels of the sampled points at train time, |
|
|
rather than the true full-resolution label histogram. |
|
|
|
|
|
If no labels are found in the NAG, `output.y_hist` will be None. |
|
|
""" |
|
|
|
|
|
if self.hparams.sampling_loss and nag[0].y is None: |
|
|
output.y_hist = None |
|
|
return output |
|
|
elif self.multi_stage_loss: |
|
|
for i in range(1, nag.num_levels): |
|
|
if nag[i].y is None: |
|
|
output.y_hist = None |
|
|
return output |
|
|
elif nag[1].y is None: |
|
|
output.y_hist = None |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.hparams.sampling_loss and self.multi_stage_loss: |
|
|
y_hist = [ |
|
|
atomic_to_histogram( |
|
|
nag[0].y, |
|
|
nag.get_super_index(i_level), n_bins=self.num_classes + 1) |
|
|
for i_level in range(1, nag.num_levels)] |
|
|
|
|
|
elif self.hparams.sampling_loss: |
|
|
idx = nag[0].super_index |
|
|
y = nag[0].y |
|
|
|
|
|
|
|
|
|
|
|
y_hist = atomic_to_histogram(y, idx, n_bins=self.num_classes + 1) |
|
|
|
|
|
elif self.multi_stage_loss: |
|
|
y_hist = [nag[i_level].y for i_level in range(1, nag.num_levels)] |
|
|
|
|
|
else: |
|
|
y_hist = nag[1].y |
|
|
|
|
|
|
|
|
output.y_hist = y_hist |
|
|
|
|
|
return output |
|
|
|
|
|
def training_step( |
|
|
self, |
|
|
batch: NAG, |
|
|
batch_idx: int |
|
|
) -> torch.Tensor: |
|
|
loss, output = self.model_step(batch) |
|
|
|
|
|
|
|
|
self.train_step_update_metrics(loss, output) |
|
|
self.train_step_log_metrics() |
|
|
|
|
|
|
|
|
del output |
|
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
def train_step_update_metrics( |
|
|
self, |
|
|
loss: torch.Tensor, |
|
|
output: SemanticSegmentationOutput |
|
|
) -> None: |
|
|
"""Update train metrics after a single step, with the content of |
|
|
the output object. |
|
|
""" |
|
|
self.train_loss(loss.detach()) |
|
|
self.train_cm(output.semantic_pred().detach(), output.semantic_target.detach()) |
|
|
|
|
|
def train_step_log_metrics(self) -> None: |
|
|
"""Log train metrics after a single step with the content of the |
|
|
output object. |
|
|
""" |
|
|
self.log( |
|
|
"train/loss", self.train_loss, on_step=False, on_epoch=True, |
|
|
prog_bar=True) |
|
|
|
|
|
def on_train_epoch_end(self) -> None: |
|
|
|
|
|
if self.trainer.num_devices > 1: |
|
|
epoch_cm = torch.sum(self.all_gather(self.train_cm.confmat), dim=0) |
|
|
epoch_cm = ConfusionMatrix(self.num_classes).from_confusion_matrix(epoch_cm) |
|
|
else: |
|
|
epoch_cm = self.train_cm |
|
|
|
|
|
|
|
|
self.log("train/miou", epoch_cm.miou(), prog_bar=True, rank_zero_only=True) |
|
|
self.log("train/oa", epoch_cm.oa(), prog_bar=True, rank_zero_only=True) |
|
|
self.log("train/macc", epoch_cm.macc(), prog_bar=True, rank_zero_only=True) |
|
|
for iou, seen, name in zip(*epoch_cm.iou(), self.class_names): |
|
|
if seen: |
|
|
self.log(f"train/iou_{name}", iou, prog_bar=True, rank_zero_only=True) |
|
|
|
|
|
|
|
|
self.train_cm.reset() |
|
|
epoch_cm.reset() |
|
|
|
|
|
def validation_step( |
|
|
self, |
|
|
batch: NAG, |
|
|
batch_idx: int |
|
|
) -> None: |
|
|
loss, output = self.model_step(batch) |
|
|
|
|
|
|
|
|
self.validation_step_update_metrics(loss, output) |
|
|
self.validation_step_log_metrics() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epoch = self.current_epoch + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
track_epoch = epoch % self.hparams.track_val_every_n_epoch == 0 |
|
|
track_batch = batch_idx == self.hparams.track_val_idx |
|
|
track_all_batches = self.hparams.track_val_idx == -1 |
|
|
if track_epoch and (track_batch or track_all_batches): |
|
|
self.track_batch(batch, batch_idx, output) |
|
|
|
|
|
|
|
|
del output |
|
|
|
|
|
def validation_step_update_metrics( |
|
|
self, |
|
|
loss: torch.Tensor, |
|
|
output: SemanticSegmentationOutput |
|
|
) -> None: |
|
|
"""Update validation metrics with the content of the output |
|
|
object. |
|
|
""" |
|
|
self.val_loss(loss.detach()) |
|
|
self.val_cm(output.semantic_pred().detach(), output.semantic_target.detach()) |
|
|
|
|
|
def validation_step_log_metrics(self) -> None: |
|
|
"""Log validation metrics after a single step with the content |
|
|
of the output object. |
|
|
""" |
|
|
self.log( |
|
|
"val/loss", self.val_loss, on_step=False, on_epoch=True, |
|
|
prog_bar=True) |
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
|
|
|
|
if self.trainer.num_devices > 1: |
|
|
epoch_cm = torch.sum(self.all_gather(self.val_cm.confmat), dim=0) |
|
|
epoch_cm = ConfusionMatrix(self.num_classes).from_confusion_matrix(epoch_cm) |
|
|
else: |
|
|
epoch_cm = self.val_cm |
|
|
|
|
|
miou = epoch_cm.miou() |
|
|
oa = epoch_cm.oa() |
|
|
macc = epoch_cm.macc() |
|
|
|
|
|
|
|
|
self.val_miou_best(miou) |
|
|
self.val_oa_best(oa) |
|
|
self.val_macc_best(macc) |
|
|
|
|
|
|
|
|
self.log("val/miou", miou, prog_bar=True, rank_zero_only=True) |
|
|
self.log("val/oa", oa, prog_bar=True, rank_zero_only=True) |
|
|
self.log("val/macc", macc, prog_bar=True, rank_zero_only=True) |
|
|
for iou, seen, name in zip(*epoch_cm.iou(), self.class_names): |
|
|
if seen: |
|
|
self.log(f"val/iou_{name}", iou, prog_bar=True, rank_zero_only=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.log("val/miou_best", self.val_miou_best.compute(), prog_bar=True, rank_zero_only=True) |
|
|
self.log("val/oa_best", self.val_oa_best.compute(), prog_bar=True, rank_zero_only=True) |
|
|
self.log("val/macc_best", self.val_macc_best.compute(), prog_bar=True, rank_zero_only=True) |
|
|
|
|
|
|
|
|
self.val_cm.reset() |
|
|
epoch_cm.reset() |
|
|
|
|
|
def on_test_start(self) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.submission_dir = self.trainer.datamodule.test_dataset.submission_dir |
|
|
self.on_fit_start() |
|
|
|
|
|
def test_step(self, batch: NAG, batch_idx: int) -> None: |
|
|
loss, output = self.model_step(batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = 0 if loss is None else loss |
|
|
|
|
|
|
|
|
|
|
|
if not output.has_target: |
|
|
self.test_has_target = False |
|
|
|
|
|
|
|
|
self.test_step_update_metrics(loss, output) |
|
|
self.test_step_log_metrics() |
|
|
|
|
|
|
|
|
if self.trainer.datamodule.hparams.submit: |
|
|
nag = batch if isinstance(batch, NAG) else batch[0] |
|
|
l0_pos = nag[0].pos.detach().cpu() |
|
|
l0_pred = output.semantic_pred()[nag[0].super_index].detach().cpu() |
|
|
self.trainer.datamodule.test_dataset.make_submission( |
|
|
batch_idx, l0_pred, l0_pos, submission_dir=self.submission_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
track_batch = batch_idx == self.hparams.track_test_idx |
|
|
track_all_batches = self.hparams.track_test_idx == -1 |
|
|
if track_batch or track_all_batches: |
|
|
self.track_batch(batch, batch_idx, output) |
|
|
|
|
|
|
|
|
del output |
|
|
|
|
|
def test_step_update_metrics( |
|
|
self, |
|
|
loss: torch.Tensor, |
|
|
output: SemanticSegmentationOutput |
|
|
) -> None: |
|
|
"""Update test metrics with the content of the output object. |
|
|
""" |
|
|
|
|
|
|
|
|
if not self.test_has_target: |
|
|
return |
|
|
|
|
|
self.test_loss(loss.detach()) |
|
|
self.test_cm(output.semantic_pred().detach(), output.semantic_target.detach()) |
|
|
|
|
|
def test_step_log_metrics(self) -> None: |
|
|
"""Log test metrics after a single step with the content of the |
|
|
output object. |
|
|
""" |
|
|
|
|
|
|
|
|
if not self.test_has_target: |
|
|
return |
|
|
|
|
|
self.log( |
|
|
"test/loss", self.test_loss, on_step=False, on_epoch=True, |
|
|
prog_bar=True) |
|
|
|
|
|
def on_test_epoch_end(self) -> None: |
|
|
|
|
|
if self.trainer.datamodule.hparams.submit: |
|
|
self.trainer.datamodule.test_dataset.finalize_submission( |
|
|
self.submission_dir) |
|
|
|
|
|
|
|
|
if not self.test_has_target: |
|
|
self.test_cm.reset() |
|
|
return |
|
|
|
|
|
if self.trainer.num_devices > 1: |
|
|
epoch_cm = torch.sum(self.all_gather(self.test_cm.confmat), dim=0) |
|
|
epoch_cm = ConfusionMatrix(self.num_classes).from_confusion_matrix(epoch_cm) |
|
|
else: |
|
|
epoch_cm = self.test_cm |
|
|
|
|
|
|
|
|
self.log("test/miou", epoch_cm.miou(), prog_bar=True, rank_zero_only=True) |
|
|
self.log("test/oa", epoch_cm.oa(), prog_bar=True, rank_zero_only=True) |
|
|
self.log("test/macc", epoch_cm.macc(), prog_bar=True, rank_zero_only=True) |
|
|
for iou, seen, name in zip(*epoch_cm.iou(), self.class_names): |
|
|
if seen: |
|
|
self.log(f"test/iou_{name}", iou, prog_bar=True, rank_zero_only=True) |
|
|
|
|
|
|
|
|
if isinstance(self.logger, WandbLogger): |
|
|
self.logger.experiment.log({ |
|
|
"test/cm": wandb_confusion_matrix( |
|
|
epoch_cm.confmat, class_names=self.class_names)}) |
|
|
|
|
|
|
|
|
self.test_cm.reset() |
|
|
epoch_cm.reset() |
|
|
|
|
|
def predict_step( |
|
|
self, |
|
|
batch: NAG, |
|
|
batch_idx: int |
|
|
) -> Tuple[NAG, SemanticSegmentationOutput]: |
|
|
_, output = self.model_step(batch) |
|
|
return batch, output |
|
|
|
|
|
def track_batch( |
|
|
self, |
|
|
batch: NAG, |
|
|
batch_idx: int, |
|
|
output: SemanticSegmentationOutput, |
|
|
folder: str = None |
|
|
) -> None: |
|
|
"""Store a batch prediction to disk. The corresponding `NAG` |
|
|
object will be populated with semantic segmentation predictions |
|
|
for: |
|
|
- levels 1+ if `multi_stage` output (i.e. loss supervision on |
|
|
levels 1 and above) |
|
|
- only level 1 otherwise |
|
|
|
|
|
Besides, we also pre-compute the level-0 predictions as this is |
|
|
frequently required for downstream tasks. However, we choose not |
|
|
to compute the full-resolution predictions for the sake of disk |
|
|
memory. |
|
|
|
|
|
If a `folder` is provided, the NAG will be saved there under: |
|
|
<folder>/predictions/<stage>/<epoch>/batch_<batch_idx>.h5 |
|
|
If not, the folder will be the logger's directory, if any. |
|
|
If not, the current working directory will be used. |
|
|
|
|
|
:param batch: NAG |
|
|
Object that will be stored to disk. Before that, the |
|
|
model predictions will be added to the attributes of each |
|
|
level, to facilitate downstream use of the stored `NAG` |
|
|
:param batch_idx: int |
|
|
Index of the batch to be stored |
|
|
:param output: SemanticSegmentationOutput |
|
|
Output of `self.model_step()` |
|
|
:param folder: str |
|
|
Path where to save the tracked batch. If not provided, the |
|
|
logger's saving directory will be used as fallback. If not |
|
|
logger is found, the current working directory will be used |
|
|
:return: |
|
|
""" |
|
|
|
|
|
if not isinstance(batch, NAG): |
|
|
raise NotImplementedError( |
|
|
f"Expected as NAG, but received a {type(batch)}. Are you " |
|
|
f"perhaps running multi-run inference ? If so, this is not " |
|
|
f"compatible with batch_saving, please deactivate either one.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not output.multi_stage: |
|
|
logits = output.logits |
|
|
pred = torch.argmax(logits, dim=1) |
|
|
|
|
|
|
|
|
batch[1].semantic_pred = pred |
|
|
batch[1].logits = logits |
|
|
|
|
|
|
|
|
batch[0].semantic_pred = pred[batch[0].super_index] |
|
|
batch[0].logits = logits[batch[0].super_index] |
|
|
|
|
|
else: |
|
|
for i, _logits in enumerate(output.logits): |
|
|
logits = _logits |
|
|
pred = torch.argmax(logits, dim=1) |
|
|
|
|
|
|
|
|
batch[i + 1].semantic_pred = pred |
|
|
batch[i + 1].logits = logits |
|
|
|
|
|
|
|
|
if i > 0: |
|
|
continue |
|
|
batch[0].semantic_pred = pred[batch[0].super_index] |
|
|
batch[0].logits = logits[batch[0].super_index] |
|
|
|
|
|
|
|
|
batch = batch.detach().cpu() |
|
|
|
|
|
|
|
|
if self.trainer is None: |
|
|
stage = 'unknown_stage' |
|
|
elif self.trainer.training: |
|
|
stage = 'train' |
|
|
elif self.trainer.validating: |
|
|
stage = 'val' |
|
|
elif self.trainer.testing: |
|
|
stage = 'test' |
|
|
elif self.trainer.predicting: |
|
|
stage = 'predict' |
|
|
else: |
|
|
stage = 'unknown_stage' |
|
|
if folder is None: |
|
|
if self.logger and self.logger.save_dir: |
|
|
folder = self.logger.save_dir |
|
|
else: |
|
|
folder = '' |
|
|
folder = osp.join(folder, 'predictions', stage, str(self.current_epoch)) |
|
|
if not osp.isdir(folder): |
|
|
os.makedirs(folder, exist_ok=True) |
|
|
|
|
|
|
|
|
path = osp.join(folder, f"batch_{batch_idx}.h5") |
|
|
batch.save(path) |
|
|
log.info(f'Stored predictions at: "{path}"') |
|
|
|
|
|
|
|
|
if isinstance(self.logger, WandbLogger): |
|
|
pass |
|
|
|
|
|
def configure_optimizers(self) -> Dict: |
|
|
"""Choose what optimizers and learning-rate schedulers to use in your optimization. |
|
|
Normally you'd need one. But in the case of GANs or similar you might have multiple. |
|
|
|
|
|
Examples: |
|
|
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers |
|
|
""" |
|
|
|
|
|
t_names = ['transformer_blocks', 'down_pool_block'] |
|
|
lr = self.hparams.optimizer.keywords['lr'] |
|
|
t_lr = lr * self.hparams.transformer_lr_scale |
|
|
param_dicts = [ |
|
|
{ |
|
|
"params": [ |
|
|
p |
|
|
for n, p in self.named_parameters() |
|
|
if all([t not in n for t in t_names]) and p.requires_grad]}, |
|
|
{ |
|
|
"params": [ |
|
|
p |
|
|
for n, p in self.named_parameters() |
|
|
if any([t in n for t in t_names]) and p.requires_grad], |
|
|
"lr": t_lr}] |
|
|
optimizer = self.hparams.optimizer(params=param_dicts) |
|
|
|
|
|
|
|
|
if self.hparams.scheduler is None: |
|
|
return {"optimizer": optimizer} |
|
|
|
|
|
|
|
|
|
|
|
scheduler = self.hparams.scheduler(optimizer=optimizer) |
|
|
reduce_on_plateau = isinstance(scheduler, ON_PLATEAU_SCHEDULERS) |
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": { |
|
|
"scheduler": scheduler, |
|
|
"monitor": "val/loss", |
|
|
"interval": "epoch", |
|
|
"frequency": 1, |
|
|
"reduce_on_plateau": reduce_on_plateau}} |
|
|
|
|
|
def load_state_dict( |
|
|
self, |
|
|
state_dict: Dict, |
|
|
strict: bool = True |
|
|
) -> None: |
|
|
"""Basic `load_state_dict` from `torch.nn.Module` with a bit of |
|
|
acrobatics due to `criterion.weight`. |
|
|
|
|
|
This attribute, when present in the `state_dict`, causes |
|
|
`load_state_dict` to crash. More precisely, `criterion.weight` |
|
|
is holding the per-class weights for classification losses. |
|
|
""" |
|
|
|
|
|
class_weight_bckp = self.criterion.weight |
|
|
self.criterion.weight = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keys = [] |
|
|
for key in state_dict.keys(): |
|
|
if key.startswith('criterion.') and key.endswith('.weight'): |
|
|
keys.append(key) |
|
|
class_weight = state_dict[keys[0]] if len(keys) > 0 else None |
|
|
for key in keys: |
|
|
state_dict.pop(key) |
|
|
|
|
|
|
|
|
super().load_state_dict(state_dict, strict=strict) |
|
|
|
|
|
|
|
|
self.criterion.weight = class_weight if class_weight is not None \ |
|
|
else class_weight_bckp |
|
|
|
|
|
def _load_from_checkpoint( |
|
|
self, |
|
|
checkpoint_path: str, |
|
|
**kwargs |
|
|
) -> 'SemanticSegmentationModule': |
|
|
"""Simpler version of `LightningModule.load_from_checkpoint()` |
|
|
for easier use: no need to explicitly pass `model.net`, |
|
|
`model.criterion`, etc. |
|
|
""" |
|
|
return self.__class__.load_from_checkpoint( |
|
|
checkpoint_path, net=self.net, criterion=self.criterion, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def sanitize_step_output(out_dict: Dict) -> Dict: |
|
|
"""Helper to be used for cleaning up the `_step` functions. |
|
|
Lightning expects those to return the loss (on GPU, with the |
|
|
computation graph intact for the backward step. Any other |
|
|
element passed in this dict will be detached and moved to CPU |
|
|
here. This avoids memory leak. |
|
|
""" |
|
|
return { |
|
|
k: v if ((k == "loss") or (not isinstance(v, torch.Tensor))) |
|
|
else v.detach().cpu() |
|
|
for k, v in out_dict.items()} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import hydra |
|
|
import omegaconf |
|
|
import pyrootutils |
|
|
|
|
|
root = str(pyrootutils.setup_root(__file__, pythonpath=True)) |
|
|
cfg = omegaconf.OmegaConf.load(root + "/configs/model/semantic/spt-2.yaml") |
|
|
_ = hydra.utils.instantiate(cfg) |
|
|
|