File size: 1,003 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Contains utility functionality to modify torch modules.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

from typing import Any

from torch import nn

NORM_LAYER_TYPES = tuple(module_type for name, module_type in nn.__dict__.items() if "Norm" in name)
BATCH_NORM_LAYER_TYPES = tuple(
    module_type for name, module_type in nn.__dict__.items() if "BatchNorm" in name
)


def freeze_norm_layer(module: nn.Module) -> nn.Module:
    """Freeze all normalization layers."""

    def set_module_eval_mode(module: nn.Module, _: Any) -> None:
        module.eval()

    for submodule in module.modules():
        if isinstance(submodule, NORM_LAYER_TYPES):
            submodule.requires_grad_(False)
            # This is to ensure that batch norm layers are always called
            # with the precomputed running statistics.
            submodule.register_forward_pre_hook(set_module_eval_mode)

    return module