File size: 1,159 Bytes
751af19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""Model-introspection helpers shared across backends."""

from __future__ import annotations


def count_params(module) -> int:
    """
    Return the total number of parameters of a PyTorch `nn.Module`-like object.

    Works for standard Transformers models (they are `nn.Module` subclasses) and for
    SpeechBrain's `Pretrained` base class (also an `nn.Module` subclass). Returns 0 if
    the object has no `.parameters()` method or counting fails for any reason — we
    treat parameter counts as best-effort metadata, never a hard requirement.
    """
    if module is None:
        return 0
    params_fn = getattr(module, "parameters", None)
    if not callable(params_fn):
        return 0
    try:
        total = 0
        for p in params_fn():
            numel = getattr(p, "numel", None)
            if callable(numel):
                total += int(numel())
        return total
    except Exception:
        return 0


def attach_params(fn, module) -> None:
    """Convenience: attach the param count of `module` to `fn._num_params`."""
    try:
        fn._num_params = count_params(module)
    except Exception:
        fn._num_params = 0