ffasr / backends /_model_utils.py
Shivam
Visualization fixes
751af19
Raw
History Blame Contribute Delete
1.16 kB
"""Model-introspection helpers shared across backends."""
from __future__ import annotations
def count_params(module) -> int:
"""
Return the total number of parameters of a PyTorch `nn.Module`-like object.
Works for standard Transformers models (they are `nn.Module` subclasses) and for
SpeechBrain's `Pretrained` base class (also an `nn.Module` subclass). Returns 0 if
the object has no `.parameters()` method or counting fails for any reason — we
treat parameter counts as best-effort metadata, never a hard requirement.
"""
if module is None:
return 0
params_fn = getattr(module, "parameters", None)
if not callable(params_fn):
return 0
try:
total = 0
for p in params_fn():
numel = getattr(p, "numel", None)
if callable(numel):
total += int(numel())
return total
except Exception:
return 0
def attach_params(fn, module) -> None:
"""Convenience: attach the param count of `module` to `fn._num_params`."""
try:
fn._num_params = count_params(module)
except Exception:
fn._num_params = 0