File size: 8,042 Bytes
f71ac1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """LightningModule that wraps around the models, losses and optims."""
from __future__ import annotations
from typing import Any
import lightning.pytorch as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.core.optimizer import LightningOptimizer
from ml_collections import ConfigDict
from torch import nn
from torch.optim.optimizer import Optimizer
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.common.distributed import broadcast
from vis4d.common.imports import FVCORE_AVAILABLE
from vis4d.common.logging import rank_zero_info
from vis4d.common.typing import DictStrAny, GenericFunc
from vis4d.common.util import init_random_seed
from vis4d.config import instantiate_classes
from vis4d.config.typing import OptimizerConfig
from vis4d.data.typing import DictData
from vis4d.engine.connectors import DataConnector
from vis4d.engine.loss_module import LossModule
from vis4d.engine.optim import LRSchedulerWrapper, set_up_optimizers
from vis4d.model.adapter.flops import IGNORED_OPS, FlopsModelAdapter
if FVCORE_AVAILABLE:
from fvcore.nn import FlopCountAnalysis
class TrainingModule(pl.LightningModule):
"""LightningModule that wraps around the vis4d implementations.
This is a wrapper around the vis4d implementations that allows to use
pytorch-lightning for training and testing.
"""
def __init__(
self,
model_cfg: ConfigDict,
optimizers_cfg: list[OptimizerConfig],
loss_module: None | LossModule,
train_data_connector: None | DataConnector,
test_data_connector: None | DataConnector,
hyper_parameters: DictStrAny | None = None,
seed: int = -1,
ckpt_path: None | str = None,
compute_flops: bool = False,
check_unused_parameters: bool = False,
) -> None:
"""Initialize the TrainingModule.
Args:
model_cfg: The model config.
optimizers_cfg: The optimizers config.
loss_module: The loss module.
train_data_connector: The data connector to use.
test_data_connector: The data connector to use.
data_connector: The data connector to use.
hyper_parameters (DictStrAny | None, optional): The hyper
parameters to use. Defaults to None.
seed (int, optional): The integer value seed for global random
state. Defaults to -1. If -1, a random seed will be generated.
ckpt_path (str, optional): The path to the checkpoint to load.
Defaults to None.
compute_flops (bool, optional): If to compute the FLOPs of the
model. Defaults to False.
check_unused_parameters (bool, optional): If to check the
unused parameters. Defaults to False.
"""
super().__init__()
self.model_cfg = model_cfg
self.optimizers_cfg = optimizers_cfg
self.loss_module = loss_module
self.train_data_connector = train_data_connector
self.test_data_connector = test_data_connector
self.hyper_parameters = hyper_parameters
self.seed = seed
self.ckpt_path = ckpt_path
self.compute_flops = compute_flops
self.check_unused_parameters = check_unused_parameters
# Create model placeholder
self.model: nn.Module
def setup(self, stage: str) -> None:
"""Setup the model."""
if stage == "fit":
if self.seed == -1:
self.seed = init_random_seed()
self.seed = broadcast(self.seed)
self.trainer.seed = self.seed # type: ignore
seed_everything(self.seed, workers=True)
rank_zero_info(f"Global seed set to {self.seed}")
if self.hyper_parameters is not None:
self.hyper_parameters["seed"] = self.seed
if "checkpoint_callback" in self.hyper_parameters:
self.hyper_parameters.pop("checkpoint_callback")
self.save_hyperparameters(self.hyper_parameters)
# Instantiate the model after the seed has been set
self.model = instantiate_classes(self.model_cfg)
if self.ckpt_path is not None:
load_model_checkpoint(
self.model,
self.ckpt_path,
rev_keys=[(r"^model\.", ""), (r"^module\.", "")],
)
def forward( # type: ignore # pylint: disable=arguments-differ
self, data: DictData
) -> Any:
"""Forward pass through the model."""
if self.training:
assert self.train_data_connector is not None
return self.model(**self.train_data_connector(data))
assert self.test_data_connector is not None
return self.model(**self.test_data_connector(data))
def training_step( # type: ignore # pylint: disable=arguments-differ,line-too-long,unused-argument
self, batch: DictData, batch_idx: int
) -> Any:
"""Perform a single training step."""
assert self.train_data_connector is not None
out = self.model(**self.train_data_connector(batch))
assert self.loss_module is not None
total_loss, metrics = self.loss_module(out, batch)
return {
"loss": total_loss,
"metrics": metrics,
"predictions": out,
}
def validation_step( # pylint: disable=arguments-differ,line-too-long,unused-argument
self, batch: DictData, batch_idx: int, dataloader_idx: int = 0
) -> DictData:
"""Perform a single validation step."""
assert self.test_data_connector is not None
out = self.model(**self.test_data_connector(batch))
return out
def test_step( # pylint: disable=arguments-differ,line-too-long,unused-argument
self, batch: DictData, batch_idx: int, dataloader_idx: int = 0
) -> DictData:
"""Perform a single test step."""
assert self.test_data_connector is not None
if self.compute_flops:
flatten_inputs = [
self.test_data_connector(batch)[key]
for key in self.test_data_connector(batch)
]
flops_model = FlopsModelAdapter(
self.model, self.test_data_connector
)
if not FVCORE_AVAILABLE:
raise RuntimeError(
"Please install fvcore to compute FLOPs of the model."
)
flop_analyzer = FlopCountAnalysis( # pylint: disable=possibly-used-before-assignment, line-too-long
flops_model, flatten_inputs
)
flop_analyzer.set_op_handle(**{k: None for k in IGNORED_OPS})
flops = flop_analyzer.total() / 1e9
rank_zero_info(f"Flops: {flops:.2f} Gflops")
out = self.model(**self.test_data_connector(batch))
return out
def configure_optimizers(self) -> Any: # type: ignore
"""Return the optimizer to use."""
self.trainer.fit_loop.setup_data()
steps_per_epoch = len(self.trainer.train_dataloader) # type: ignore
return set_up_optimizers(
self.optimizers_cfg, [self.model], steps_per_epoch
)
def lr_scheduler_step( # type: ignore # pylint: disable=arguments-differ,line-too-long,unused-argument
self, scheduler: LRSchedulerWrapper, metric: Any | None = None
) -> None:
"""Perform a step on the lr scheduler."""
# TODO: Support metric if needed
scheduler.step(self.current_epoch)
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer: Optimizer | LightningOptimizer,
optimizer_closure: GenericFunc | None = None,
) -> None:
"""Optimizer step."""
if self.check_unused_parameters:
for name, param in self.model.named_parameters():
if param.grad is None:
rank_zero_info(name)
optimizer.step(closure=optimizer_closure)
|