CgvKodai's picture
Upload folder using huggingface_hub
66003a2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from wcmatch import fnmatch
from functools import wraps
from typing import List
import torch.nn as nn
# ------------------------------------------------------------
# Glob‑matching flags (behave like the Unix shell)
# ------------------------------------------------------------
GLOB_FLAGS = (
fnmatch.CASE # case‑sensitive
| fnmatch.DOTMATCH # '*' also matches '.'
| fnmatch.EXTMATCH # extended patterns like *(foo|bar)
| fnmatch.SPLIT # "pat1|pat2" works out‑of‑the‑box
)
def freeze_modules(model: nn.Module, patterns: List[str], recursive: bool = True) -> nn.Module:
"""Freeze (stop training) parts of *model* whose *name* matches *patterns*.
Parameters
----------
model : nn.Module
The complete model you are working with.
patterns : list[str]
Glob patterns to match sub‑module names. Example: ``["encoder.*", "cls_head"]``
recursive : bool, default = True
• ``True`` → also freeze every child of a matched module.
• ``False`` → freeze only the matched module itself.
Returns
-------
nn.Module
The same model object, now with some parts frozen.
Example
-------
>>> freeze_modules(model, ["encoder.*", "decoder.layer1"], recursive=True)
"""
matched: set[str] = set()
for name, mod in model.named_modules():
# does *name* match ANY user pattern?
if any(fnmatch.fnmatch(name, p, flags=GLOB_FLAGS) for p in patterns):
matched.add(name)
_freeze(mod, recursive)
_check_every_pattern_used(matched, patterns)
return model
# ------------------------------------------------------------
# helpers
# ------------------------------------------------------------
def _freeze(mod: nn.Module, recursive: bool) -> None:
"""Put *mod* in eval mode and lock its parameters."""
if recursive:
mod.eval() # affects the whole subtree
else:
mod.training = False # only this exact module
original_train = mod.train
@wraps(original_train)
def locked_train(mode: bool = True):
if recursive:
return original_train(False) # ignore user's *mode*
out = original_train(mode) # children follow user's choice
out.training = False # but this module stays frozen
return out
mod.train = locked_train # type: ignore[attr-defined]
param_iter = (
mod.parameters() # default recurse=True
if recursive
else mod.parameters(recurse=False)
)
for p in param_iter:
p.requires_grad = False
def _check_every_pattern_used(matched_names: set[str], patterns: List[str]):
unused = [p for p in patterns if not any(fnmatch.fnmatch(n, p, flags=GLOB_FLAGS)
for n in matched_names)]
if unused:
raise ValueError(f"These patterns matched nothing: {unused}")