| | |
| | from __future__ import annotations |
| |
|
| | import operator |
| | from collections import abc as container_abcs, OrderedDict |
| | from itertools import chain, islice |
| | from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union |
| | from typing_extensions import deprecated, Self |
| |
|
| | import torch |
| | from torch._jit_internal import _copy_to_script_wrapper |
| | from torch.nn.parameter import Parameter |
| |
|
| | from .module import Module |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from collections.abc import Iterable, Iterator, Mapping |
| |
|
| |
|
| | __all__ = [ |
| | "Container", |
| | "Sequential", |
| | "ModuleList", |
| | "ModuleDict", |
| | "ParameterList", |
| | "ParameterDict", |
| | ] |
| |
|
| | T = TypeVar("T", bound=Module) |
| | _V = TypeVar("_V") |
| |
|
| |
|
| | |
| | def _addindent(s_, numSpaces): |
| | s = s_.split("\n") |
| | |
| | if len(s) == 1: |
| | return s_ |
| | first = s.pop(0) |
| | s = [(numSpaces * " ") + line for line in s] |
| | s = "\n".join(s) |
| | s = first + "\n" + s |
| | return s |
| |
|
| |
|
| | @deprecated( |
| | "`nn.Container` is deprecated. " |
| | "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", |
| | category=FutureWarning, |
| | ) |
| | class Container(Module): |
| | def __init__(self, **kwargs: Any) -> None: |
| | super().__init__() |
| | for key, value in kwargs.items(): |
| | self.add_module(key, value) |
| |
|
| |
|
| | class Sequential(Module): |
| | r"""A sequential container. |
| | |
| | Modules will be added to it in the order they are passed in the |
| | constructor. Alternatively, an ``OrderedDict`` of modules can be |
| | passed in. The ``forward()`` method of ``Sequential`` accepts any |
| | input and forwards it to the first module it contains. It then |
| | "chains" outputs to inputs sequentially for each subsequent module, |
| | finally returning the output of the last module. |
| | |
| | The value a ``Sequential`` provides over manually calling a sequence |
| | of modules is that it allows treating the whole container as a |
| | single module, such that performing a transformation on the |
| | ``Sequential`` applies to each of the modules it stores (which are |
| | each a registered submodule of the ``Sequential``). |
| | |
| | What's the difference between a ``Sequential`` and a |
| | :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it |
| | sounds like--a list for storing ``Module`` s! On the other hand, |
| | the layers in a ``Sequential`` are connected in a cascading way. |
| | |
| | Example:: |
| | |
| | # Using Sequential to create a small model. When `model` is run, |
| | # input will first be passed to `Conv2d(1,20,5)`. The output of |
| | # `Conv2d(1,20,5)` will be used as the input to the first |
| | # `ReLU`; the output of the first `ReLU` will become the input |
| | # for `Conv2d(20,64,5)`. Finally, the output of |
| | # `Conv2d(20,64,5)` will be used as input to the second `ReLU` |
| | model = nn.Sequential( |
| | nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() |
| | ) |
| | |
| | # Using Sequential with OrderedDict. This is functionally the |
| | # same as the above code |
| | model = nn.Sequential( |
| | OrderedDict( |
| | [ |
| | ("conv1", nn.Conv2d(1, 20, 5)), |
| | ("relu1", nn.ReLU()), |
| | ("conv2", nn.Conv2d(20, 64, 5)), |
| | ("relu2", nn.ReLU()), |
| | ] |
| | ) |
| | ) |
| | """ |
| |
|
| | _modules: dict[str, Module] |
| |
|
| | @overload |
| | def __init__(self, *args: Module) -> None: ... |
| |
|
| | @overload |
| | def __init__(self, arg: OrderedDict[str, Module]) -> None: ... |
| |
|
| | def __init__(self, *args): |
| | super().__init__() |
| | if len(args) == 1 and isinstance(args[0], OrderedDict): |
| | for key, module in args[0].items(): |
| | self.add_module(key, module) |
| | else: |
| | for idx, module in enumerate(args): |
| | self.add_module(str(idx), module) |
| |
|
| | def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V: |
| | """Get the idx-th item of the iterator.""" |
| | size = len(self) |
| | idx = operator.index(idx) |
| | if not -size <= idx < size: |
| | raise IndexError(f"index {idx} is out of range") |
| | idx %= size |
| | return next(islice(iterator, idx, None)) |
| |
|
| | @_copy_to_script_wrapper |
| | def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]: |
| | if isinstance(idx, slice): |
| | return self.__class__(OrderedDict(list(self._modules.items())[idx])) |
| | else: |
| | return self._get_item_by_idx(self._modules.values(), idx) |
| |
|
| | def __setitem__(self, idx: int, module: Module) -> None: |
| | key: str = self._get_item_by_idx(self._modules.keys(), idx) |
| | return setattr(self, key, module) |
| |
|
| | def __delitem__(self, idx: Union[slice, int]) -> None: |
| | if isinstance(idx, slice): |
| | for key in list(self._modules.keys())[idx]: |
| | delattr(self, key) |
| | else: |
| | key = self._get_item_by_idx(self._modules.keys(), idx) |
| | delattr(self, key) |
| | |
| | str_indices = [str(i) for i in range(len(self._modules))] |
| | self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) |
| |
|
| | @_copy_to_script_wrapper |
| | def __len__(self) -> int: |
| | return len(self._modules) |
| |
|
| | def __add__(self, other) -> Sequential: |
| | if isinstance(other, Sequential): |
| | ret = Sequential() |
| | for layer in self: |
| | ret.append(layer) |
| | for layer in other: |
| | ret.append(layer) |
| | return ret |
| | else: |
| | raise ValueError( |
| | "add operator supports only objects " |
| | f"of Sequential class, but {str(type(other))} is given." |
| | ) |
| |
|
| | def pop(self, key: Union[int, slice]) -> Module: |
| | """ |
| | Pop ``key`` from self. |
| | """ |
| | v = self[key] |
| | del self[key] |
| | return v |
| |
|
| | def __iadd__(self, other) -> Self: |
| | if isinstance(other, Sequential): |
| | offset = len(self) |
| | for i, module in enumerate(other): |
| | self.add_module(str(i + offset), module) |
| | return self |
| | else: |
| | raise ValueError( |
| | "add operator supports only objects " |
| | f"of Sequential class, but {str(type(other))} is given." |
| | ) |
| |
|
| | def __mul__(self, other: int) -> Sequential: |
| | if not isinstance(other, int): |
| | raise TypeError( |
| | f"unsupported operand type(s) for *: {type(self)} and {type(other)}" |
| | ) |
| | elif other <= 0: |
| | raise ValueError( |
| | f"Non-positive multiplication factor {other} for {type(self)}" |
| | ) |
| | else: |
| | combined = Sequential() |
| | offset = 0 |
| | for _ in range(other): |
| | for module in self: |
| | combined.add_module(str(offset), module) |
| | offset += 1 |
| | return combined |
| |
|
| | def __rmul__(self, other: int) -> Sequential: |
| | return self.__mul__(other) |
| |
|
| | def __imul__(self, other: int) -> Self: |
| | if not isinstance(other, int): |
| | raise TypeError( |
| | f"unsupported operand type(s) for *: {type(self)} and {type(other)}" |
| | ) |
| | elif other <= 0: |
| | raise ValueError( |
| | f"Non-positive multiplication factor {other} for {type(self)}" |
| | ) |
| | else: |
| | len_original = len(self) |
| | offset = len(self) |
| | for _ in range(other - 1): |
| | for i in range(len_original): |
| | self.add_module(str(i + offset), self._modules[str(i)]) |
| | offset += len_original |
| | return self |
| |
|
| | @_copy_to_script_wrapper |
| | def __dir__(self) -> list[str]: |
| | keys = super().__dir__() |
| | keys = [key for key in keys if not key.isdigit()] |
| | return keys |
| |
|
| | @_copy_to_script_wrapper |
| | def __iter__(self) -> Iterator[Module]: |
| | return iter(self._modules.values()) |
| |
|
| | |
| | |
| | |
| | |
| | def forward(self, input): |
| | """ |
| | Runs the forward pass. |
| | """ |
| | for module in self: |
| | input = module(input) |
| | return input |
| |
|
| | def append(self, module: Module) -> Self: |
| | r"""Append a given module to the end. |
| | |
| | Args: |
| | module (nn.Module): module to append |
| | |
| | Example:: |
| | |
| | >>> import torch.nn as nn |
| | >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) |
| | >>> n.append(nn.Linear(3, 4)) |
| | Sequential( |
| | (0): Linear(in_features=1, out_features=2, bias=True) |
| | (1): Linear(in_features=2, out_features=3, bias=True) |
| | (2): Linear(in_features=3, out_features=4, bias=True) |
| | ) |
| | |
| | """ |
| | self.add_module(str(len(self)), module) |
| | return self |
| |
|
| | def insert(self, index: int, module: Module) -> Self: |
| | """ |
| | Inserts a module into the Sequential container at the specified index. |
| | |
| | Args: |
| | index (int): The index to insert the module. |
| | module (Module): The module to be inserted. |
| | |
| | Example:: |
| | |
| | >>> import torch.nn as nn |
| | >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) |
| | >>> n.insert(0, nn.Linear(3, 4)) |
| | Sequential( |
| | (0): Linear(in_features=3, out_features=4, bias=True) |
| | (1): Linear(in_features=1, out_features=2, bias=True) |
| | (2): Linear(in_features=2, out_features=3, bias=True) |
| | ) |
| | |
| | """ |
| | if not isinstance(module, Module): |
| | raise AssertionError(f"module should be of type: {Module}") |
| | n = len(self._modules) |
| | if not (-n <= index <= n): |
| | raise IndexError(f"Index out of range: {index}") |
| | if index < 0: |
| | index += n |
| | for i in range(n, index, -1): |
| | self._modules[str(i)] = self._modules[str(i - 1)] |
| | self._modules[str(index)] = module |
| | return self |
| |
|
| | def extend(self, sequential: Iterable[Module]) -> Self: |
| | """ |
| | Extends the current Sequential container with layers from another Sequential container. |
| | |
| | Args: |
| | sequential (Sequential): A Sequential container whose layers will be added to the current container. |
| | |
| | Example:: |
| | |
| | >>> import torch.nn as nn |
| | >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) |
| | >>> other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5)) |
| | >>> n.extend(other) # or `n + other` |
| | Sequential( |
| | (0): Linear(in_features=1, out_features=2, bias=True) |
| | (1): Linear(in_features=2, out_features=3, bias=True) |
| | (2): Linear(in_features=3, out_features=4, bias=True) |
| | (3): Linear(in_features=4, out_features=5, bias=True) |
| | ) |
| | |
| | """ |
| | for layer in sequential: |
| | self.append(layer) |
| | return self |
| |
|
| |
|
| | class ModuleList(Module): |
| | r"""Holds submodules in a list. |
| | |
| | :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but |
| | modules it contains are properly registered, and will be visible by all |
| | :class:`~torch.nn.Module` methods. |
| | |
| | Args: |
| | modules (iterable, optional): an iterable of modules to add |
| | |
| | Example:: |
| | |
| | class MyModule(nn.Module): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) |
| | |
| | def forward(self, x): |
| | # ModuleList can act as an iterable, or be indexed using ints |
| | for i, l in enumerate(self.linears): |
| | x = self.linears[i // 2](x) + l(x) |
| | return x |
| | """ |
| |
|
| | _modules: dict[str, Module] |
| |
|
| | def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: |
| | super().__init__() |
| | if modules is not None: |
| | self += modules |
| |
|
| | def _get_abs_string_index(self, idx): |
| | """Get the absolute index for the list of modules.""" |
| | idx = operator.index(idx) |
| | if not (-len(self) <= idx < len(self)): |
| | raise IndexError(f"index {idx} is out of range") |
| | if idx < 0: |
| | idx += len(self) |
| | return str(idx) |
| |
|
| | @overload |
| | def __getitem__(self, idx: slice) -> ModuleList: ... |
| |
|
| | @overload |
| | def __getitem__(self, idx: int) -> Module: ... |
| |
|
| | @_copy_to_script_wrapper |
| | def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]: |
| | if isinstance(idx, slice): |
| | return self.__class__(list(self._modules.values())[idx]) |
| | else: |
| | return self._modules[self._get_abs_string_index(idx)] |
| |
|
| | def __setitem__(self, idx: int, module: Module) -> None: |
| | idx = self._get_abs_string_index(idx) |
| | return setattr(self, str(idx), module) |
| |
|
| | def __delitem__(self, idx: Union[int, slice]) -> None: |
| | if isinstance(idx, slice): |
| | for k in range(len(self._modules))[idx]: |
| | delattr(self, str(k)) |
| | else: |
| | delattr(self, self._get_abs_string_index(idx)) |
| | |
| | str_indices = [str(i) for i in range(len(self._modules))] |
| | self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) |
| |
|
| | @_copy_to_script_wrapper |
| | def __len__(self) -> int: |
| | return len(self._modules) |
| |
|
| | @_copy_to_script_wrapper |
| | def __iter__(self) -> Iterator[Module]: |
| | return iter(self._modules.values()) |
| |
|
| | def __iadd__(self, modules: Iterable[Module]) -> Self: |
| | return self.extend(modules) |
| |
|
| | def __add__(self, other: Iterable[Module]) -> ModuleList: |
| | combined = ModuleList() |
| | for i, module in enumerate(chain(self, other)): |
| | combined.add_module(str(i), module) |
| | return combined |
| |
|
| | def __repr__(self) -> str: |
| | """Return a custom repr for ModuleList that compresses repeated module representations.""" |
| | list_of_reprs = [repr(item) for item in self] |
| | if len(list_of_reprs) == 0: |
| | return self._get_name() + "()" |
| |
|
| | start_end_indices = [[0, 0]] |
| | repeated_blocks = [list_of_reprs[0]] |
| | for i, r in enumerate(list_of_reprs[1:], 1): |
| | if r == repeated_blocks[-1]: |
| | start_end_indices[-1][1] += 1 |
| | continue |
| |
|
| | start_end_indices.append([i, i]) |
| | repeated_blocks.append(r) |
| |
|
| | lines = [] |
| | main_str = self._get_name() + "(" |
| | for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): |
| | local_repr = f"({start_id}): {b}" |
| |
|
| | if start_id != end_id: |
| | n = end_id - start_id + 1 |
| | local_repr = f"({start_id}-{end_id}): {n} x {b}" |
| |
|
| | local_repr = _addindent(local_repr, 2) |
| | lines.append(local_repr) |
| |
|
| | main_str += "\n " + "\n ".join(lines) + "\n" |
| | main_str += ")" |
| | return main_str |
| |
|
| | @_copy_to_script_wrapper |
| | def __dir__(self) -> list[str]: |
| | keys = super().__dir__() |
| | keys = [key for key in keys if not key.isdigit()] |
| | return keys |
| |
|
| | def insert(self, index: int, module: Module) -> None: |
| | r"""Insert a given module before a given index in the list. |
| | |
| | Args: |
| | index (int): index to insert. |
| | module (nn.Module): module to insert |
| | """ |
| | for i in range(len(self._modules), index, -1): |
| | self._modules[str(i)] = self._modules[str(i - 1)] |
| | self._modules[str(index)] = module |
| |
|
| | def append(self, module: Module) -> Self: |
| | r"""Append a given module to the end of the list. |
| | |
| | Args: |
| | module (nn.Module): module to append |
| | """ |
| | self.add_module(str(len(self)), module) |
| | return self |
| |
|
| | def pop(self, key: Union[int, slice]) -> Module: |
| | v = self[key] |
| | del self[key] |
| | return v |
| |
|
| | def extend(self, modules: Iterable[Module]) -> Self: |
| | r"""Append modules from a Python iterable to the end of the list. |
| | |
| | Args: |
| | modules (iterable): iterable of modules to append |
| | """ |
| | if not isinstance(modules, container_abcs.Iterable): |
| | raise TypeError( |
| | "ModuleList.extend should be called with an " |
| | "iterable, but got " + type(modules).__name__ |
| | ) |
| | offset = len(self) |
| | for i, module in enumerate(modules): |
| | self.add_module(str(offset + i), module) |
| | return self |
| |
|
| | |
| |
|
| |
|
| | class ModuleDict(Module): |
| | r"""Holds submodules in a dictionary. |
| | |
| | :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary, |
| | but modules it contains are properly registered, and will be visible by all |
| | :class:`~torch.nn.Module` methods. |
| | |
| | :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects |
| | |
| | * the order of insertion, and |
| | |
| | * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged |
| | ``OrderedDict``, ``dict`` (started from Python 3.6) or another |
| | :class:`~torch.nn.ModuleDict` (the argument to |
| | :meth:`~torch.nn.ModuleDict.update`). |
| | |
| | Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping |
| | types (e.g., Python's plain ``dict`` before Python version 3.6) does not |
| | preserve the order of the merged mapping. |
| | |
| | Args: |
| | modules (iterable, optional): a mapping (dictionary) of (string: module) |
| | or an iterable of key-value pairs of type (string, module) |
| | |
| | Example:: |
| | |
| | class MyModule(nn.Module): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.choices = nn.ModuleDict( |
| | {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)} |
| | ) |
| | self.activations = nn.ModuleDict( |
| | [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]] |
| | ) |
| | |
| | def forward(self, x, choice, act): |
| | x = self.choices[choice](x) |
| | x = self.activations[act](x) |
| | return x |
| | """ |
| |
|
| | _modules: dict[str, Module] |
| |
|
| | def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: |
| | super().__init__() |
| | if modules is not None: |
| | self.update(modules) |
| |
|
| | @_copy_to_script_wrapper |
| | def __getitem__(self, key: str) -> Module: |
| | return self._modules[key] |
| |
|
| | def __setitem__(self, key: str, module: Module) -> None: |
| | self.add_module(key, module) |
| |
|
| | def __delitem__(self, key: str) -> None: |
| | del self._modules[key] |
| |
|
| | @_copy_to_script_wrapper |
| | def __len__(self) -> int: |
| | return len(self._modules) |
| |
|
| | @_copy_to_script_wrapper |
| | def __iter__(self) -> Iterator[str]: |
| | return iter(self._modules) |
| |
|
| | @_copy_to_script_wrapper |
| | def __contains__(self, key: str) -> bool: |
| | return key in self._modules |
| |
|
| | def clear(self) -> None: |
| | """Remove all items from the ModuleDict.""" |
| | self._modules.clear() |
| |
|
| | def pop(self, key: str) -> Module: |
| | r"""Remove key from the ModuleDict and return its module. |
| | |
| | Args: |
| | key (str): key to pop from the ModuleDict |
| | """ |
| | v = self[key] |
| | del self[key] |
| | return v |
| |
|
| | @_copy_to_script_wrapper |
| | def keys(self) -> container_abcs.KeysView[str]: |
| | r"""Return an iterable of the ModuleDict keys.""" |
| | return self._modules.keys() |
| |
|
| | @_copy_to_script_wrapper |
| | def items(self) -> container_abcs.ItemsView[str, Module]: |
| | r"""Return an iterable of the ModuleDict key/value pairs.""" |
| | return self._modules.items() |
| |
|
| | @_copy_to_script_wrapper |
| | def values(self) -> container_abcs.ValuesView[Module]: |
| | r"""Return an iterable of the ModuleDict values.""" |
| | return self._modules.values() |
| |
|
| | def update(self, modules: Mapping[str, Module]) -> None: |
| | r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. |
| | |
| | .. note:: |
| | If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or |
| | an iterable of key-value pairs, the order of new elements in it is preserved. |
| | |
| | Args: |
| | modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, |
| | or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) |
| | """ |
| | if not isinstance(modules, container_abcs.Iterable): |
| | raise TypeError( |
| | "ModuleDict.update should be called with an " |
| | "iterable of key/value pairs, but got " + type(modules).__name__ |
| | ) |
| |
|
| | if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): |
| | for key, module in modules.items(): |
| | self[key] = module |
| | else: |
| | |
| | for j, m in enumerate(modules): |
| | if not isinstance(m, container_abcs.Iterable): |
| | raise TypeError( |
| | "ModuleDict update sequence element " |
| | "#" + str(j) + " should be Iterable; is" + type(m).__name__ |
| | ) |
| | if not len(m) == 2: |
| | raise ValueError( |
| | "ModuleDict update sequence element " |
| | "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" |
| | ) |
| | |
| | |
| | self[m[0]] = m[1] |
| |
|
| | |
| |
|
| |
|
| | class ParameterList(Module): |
| | r"""Holds parameters in a list. |
| | |
| | :class:`~torch.nn.ParameterList` can be used like a regular Python |
| | list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered, |
| | and will be visible by all :class:`~torch.nn.Module` methods. |
| | |
| | Note that the constructor, assigning an element of the list, the |
| | :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend` |
| | method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`. |
| | |
| | Args: |
| | parameters (iterable, optional): an iterable of elements to add to the list. |
| | |
| | Example:: |
| | |
| | class MyModule(nn.Module): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.params = nn.ParameterList( |
| | [nn.Parameter(torch.randn(10, 10)) for i in range(10)] |
| | ) |
| | |
| | def forward(self, x): |
| | # ParameterList can act as an iterable, or be indexed using ints |
| | for i, p in enumerate(self.params): |
| | x = self.params[i // 2].mm(x) + p.mm(x) |
| | return x |
| | """ |
| |
|
| | def __init__(self, values: Optional[Iterable[Any]] = None) -> None: |
| | super().__init__() |
| | self._size = 0 |
| | if values is not None: |
| | self += values |
| |
|
| | def _get_abs_string_index(self, idx): |
| | """Get the absolute index for the list of modules.""" |
| | idx = operator.index(idx) |
| | if not (-len(self) <= idx < len(self)): |
| | raise IndexError(f"index {idx} is out of range") |
| | if idx < 0: |
| | idx += len(self) |
| | return str(idx) |
| |
|
| | @overload |
| | def __getitem__(self, idx: int) -> Any: ... |
| |
|
| | @overload |
| | def __getitem__(self: T, idx: slice) -> T: ... |
| |
|
| | def __getitem__(self, idx): |
| | if isinstance(idx, slice): |
| | start, stop, step = idx.indices(len(self)) |
| | out = self.__class__() |
| | for i in range(start, stop, step): |
| | out.append(self[i]) |
| | return out |
| | else: |
| | idx = self._get_abs_string_index(idx) |
| | return getattr(self, str(idx)) |
| |
|
| | def __setitem__(self, idx: int, param: Any) -> None: |
| | |
| | |
| | |
| | |
| | |
| | idx = self._get_abs_string_index(idx) |
| | if isinstance(param, torch.Tensor) and not isinstance(param, Parameter): |
| | param = Parameter(param) |
| | return setattr(self, str(idx), param) |
| |
|
| | def __len__(self) -> int: |
| | return self._size |
| |
|
| | def __iter__(self) -> Iterator[Any]: |
| | return iter(self[i] for i in range(len(self))) |
| |
|
| | def __iadd__(self, parameters: Iterable[Any]) -> Self: |
| | return self.extend(parameters) |
| |
|
| | def __dir__(self) -> list[str]: |
| | keys = super().__dir__() |
| | keys = [key for key in keys if not key.isdigit()] |
| | return keys |
| |
|
| | def append(self, value: Any) -> Self: |
| | """Append a given value at the end of the list. |
| | |
| | Args: |
| | value (Any): value to append |
| | """ |
| | new_idx = len(self) |
| | self._size += 1 |
| | self[new_idx] = value |
| | return self |
| |
|
| | def extend(self, values: Iterable[Any]) -> Self: |
| | """Append values from a Python iterable to the end of the list. |
| | |
| | Args: |
| | values (iterable): iterable of values to append |
| | """ |
| | |
| | if not isinstance(values, container_abcs.Iterable) or isinstance( |
| | values, torch.Tensor |
| | ): |
| | raise TypeError( |
| | "ParameterList.extend should be called with an " |
| | "iterable, but got " + type(values).__name__ |
| | ) |
| | for value in values: |
| | self.append(value) |
| | return self |
| |
|
| | def extra_repr(self) -> str: |
| | """ |
| | Return the extra representation of the module. |
| | """ |
| | child_lines = [] |
| | for k, p in enumerate(self): |
| | if isinstance(p, torch.Tensor): |
| | size_str = "x".join(str(size) for size in p.size()) |
| | if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: |
| | device_str = f" ({p.device})" |
| | else: |
| | device_str = "" |
| | parastr = "{} containing: [{} of size {}{}]".format( |
| | "Parameter" if isinstance(p, Parameter) else "Tensor", |
| | p.dtype, |
| | size_str, |
| | device_str, |
| | ) |
| | child_lines.append(" (" + str(k) + "): " + parastr) |
| | else: |
| | child_lines.append( |
| | " (" + str(k) + "): Object of type: " + type(p).__name__ |
| | ) |
| |
|
| | tmpstr = "\n".join(child_lines) |
| | return tmpstr |
| |
|
| | def __call__(self, *args, **kwargs): |
| | raise RuntimeError("ParameterList should not be called.") |
| |
|
| |
|
| | class ParameterDict(Module): |
| | r"""Holds parameters in a dictionary. |
| | |
| | ParameterDict can be indexed like a regular Python dictionary, but Parameters it |
| | contains are properly registered, and will be visible by all Module methods. |
| | Other objects are treated as would be done by a regular Python dictionary |
| | |
| | :class:`~torch.nn.ParameterDict` is an **ordered** dictionary. |
| | :meth:`~torch.nn.ParameterDict.update` with other unordered mapping |
| | types (e.g., Python's plain ``dict``) does not preserve the order of the |
| | merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict` |
| | will preserve their ordering. |
| | |
| | Note that the constructor, assigning an element of the dictionary and the |
| | :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into |
| | :class:`~torch.nn.Parameter`. |
| | |
| | Args: |
| | values (iterable, optional): a mapping (dictionary) of |
| | (string : Any) or an iterable of key-value pairs |
| | of type (string, Any) |
| | |
| | Example:: |
| | |
| | class MyModule(nn.Module): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.params = nn.ParameterDict( |
| | { |
| | "left": nn.Parameter(torch.randn(5, 10)), |
| | "right": nn.Parameter(torch.randn(5, 10)), |
| | } |
| | ) |
| | |
| | def forward(self, x, choice): |
| | x = self.params[choice].mm(x) |
| | return x |
| | """ |
| |
|
| | def __init__(self, parameters: Any = None) -> None: |
| | super().__init__() |
| | self._keys: dict[str, None] = {} |
| | if parameters is not None: |
| | self.update(parameters) |
| |
|
| | def _key_to_attr(self, key: str) -> str: |
| | if not isinstance(key, str): |
| | raise TypeError( |
| | "Index given to ParameterDict cannot be used as a key as it is " |
| | f"not a string (type is '{type(key).__name__}'). Open an issue on " |
| | "github if you need non-string keys." |
| | ) |
| | else: |
| | |
| | return key |
| |
|
| | def __getitem__(self, key: str) -> Any: |
| | attr = self._key_to_attr(key) |
| | return getattr(self, attr) |
| |
|
| | def __setitem__(self, key: str, value: Any) -> None: |
| | |
| | |
| | |
| | |
| | |
| | self._keys[key] = None |
| | attr = self._key_to_attr(key) |
| | if isinstance(value, torch.Tensor) and not isinstance(value, Parameter): |
| | value = Parameter(value) |
| | setattr(self, attr, value) |
| |
|
| | def __delitem__(self, key: str) -> None: |
| | del self._keys[key] |
| | attr = self._key_to_attr(key) |
| | delattr(self, attr) |
| |
|
| | def __len__(self) -> int: |
| | return len(self._keys) |
| |
|
| | def __iter__(self) -> Iterator[str]: |
| | return iter(self._keys) |
| |
|
| | def __reversed__(self) -> Iterator[str]: |
| | return reversed(self._keys) |
| |
|
| | def copy(self) -> ParameterDict: |
| | """Return a copy of this :class:`~torch.nn.ParameterDict` instance.""" |
| | |
| | |
| | return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) |
| |
|
| | def __contains__(self, key: str) -> bool: |
| | return key in self._keys |
| |
|
| | def setdefault(self, key: str, default: Optional[Any] = None) -> Any: |
| | """Set the default for a key in the Parameterdict. |
| | |
| | If key is in the ParameterDict, return its value. |
| | If not, insert `key` with a parameter `default` and return `default`. |
| | `default` defaults to `None`. |
| | |
| | Args: |
| | key (str): key to set default for |
| | default (Any): the parameter set to the key |
| | """ |
| | if key not in self: |
| | self[key] = default |
| | return self[key] |
| |
|
| | def clear(self) -> None: |
| | """Remove all items from the ParameterDict.""" |
| | for k in self._keys.copy(): |
| | del self[k] |
| |
|
| | def pop(self, key: str) -> Any: |
| | r"""Remove key from the ParameterDict and return its parameter. |
| | |
| | Args: |
| | key (str): key to pop from the ParameterDict |
| | """ |
| | v = self[key] |
| | del self[key] |
| | return v |
| |
|
| | def popitem(self) -> tuple[str, Any]: |
| | """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict.""" |
| | k, _ = self._keys.popitem() |
| | |
| | self._keys[k] = None |
| | val = self[k] |
| | del self[k] |
| | return k, val |
| |
|
| | def get(self, key: str, default: Optional[Any] = None) -> Any: |
| | r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. |
| | |
| | Args: |
| | key (str): key to get from the ParameterDict |
| | default (Parameter, optional): value to return if key not present |
| | """ |
| | return self[key] if key in self else default |
| |
|
| | def fromkeys( |
| | self, keys: Iterable[str], default: Optional[Any] = None |
| | ) -> ParameterDict: |
| | r"""Return a new ParameterDict with the keys provided. |
| | |
| | Args: |
| | keys (iterable, string): keys to make the new ParameterDict from |
| | default (Parameter, optional): value to set for all keys |
| | """ |
| | return ParameterDict((k, default) for k in keys) |
| |
|
| | def keys(self) -> container_abcs.KeysView[str]: |
| | r"""Return an iterable of the ParameterDict keys.""" |
| | return self._keys.keys() |
| |
|
| | def items(self) -> Iterable[tuple[str, Any]]: |
| | r"""Return an iterable of the ParameterDict key/value pairs.""" |
| | return ((k, self[k]) for k in self._keys) |
| |
|
| | def values(self) -> Iterable[Any]: |
| | r"""Return an iterable of the ParameterDict values.""" |
| | return (self[k] for k in self._keys) |
| |
|
| | def update(self, parameters: Union[Mapping[str, Any], ParameterDict]) -> None: |
| | r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. |
| | |
| | .. note:: |
| | If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or |
| | an iterable of key-value pairs, the order of new elements in it is preserved. |
| | |
| | Args: |
| | parameters (iterable): a mapping (dictionary) from string to |
| | :class:`~torch.nn.Parameter`, or an iterable of |
| | key-value pairs of type (string, :class:`~torch.nn.Parameter`) |
| | """ |
| | if not isinstance(parameters, container_abcs.Iterable): |
| | raise TypeError( |
| | "ParametersDict.update should be called with an " |
| | "iterable of key/value pairs, but got " + type(parameters).__name__ |
| | ) |
| |
|
| | if isinstance(parameters, (OrderedDict, ParameterDict)): |
| | for key, parameter in parameters.items(): |
| | self[key] = parameter |
| | elif isinstance(parameters, container_abcs.Mapping): |
| | for key, parameter in sorted(parameters.items()): |
| | self[key] = parameter |
| | else: |
| | for j, p in enumerate(parameters): |
| | if not isinstance(p, container_abcs.Iterable): |
| | raise TypeError( |
| | "ParameterDict update sequence element " |
| | "#" + str(j) + " should be Iterable; is" + type(p).__name__ |
| | ) |
| | if not len(p) == 2: |
| | raise ValueError( |
| | "ParameterDict update sequence element " |
| | "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" |
| | ) |
| | |
| | self[p[0]] = p[1] |
| |
|
| | def extra_repr(self) -> str: |
| | child_lines = [] |
| | for k, p in self.items(): |
| | if isinstance(p, torch.Tensor): |
| | size_str = "x".join(str(size) for size in p.size()) |
| | if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: |
| | device_str = f" ({p.device})" |
| | else: |
| | device_str = "" |
| | parastr = "{} containing: [{} of size {}{}]".format( |
| | "Parameter" if isinstance(p, Parameter) else "Tensor", |
| | torch.typename(p), |
| | size_str, |
| | device_str, |
| | ) |
| | child_lines.append(" (" + str(k) + "): " + parastr) |
| | else: |
| | child_lines.append( |
| | " (" + str(k) + "): Object of type: " + type(p).__name__ |
| | ) |
| | tmpstr = "\n".join(child_lines) |
| | return tmpstr |
| |
|
| | def __call__(self, input): |
| | raise RuntimeError("ParameterDict should not be called.") |
| |
|
| | def __or__(self, other: ParameterDict) -> ParameterDict: |
| | copy = self.copy() |
| | copy.update(other) |
| | return copy |
| |
|
| | def __ror__(self, other: ParameterDict) -> ParameterDict: |
| | copy = other.copy() |
| | copy.update(self) |
| | return copy |
| |
|
| | def __ior__(self, other: ParameterDict) -> Self: |
| | self.update(other) |
| | return self |
| |
|