| | from detectron2.engine import AMPTrainer |
| | import torch |
| | import time |
| | import logging |
| |
|
| | logger = logging.getLogger("detectron2") |
| |
|
| | import typing |
| | from collections import defaultdict |
| | import tabulate |
| | from torch import nn |
| |
|
| |
|
| | def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]: |
| | """ |
| | Count parameters of a model and its submodules. |
| | |
| | Args: |
| | model: a torch module |
| | |
| | Returns: |
| | dict (str-> int): the key is either a parameter name or a module name. |
| | The value is the number of elements in the parameter, or in all |
| | parameters of the module. The key "" corresponds to the total |
| | number of parameters of the model. |
| | """ |
| | r = defaultdict(int) |
| | for name, prm in model.named_parameters(): |
| | if trainable_only: |
| | if not prm.requires_grad: |
| | continue |
| | size = prm.numel() |
| | name = name.split(".") |
| | for k in range(0, len(name) + 1): |
| | prefix = ".".join(name[:k]) |
| | r[prefix] += size |
| | return r |
| |
|
| |
|
| | def parameter_count_table( |
| | model: nn.Module, max_depth: int = 3, trainable_only: bool = False |
| | ) -> str: |
| | """ |
| | Format the parameter count of the model (and its submodules or parameters) |
| | in a nice table. It looks like this: |
| | |
| | :: |
| | |
| | | name | #elements or shape | |
| | |:--------------------------------|:---------------------| |
| | | model | 37.9M | |
| | | backbone | 31.5M | |
| | | backbone.fpn_lateral3 | 0.1M | |
| | | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | |
| | | backbone.fpn_lateral3.bias | (256,) | |
| | | backbone.fpn_output3 | 0.6M | |
| | | backbone.fpn_output3.weight | (256, 256, 3, 3) | |
| | | backbone.fpn_output3.bias | (256,) | |
| | | backbone.fpn_lateral4 | 0.3M | |
| | | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | |
| | | backbone.fpn_lateral4.bias | (256,) | |
| | | backbone.fpn_output4 | 0.6M | |
| | | backbone.fpn_output4.weight | (256, 256, 3, 3) | |
| | | backbone.fpn_output4.bias | (256,) | |
| | | backbone.fpn_lateral5 | 0.5M | |
| | | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | |
| | | backbone.fpn_lateral5.bias | (256,) | |
| | | backbone.fpn_output5 | 0.6M | |
| | | backbone.fpn_output5.weight | (256, 256, 3, 3) | |
| | | backbone.fpn_output5.bias | (256,) | |
| | | backbone.top_block | 5.3M | |
| | | backbone.top_block.p6 | 4.7M | |
| | | backbone.top_block.p7 | 0.6M | |
| | | backbone.bottom_up | 23.5M | |
| | | backbone.bottom_up.stem | 9.4K | |
| | | backbone.bottom_up.res2 | 0.2M | |
| | | backbone.bottom_up.res3 | 1.2M | |
| | | backbone.bottom_up.res4 | 7.1M | |
| | | backbone.bottom_up.res5 | 14.9M | |
| | | ...... | ..... | |
| | |
| | Args: |
| | model: a torch module |
| | max_depth (int): maximum depth to recursively print submodules or |
| | parameters |
| | |
| | Returns: |
| | str: the table to be printed |
| | """ |
| | count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only) |
| | |
| | param_shape: typing.Dict[str, typing.Tuple] = { |
| | k: tuple(v.shape) for k, v in model.named_parameters() |
| | } |
| |
|
| | |
| | table: typing.List[typing.Tuple] = [] |
| |
|
| | def format_size(x: int) -> str: |
| | if x > 1e8: |
| | return "{:.1f}G".format(x / 1e9) |
| | if x > 1e5: |
| | return "{:.1f}M".format(x / 1e6) |
| | if x > 1e2: |
| | return "{:.1f}K".format(x / 1e3) |
| | return str(x) |
| |
|
| | def fill(lvl: int, prefix: str) -> None: |
| | if lvl >= max_depth: |
| | return |
| | for name, v in count.items(): |
| | if name.count(".") == lvl and name.startswith(prefix): |
| | indent = " " * (lvl + 1) |
| | if name in param_shape: |
| | table.append((indent + name, indent + str(param_shape[name]))) |
| | else: |
| | table.append((indent + name, indent + format_size(v))) |
| | fill(lvl + 1, name + ".") |
| |
|
| | table.append(("model", format_size(count.pop("")))) |
| | fill(0, "") |
| |
|
| | old_ws = tabulate.PRESERVE_WHITESPACE |
| | tabulate.PRESERVE_WHITESPACE = True |
| | tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe") |
| | tabulate.PRESERVE_WHITESPACE = old_ws |
| | return tab |
| |
|
| |
|
| | def cycle(iterable): |
| | while True: |
| | for x in iterable: |
| | yield x |
| |
|
| | class MattingTrainer(AMPTrainer): |
| | def __init__(self, model, data_loader, optimizer, grad_scaler=None): |
| | super().__init__(model, data_loader, optimizer, grad_scaler=None) |
| | self.data_loader_iter = iter(cycle(self.data_loader)) |
| |
|
| | |
| | logger.info("All parameters: \n" + parameter_count_table(model)) |
| | logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8)) |
| |
|
| | def run_step(self): |
| | """ |
| | Implement the AMP training logic. |
| | """ |
| | assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
| | assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
| | from torch.cuda.amp import autocast |
| |
|
| | |
| | start = time.perf_counter() |
| | data = next(self.data_loader_iter) |
| | data_time = time.perf_counter() - start |
| |
|
| | with autocast(): |
| | loss_dict = self.model(data) |
| | if isinstance(loss_dict, torch.Tensor): |
| | losses = loss_dict |
| | loss_dict = {"total_loss": loss_dict} |
| | else: |
| | losses = sum(loss_dict.values()) |
| |
|
| | self.optimizer.zero_grad() |
| | self.grad_scaler.scale(losses).backward() |
| |
|
| | self._write_metrics(loss_dict, data_time) |
| |
|
| | self.grad_scaler.step(self.optimizer) |
| | self.grad_scaler.update() |