File size: 1,323 Bytes
59f1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import contextlib
import threading
from collections.abc import Generator
from typing import Any

import torch


_TLS = threading.local()


def _freezing_active() -> bool:
    return getattr(_TLS, "freezing_active", False)


@contextlib.contextmanager
def enter_freezing() -> Generator[Any, None, None]:
    """

    Context manager to designate when freezing is active.

    """
    prev = _freezing_active()
    _TLS.freezing_active = True
    try:
        yield
    finally:
        _TLS.freezing_active = prev


def record_has_frozen_params(gm: torch.fx.GraphModule) -> None:
    """

    Mark the gm as having frozen params.

    """
    gm._has_frozen_params = True  # type: ignore[assignment]


def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
    """

    Return True if the gm has frozen parameters.

    """
    return getattr(gm, "_has_frozen_params", False)


def maybe_set_is_frozen_param(t: torch.Tensor) -> None:
    """

    Mark the provided tensor as a frozen param if freezing is active.

    """
    if _freezing_active():
        t._is_frozen_param = True  # type: ignore[attr-defined]


def is_frozen_param(t: torch.Tensor) -> bool:
    """

    Return True if the tensor is a frozen param.

    """
    return getattr(t, "_is_frozen_param", False)