| from detectron2.engine import AMPTrainer |
| import torch |
| import time |
| import logging |
|
|
| logger = logging.getLogger("detectron2") |
|
|
| import typing |
| from collections import defaultdict |
| import tabulate |
| from torch import nn |
|
|
|
|
| def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]: |
| """ |
| Count parameters of a model and its submodules. |
| |
| Args: |
| model: a torch module |
| |
| Returns: |
| dict (str-> int): the key is either a parameter name or a module name. |
| The value is the number of elements in the parameter, or in all |
| parameters of the module. The key "" corresponds to the total |
| number of parameters of the model. |
| """ |
| r = defaultdict(int) |
| for name, prm in model.named_parameters(): |
| if trainable_only: |
| if not prm.requires_grad: |
| continue |
| size = prm.numel() |
| name = name.split(".") |
| for k in range(0, len(name) + 1): |
| prefix = ".".join(name[:k]) |
| r[prefix] += size |
| return r |
|
|
|
|
| def parameter_count_table( |
| model: nn.Module, max_depth: int = 3, trainable_only: bool = False |
| ) -> str: |
| """ |
| Format the parameter count of the model (and its submodules or parameters) |
| in a nice table. It looks like this: |
| |
| :: |
| |
| | name | #elements or shape | |
| |:--------------------------------|:---------------------| |
| | model | 37.9M | |
| | backbone | 31.5M | |
| | backbone.fpn_lateral3 | 0.1M | |
| | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | |
| | backbone.fpn_lateral3.bias | (256,) | |
| | backbone.fpn_output3 | 0.6M | |
| | backbone.fpn_output3.weight | (256, 256, 3, 3) | |
| | backbone.fpn_output3.bias | (256,) | |
| | backbone.fpn_lateral4 | 0.3M | |
| | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | |
| | backbone.fpn_lateral4.bias | (256,) | |
| | backbone.fpn_output4 | 0.6M | |
| | backbone.fpn_output4.weight | (256, 256, 3, 3) | |
| | backbone.fpn_output4.bias | (256,) | |
| | backbone.fpn_lateral5 | 0.5M | |
| | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | |
| | backbone.fpn_lateral5.bias | (256,) | |
| | backbone.fpn_output5 | 0.6M | |
| | backbone.fpn_output5.weight | (256, 256, 3, 3) | |
| | backbone.fpn_output5.bias | (256,) | |
| | backbone.top_block | 5.3M | |
| | backbone.top_block.p6 | 4.7M | |
| | backbone.top_block.p7 | 0.6M | |
| | backbone.bottom_up | 23.5M | |
| | backbone.bottom_up.stem | 9.4K | |
| | backbone.bottom_up.res2 | 0.2M | |
| | backbone.bottom_up.res3 | 1.2M | |
| | backbone.bottom_up.res4 | 7.1M | |
| | backbone.bottom_up.res5 | 14.9M | |
| | ...... | ..... | |
| |
| Args: |
| model: a torch module |
| max_depth (int): maximum depth to recursively print submodules or |
| parameters |
| |
| Returns: |
| str: the table to be printed |
| """ |
| count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only) |
| |
| param_shape: typing.Dict[str, typing.Tuple] = { |
| k: tuple(v.shape) for k, v in model.named_parameters() |
| } |
|
|
| |
| table: typing.List[typing.Tuple] = [] |
|
|
| def format_size(x: int) -> str: |
| if x > 1e8: |
| return "{:.1f}G".format(x / 1e9) |
| if x > 1e5: |
| return "{:.1f}M".format(x / 1e6) |
| if x > 1e2: |
| return "{:.1f}K".format(x / 1e3) |
| return str(x) |
|
|
| def fill(lvl: int, prefix: str) -> None: |
| if lvl >= max_depth: |
| return |
| for name, v in count.items(): |
| if name.count(".") == lvl and name.startswith(prefix): |
| indent = " " * (lvl + 1) |
| if name in param_shape: |
| table.append((indent + name, indent + str(param_shape[name]))) |
| else: |
| table.append((indent + name, indent + format_size(v))) |
| fill(lvl + 1, name + ".") |
|
|
| table.append(("model", format_size(count.pop("")))) |
| fill(0, "") |
|
|
| old_ws = tabulate.PRESERVE_WHITESPACE |
| tabulate.PRESERVE_WHITESPACE = True |
| tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe") |
| tabulate.PRESERVE_WHITESPACE = old_ws |
| return tab |
|
|
|
|
| def cycle(iterable): |
| while True: |
| for x in iterable: |
| yield x |
|
|
| class MattingTrainer(AMPTrainer): |
| def __init__(self, model, data_loader, optimizer, grad_scaler=None): |
| super().__init__(model, data_loader, optimizer, grad_scaler=None) |
| self.data_loader_iter = iter(cycle(self.data_loader)) |
|
|
| |
| logger.info("All parameters: \n" + parameter_count_table(model)) |
| logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8)) |
|
|
| def run_step(self): |
| """ |
| Implement the AMP training logic. |
| """ |
| assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
| assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
| from torch.cuda.amp import autocast |
|
|
| |
| start = time.perf_counter() |
| data = next(self.data_loader_iter) |
| data_time = time.perf_counter() - start |
|
|
| with autocast(): |
| loss_dict = self.model(data) |
| if isinstance(loss_dict, torch.Tensor): |
| losses = loss_dict |
| loss_dict = {"total_loss": loss_dict} |
| else: |
| losses = sum(loss_dict.values()) |
|
|
| self.optimizer.zero_grad() |
| self.grad_scaler.scale(losses).backward() |
|
|
| self._write_metrics(loss_dict, data_time) |
|
|
| self.grad_scaler.step(self.optimizer) |
| self.grad_scaler.update() |