Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
File size: 1,159 Bytes
751af19 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | """Model-introspection helpers shared across backends."""
from __future__ import annotations
def count_params(module) -> int:
"""
Return the total number of parameters of a PyTorch `nn.Module`-like object.
Works for standard Transformers models (they are `nn.Module` subclasses) and for
SpeechBrain's `Pretrained` base class (also an `nn.Module` subclass). Returns 0 if
the object has no `.parameters()` method or counting fails for any reason — we
treat parameter counts as best-effort metadata, never a hard requirement.
"""
if module is None:
return 0
params_fn = getattr(module, "parameters", None)
if not callable(params_fn):
return 0
try:
total = 0
for p in params_fn():
numel = getattr(p, "numel", None)
if callable(numel):
total += int(numel())
return total
except Exception:
return 0
def attach_params(fn, module) -> None:
"""Convenience: attach the param count of `module` to `fn._num_params`."""
try:
fn._num_params = count_params(module)
except Exception:
fn._num_params = 0
|