Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Model-introspection helpers shared across backends.""" | |
| from __future__ import annotations | |
| def count_params(module) -> int: | |
| """ | |
| Return the total number of parameters of a PyTorch `nn.Module`-like object. | |
| Works for standard Transformers models (they are `nn.Module` subclasses) and for | |
| SpeechBrain's `Pretrained` base class (also an `nn.Module` subclass). Returns 0 if | |
| the object has no `.parameters()` method or counting fails for any reason — we | |
| treat parameter counts as best-effort metadata, never a hard requirement. | |
| """ | |
| if module is None: | |
| return 0 | |
| params_fn = getattr(module, "parameters", None) | |
| if not callable(params_fn): | |
| return 0 | |
| try: | |
| total = 0 | |
| for p in params_fn(): | |
| numel = getattr(p, "numel", None) | |
| if callable(numel): | |
| total += int(numel()) | |
| return total | |
| except Exception: | |
| return 0 | |
| def attach_params(fn, module) -> None: | |
| """Convenience: attach the param count of `module` to `fn._num_params`.""" | |
| try: | |
| fn._num_params = count_params(module) | |
| except Exception: | |
| fn._num_params = 0 | |