|
|
"""Contains utility functionality to modify torch modules. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
from torch import nn |
|
|
|
|
|
NORM_LAYER_TYPES = tuple(module_type for name, module_type in nn.__dict__.items() if "Norm" in name) |
|
|
BATCH_NORM_LAYER_TYPES = tuple( |
|
|
module_type for name, module_type in nn.__dict__.items() if "BatchNorm" in name |
|
|
) |
|
|
|
|
|
|
|
|
def freeze_norm_layer(module: nn.Module) -> nn.Module: |
|
|
"""Freeze all normalization layers.""" |
|
|
|
|
|
def set_module_eval_mode(module: nn.Module, _: Any) -> None: |
|
|
module.eval() |
|
|
|
|
|
for submodule in module.modules(): |
|
|
if isinstance(submodule, NORM_LAYER_TYPES): |
|
|
submodule.requires_grad_(False) |
|
|
|
|
|
|
|
|
submodule.register_forward_pre_hook(set_module_eval_mode) |
|
|
|
|
|
return module |
|
|
|