"""Model-introspection helpers shared across backends.""" from __future__ import annotations def count_params(module) -> int: """ Return the total number of parameters of a PyTorch `nn.Module`-like object. Works for standard Transformers models (they are `nn.Module` subclasses) and for SpeechBrain's `Pretrained` base class (also an `nn.Module` subclass). Returns 0 if the object has no `.parameters()` method or counting fails for any reason — we treat parameter counts as best-effort metadata, never a hard requirement. """ if module is None: return 0 params_fn = getattr(module, "parameters", None) if not callable(params_fn): return 0 try: total = 0 for p in params_fn(): numel = getattr(p, "numel", None) if callable(numel): total += int(numel()) return total except Exception: return 0 def attach_params(fn, module) -> None: """Convenience: attach the param count of `module` to `fn._num_params`.""" try: fn._num_params = count_params(module) except Exception: fn._num_params = 0