| | |
| | import inspect |
| |
|
| | import torch |
| |
|
| |
|
| | def skip_init(module_cls, *args, **kwargs): |
| | r""" |
| | Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. |
| | |
| | This can be useful if initialization is slow or if custom initialization will |
| | be performed, making the default initialization unnecessary. There are some caveats to this, due to |
| | the way this function is implemented: |
| | |
| | 1. The module must accept a `device` arg in its constructor that is passed to any parameters |
| | or buffers created during construction. |
| | |
| | 2. The module must not perform any computation on parameters in its constructor except |
| | initialization (i.e. functions from :mod:`torch.nn.init`). |
| | |
| | If these conditions are satisfied, the module can be instantiated with parameter / buffer values |
| | uninitialized, as if having been created using :func:`torch.empty`. |
| | |
| | Args: |
| | module_cls: Class object; should be a subclass of :class:`torch.nn.Module` |
| | args: args to pass to the module's constructor |
| | kwargs: kwargs to pass to the module's constructor |
| | |
| | Returns: |
| | Instantiated module with uninitialized parameters / buffers |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +IGNORE_WANT("non-deterministic") |
| | >>> import torch |
| | >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) |
| | >>> m.weight |
| | Parameter containing: |
| | tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], |
| | requires_grad=True) |
| | >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) |
| | >>> m2.weight |
| | Parameter containing: |
| | tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, |
| | 4.5915e-41]], requires_grad=True) |
| | |
| | """ |
| | if not issubclass(module_cls, torch.nn.Module): |
| | raise RuntimeError(f"Expected a Module; got {module_cls}") |
| | if "device" not in inspect.signature(module_cls).parameters: |
| | raise RuntimeError("Module must support a 'device' arg to skip initialization") |
| |
|
| | final_device = kwargs.pop("device", "cpu") |
| | kwargs["device"] = "meta" |
| | return module_cls(*args, **kwargs).to_empty(device=final_device) |
| |
|