ml-sharp / src /sharp /utils /module_surgery.py
amael-apple's picture
Initial commit
c20d7cc
"""Contains utility functionality to modify torch modules.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import Any
from torch import nn
NORM_LAYER_TYPES = tuple(module_type for name, module_type in nn.__dict__.items() if "Norm" in name)
BATCH_NORM_LAYER_TYPES = tuple(
module_type for name, module_type in nn.__dict__.items() if "BatchNorm" in name
)
def freeze_norm_layer(module: nn.Module) -> nn.Module:
"""Freeze all normalization layers."""
def set_module_eval_mode(module: nn.Module, _: Any) -> None:
module.eval()
for submodule in module.modules():
if isinstance(submodule, NORM_LAYER_TYPES):
submodule.requires_grad_(False)
# This is to ensure that batch norm layers are always called
# with the precomputed running statistics.
submodule.register_forward_pre_hook(set_module_eval_mode)
return module