English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import torch
import logging
from pytorch_lightning import LightningDataModule
from typing import Any, Dict, List, Tuple, Union
from src.transforms import *
from src.loader import DataLoader
from src.data import NAGBatch
log = logging.getLogger(__name__)
# List of transforms not allowed for test-time augmentation
_TTA_CONFLICTS = []
# List of transforms not allowed for test prediction submission
_SUBMISSION_CONFLICTS = [
CenterPosition,
RandomTiltAndRotate,
RandomAnisotropicScale,
RandomAxisFlip,
Inliers,
Outliers,
Shuffle,
GridSampling3D,
SampleXYTiling,
SampleRecursiveMainXYAxisTiling,
SampleSubNodes,
SampleSegments,
SampleKHopSubgraphs,
SampleRadiusSubgraphs,
SampleSubNodes]
class BaseDataModule(LightningDataModule):
"""Base LightningDataModule class.
Child classes should overwrite:
```
MyDataModule(BaseDataModule):
_DATASET_CLASS = ...
_MINIDATASET_CLASS = ...
```
A DataModule implements 5 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
"""
_DATASET_CLASS = None
_MINIDATASET_CLASS = None
def __init__(
self,
data_dir: str = '',
pre_transform: Transform = None,
train_transform: Transform = None,
val_transform: Transform = None,
test_transform: Transform = None,
on_device_train_transform: Transform = None,
on_device_val_transform: Transform = None,
on_device_test_transform: Transform = None,
dataloader: DataLoader = None,
mini: bool = False,
trainval: bool = False,
val_on_test: bool = False,
tta_runs: int = None,
tta_val: bool = False,
submit: bool = False,
**kwargs):
super().__init__()
# This line allows to access init params with 'self.hparams'
# attribute also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
self.kwargs = kwargs
# Make sure `_DATASET_CLASS` and `_MINIDATASET_CLASS` have been
# specified
if self.dataset_class is None:
raise NotImplementedError
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
# Do not set the transforms directly, use self.set_transforms()
# instead to parse the input configs
self.pre_transform = None
self.train_transform = None
self.val_transform = None
self.test_transform = None
self.on_device_train_transform = None
self.on_device_val_transform = None
self.on_device_test_transform = None
# Instantiate the transforms
self.set_transforms()
# Check TTA and transforms conflicts
self.check_tta_conflicts()
# Check test submission and transforms conflicts
self.check_submission_conflicts()
@property
def dataset_class(self) -> type:
"""Return the LightningDataModule's Dataset class.
"""
if self.hparams.mini:
return self._MINIDATASET_CLASS
return self._DATASET_CLASS
@property
def train_stage(self) -> str:
"""Return either 'train' or 'trainval' depending on how
`self.hparams.trainval` is configured.
"""
return 'trainval' if self.hparams.trainval else 'train'
@property
def val_stage(self) -> str:
"""Return either 'val' or 'test' depending on how
`self.hparams.val_on_test` is configured.
"""
return 'test' if self.hparams.val_on_test else 'val'
def prepare_data(self) -> None:
"""Download and heavy preprocessing of data should be triggered
here.
However, do not use it to assign state (e.g. self.x = y) because
it will not be preserved outside this scope.
"""
self.dataset_class(
self.hparams.data_dir, stage=self.train_stage,
transform=self.train_transform, pre_transform=self.pre_transform,
on_device_transform=self.on_device_train_transform, **self.kwargs)
self.dataset_class(
self.hparams.data_dir, stage=self.val_stage,
transform=self.val_transform, pre_transform=self.pre_transform,
on_device_transform=self.on_device_val_transform, **self.kwargs)
self.dataset_class(
self.hparams.data_dir, stage='test',
transform=self.test_transform, pre_transform=self.pre_transform,
on_device_transform=self.on_device_test_transform, **self.kwargs)
def setup(self, stage=None) -> None:
"""Load data. Set variables: `self.train_dataset`,
`self.val_dataset`, `self.test_dataset`.
This method is called by lightning with both `trainer.fit()`
and `trainer.test()`, so be careful not to execute things like
random split twice!
"""
self.train_dataset = self.dataset_class(
self.hparams.data_dir, stage=self.train_stage,
transform=self.train_transform, pre_transform=self.pre_transform,
on_device_transform=self.on_device_train_transform, **self.kwargs)
self.val_dataset = self.dataset_class(
self.hparams.data_dir, stage=self.val_stage,
transform=self.val_transform, pre_transform=self.pre_transform,
on_device_transform=self.on_device_val_transform, **self.kwargs)
self.test_dataset = self.dataset_class(
self.hparams.data_dir, stage='test',
transform=self.test_transform, pre_transform=self.pre_transform,
on_device_transform=self.on_device_test_transform, **self.kwargs)
def set_transforms(self) -> None:
"""Parse in self.hparams in search for '*transform*' keys and
instantiate the corresponding transforms.
"""
t_dict = instantiate_datamodule_transforms(self.hparams, log=log)
for key, transform in t_dict.items():
setattr(self, key, transform)
def check_tta_conflicts(self) -> None:
"""Make sure the transforms are Test-Time Augmentation-friendly
"""
# Skip if not TTA
if self.hparams.tta_runs is None or self.hparams.tta_runs == 1:
return
# Make sure all transforms are test-time augmentation friendly
transforms = getattr(self.test_transform, 'transforms', [])
transforms += getattr(self.on_device_test_transform, 'transforms', [])
if self.hparams.tta_val:
transforms += getattr(self.val_transform, 'transforms', [])
transforms += getattr(self.on_device_val_transform, 'transforms', [])
for t in transforms:
if t in _TTA_CONFLICTS:
raise NotImplementedError(
f"Cannot use {t} with test-time augmentation. The "
f"following transforms are not supported: {_TTA_CONFLICTS}")
def check_submission_conflicts(self) -> None:
"""Make sure the transforms and other parameters do not prevent
test prediction submission.
"""
# Skip if submission not needed
if not self.hparams.submit:
return
# TODO
# # Make sure the test dataset does not have any tiling
# if self.test_dataset.xy_tiling is not None \
# or self.test_dataset.pc_tiling is not None:
# raise NotImplementedError(
# f"Cannot run test prediction submission for test datasets "
# f"with tiling")
# Make sure the dataloader only produces predictions for 1 cloud
# at a time
if self.hparams.dataloader.batch_size > 1:
raise NotImplementedError(
f"Cannot run test prediction submission for dataloaders "
f"with batch size > 1")
# Make sure all transforms are test submission friendly
transforms = getattr(self.test_transform, 'transforms', [])
transforms += getattr(self.on_device_test_transform, 'transforms', [])
for t in transforms:
if t in _SUBMISSION_CONFLICTS:
raise NotImplementedError(
f"Cannot use {t} with test prediction submission. The "
f"following transforms are not supported: "
f"{_SUBMISSION_CONFLICTS}")
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=self.hparams.dataloader.batch_size,
num_workers=self.hparams.dataloader.num_workers,
pin_memory=self.hparams.dataloader.pin_memory,
persistent_workers=self.hparams.dataloader.persistent_workers,
shuffle=True)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=self.hparams.dataloader.batch_size,
num_workers=self.hparams.dataloader.num_workers,
pin_memory=self.hparams.dataloader.pin_memory,
persistent_workers=self.hparams.dataloader.persistent_workers,
shuffle=False)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.test_dataset,
batch_size=self.hparams.dataloader.batch_size,
num_workers=self.hparams.dataloader.num_workers,
pin_memory=self.hparams.dataloader.pin_memory,
persistent_workers=self.hparams.dataloader.persistent_workers,
shuffle=False)
def predict_dataloader(self) -> DataLoader:
"""By default, each DataModule uses its test dataset for predict
behavior.
"""
return self.test_dataloader()
def teardown(self, stage: str = None) -> None:
"""Clean up after fit or test."""
pass
def state_dict(self) -> Dict:
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Things to do when loading checkpoint."""
pass
@torch.no_grad()
def on_after_batch_transfer(
self,
nag_list: List['NAG'],
dataloader_idx: int,
) -> Union['NAG', Tuple['NAG', Transform, int]]:
"""Intended to call on-device operations. Typically,
NAGBatch.from_nag_list and some Transforms like SampleSubNodes
and SampleSegments are faster on GPU, and we may prefer
executing those on GPU rather than in CPU-based DataLoader.
Use self.on_device_<stage>_transform, to benefit from this hook.
"""
# Since NAGBatch.from_nag_list takes a bit of time, we asked
# src.loader.DataLoader to simply pass a list of NAG objects,
# waiting for to be batched on device.
nag = NAGBatch.from_nag_list(nag_list)
del nag_list
# Here we run on_device_transform, which contains NAG transforms
# that we could not / did not want to run using CPU-based
# DataLoaders
if self.trainer.training:
on_device_transform = self.on_device_train_transform
elif self.trainer.validating:
on_device_transform = self.on_device_val_transform
elif self.trainer.testing:
on_device_transform = self.on_device_test_transform
elif self.trainer.predicting:
on_device_transform = self.on_device_test_transform
elif self.trainer.evaluating:
on_device_transform = self.on_device_test_transform
elif self.trainer.sanity_checking:
on_device_transform = self.on_device_train_transform
else:
log.warning(
'Unsure which stage we are in, defaulting to '
'self.on_device_train_transform')
on_device_transform = self.on_device_train_transform
# Skip on_device_transform if None
if on_device_transform is None:
return nag
# Apply on_device_transform only once when in training mode and
# if no test-time augmentation is required
if self.trainer.training \
or self.hparams.tta_runs is None \
or self.hparams.tta_runs == 1 or \
(self.trainer.validating and not self.hparams.tta_val):
return on_device_transform(nag)
# We return the input NAG as well as the augmentation transform
# and the number of runs. Those will be used by
# `LightningModule.step` to accumulate multiple augmented runs
return nag, on_device_transform, self.hparams.tta_runs
def __repr__(self):
return f'{self.__class__.__name__}'