Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/fft.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_matcher.py +460 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_passes.py +950 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_safeguard.py +42 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_tree_utils.py +64 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/common_types.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/cpp.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/init.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/parameter.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/thnn.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/thnn.py +4 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__init__.py +35 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/fused.py +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__init__.py +13 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py +12 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__init__.py +68 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/_functions.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/container.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/conv.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/flatten.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/fold.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/lazy.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/activation.py +1624 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py +849 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/channelshuffle.py +57 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py +911 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/dropout.py +294 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/flatten.py +144 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py +297 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/padding.py +801 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pixelshuffle.py +113 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pooling.py +1306 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py +975 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/fft.cpython-311.pyc
ADDED
|
Binary file (29.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-311.pyc
ADDED
|
Binary file (33 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-311.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-311.pyc
ADDED
|
Binary file (8.03 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (23.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_matcher.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
toq = torch.ops.quantized
|
| 6 |
+
|
| 7 |
+
from torch.fx import GraphModule
|
| 8 |
+
from torch.fx.graph import Graph, Node
|
| 9 |
+
|
| 10 |
+
from torch.ao.quantization.utils import getattr_from_fqn
|
| 11 |
+
from .ns_types import NSSubgraph, NSNodeTargetType
|
| 12 |
+
from .mappings import (
|
| 13 |
+
get_base_name_to_sets_of_related_ops,
|
| 14 |
+
get_unmatchable_types_map,
|
| 15 |
+
)
|
| 16 |
+
from .pattern_utils import (
|
| 17 |
+
get_type_a_related_to_b,
|
| 18 |
+
get_reversed_fusions,
|
| 19 |
+
end_node_matches_reversed_fusion,
|
| 20 |
+
)
|
| 21 |
+
from torch.ao.quantization import (
|
| 22 |
+
ObserverBase,
|
| 23 |
+
FakeQuantizeBase,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from typing import Dict, Tuple, List, Optional, Set, Any
|
| 27 |
+
|
| 28 |
+
def _get_output_nodes(g: Graph) -> List[Node]:
|
| 29 |
+
return [n for n in g.nodes if n.op == 'output']
|
| 30 |
+
|
| 31 |
+
class _NSGraphMatchableSubgraphsIterator:
|
| 32 |
+
"""
|
| 33 |
+
Iterates through the graph of gm, starting with the output nodes
|
| 34 |
+
and continuing backwards.
|
| 35 |
+
1. Returns matchable subgraphs, in order. A subgraph is defined by
|
| 36 |
+
(start_node, end_node).
|
| 37 |
+
2. Skips over non-matchable subgraphs
|
| 38 |
+
"""
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
gm: GraphModule,
|
| 42 |
+
non_matchable_functions: Set[NSNodeTargetType],
|
| 43 |
+
non_matchable_modules: Set[NSNodeTargetType],
|
| 44 |
+
non_matchable_methods: Set[NSNodeTargetType],
|
| 45 |
+
):
|
| 46 |
+
self.gm: GraphModule = gm
|
| 47 |
+
self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
|
| 48 |
+
self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
|
| 49 |
+
self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
|
| 50 |
+
self.seen_nodes: Set[Node] = set()
|
| 51 |
+
self.stack: List[Node] = []
|
| 52 |
+
for start_node in _get_output_nodes(self.gm.graph):
|
| 53 |
+
self.stack.append(start_node)
|
| 54 |
+
|
| 55 |
+
def __iter__(self):
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
def __next__(self) -> NSSubgraph:
|
| 59 |
+
"""
|
| 60 |
+
Returns the next matchable subgraph.
|
| 61 |
+
"""
|
| 62 |
+
while len(self.stack) > 0:
|
| 63 |
+
cur_end_node = self.stack.pop()
|
| 64 |
+
if cur_end_node in self.seen_nodes:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
# for subgraphs which are single nodes, start_node == end_node
|
| 68 |
+
# for subgraphs with more than one node, start node != end_node
|
| 69 |
+
cur_start_node = cur_end_node
|
| 70 |
+
# Subgraphs like linear-relu have the base node as the start node.
|
| 71 |
+
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
|
| 72 |
+
# base node as the second node.
|
| 73 |
+
# The cur_base_op_node var will move to the actual node during
|
| 74 |
+
# the fusion matching later in this code block.
|
| 75 |
+
cur_base_op_node = cur_end_node
|
| 76 |
+
|
| 77 |
+
# Check for potential fusions. For now, we are greedy
|
| 78 |
+
# and always skip all non-base nodes of a fusion. For example,
|
| 79 |
+
# if we match linear-relu backwards, we will always skip the
|
| 80 |
+
# relu node and attempt to match the linear node. This can
|
| 81 |
+
# be made configurable later if needed.
|
| 82 |
+
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
|
| 83 |
+
is_match = end_node_matches_reversed_fusion(
|
| 84 |
+
cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes)
|
| 85 |
+
if is_match:
|
| 86 |
+
# navigate to the base node
|
| 87 |
+
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
| 88 |
+
self.seen_nodes.add(cur_start_node)
|
| 89 |
+
# for now, assume that there are no other nodes
|
| 90 |
+
# which need to be added to the stack
|
| 91 |
+
cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
|
| 92 |
+
# if the base op index matches the current node, set it
|
| 93 |
+
rev_base_op_idx = \
|
| 94 |
+
len(_reverse_fusion_ops) - 2 - base_op_idx
|
| 95 |
+
if rev_fusion_idx == rev_base_op_idx:
|
| 96 |
+
cur_base_op_node = cur_start_node
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
self.seen_nodes.add(cur_start_node)
|
| 100 |
+
# add args of previous nodes to stack
|
| 101 |
+
for arg in cur_start_node.all_input_nodes:
|
| 102 |
+
self._recursively_add_node_arg_to_stack(arg)
|
| 103 |
+
|
| 104 |
+
# skip unmatchable nodes
|
| 105 |
+
# note: this check is done on the start_node, i.e.
|
| 106 |
+
# if we are matching linear-relu in reverse, this would do the matchable
|
| 107 |
+
# check on the linear
|
| 108 |
+
if not self._is_matchable(cur_base_op_node):
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# If an observer or a fake_quant was not matched as a part of
|
| 112 |
+
# a pattern of multiple nodes, ignore it. One case where this is
|
| 113 |
+
# relevant is an observer on a graph input, which was added because
|
| 114 |
+
# it is necessary for the next node.
|
| 115 |
+
if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
|
| 116 |
+
maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
|
| 117 |
+
if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
return NSSubgraph(
|
| 121 |
+
start_node=cur_start_node, end_node=cur_end_node,
|
| 122 |
+
base_op_node=cur_base_op_node)
|
| 123 |
+
|
| 124 |
+
raise StopIteration
|
| 125 |
+
|
| 126 |
+
def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
|
| 127 |
+
"""
|
| 128 |
+
Adds all of the nodes in this arg to the stack, properly navigating
|
| 129 |
+
through list, dicts and tuples.
|
| 130 |
+
"""
|
| 131 |
+
if isinstance(arg, Node):
|
| 132 |
+
self.stack.append(arg)
|
| 133 |
+
elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
|
| 134 |
+
for inner_arg in arg:
|
| 135 |
+
self._recursively_add_node_arg_to_stack(inner_arg)
|
| 136 |
+
elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
|
| 137 |
+
for value in arg.values():
|
| 138 |
+
self._recursively_add_node_arg_to_stack(value)
|
| 139 |
+
|
| 140 |
+
def _is_matchable(self, node: Node) -> bool:
|
| 141 |
+
if node.op == 'call_function':
|
| 142 |
+
return node.target not in self.non_matchable_functions
|
| 143 |
+
elif node.op == 'call_module':
|
| 144 |
+
assert isinstance(node.target, str)
|
| 145 |
+
target_mod = getattr_from_fqn(self.gm, node.target)
|
| 146 |
+
return not \
|
| 147 |
+
any(isinstance(target_mod, t) # type: ignore[arg-type]
|
| 148 |
+
for t in self.non_matchable_modules)
|
| 149 |
+
elif node.op == 'call_method':
|
| 150 |
+
return node.target not in self.non_matchable_methods
|
| 151 |
+
else:
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
class GraphMatchingException(Exception):
|
| 155 |
+
"""
|
| 156 |
+
Exception raised when two graphs cannot be matched.
|
| 157 |
+
"""
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
class SubgraphTypeRelationship(enum.Enum):
|
| 161 |
+
# same type, known
|
| 162 |
+
# example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
|
| 163 |
+
EQUAL = enum.auto()
|
| 164 |
+
# same type, but the type is not known to Numerical Suite
|
| 165 |
+
# (user defined type, etc).
|
| 166 |
+
EQUAL_BUT_UKNOWN = enum.auto()
|
| 167 |
+
# known, same subgraph_relationship set, but not the same type
|
| 168 |
+
# example: F.linear and toq.linear
|
| 169 |
+
RELATED_BUT_NOT_EQUAL = enum.auto()
|
| 170 |
+
# not related
|
| 171 |
+
NOT_RELATED = enum.auto()
|
| 172 |
+
|
| 173 |
+
def _get_subgraph_relationship_type(
|
| 174 |
+
subgraph_a: NSSubgraph,
|
| 175 |
+
subgraph_b: NSSubgraph,
|
| 176 |
+
gm_a: GraphModule,
|
| 177 |
+
gm_b: GraphModule,
|
| 178 |
+
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
|
| 179 |
+
) -> SubgraphTypeRelationship:
|
| 180 |
+
node_a = subgraph_a.base_op_node
|
| 181 |
+
node_b = subgraph_b.base_op_node
|
| 182 |
+
|
| 183 |
+
# TODO(next): make this code handle matching by what is before the base op
|
| 184 |
+
if node_a.op != node_b.op:
|
| 185 |
+
if not (
|
| 186 |
+
node_a.op in ('call_function', 'call_method') and
|
| 187 |
+
node_b.op in ('call_function', 'call_method')
|
| 188 |
+
):
|
| 189 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 190 |
+
|
| 191 |
+
if node_a.op in ('call_function', 'call_method'):
|
| 192 |
+
key = (node_a.target, node_b.target)
|
| 193 |
+
|
| 194 |
+
if key not in type_a_related_to_b:
|
| 195 |
+
if node_a.target == node_b.target:
|
| 196 |
+
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
|
| 197 |
+
else:
|
| 198 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 199 |
+
# after this point, we are dealing with known types
|
| 200 |
+
|
| 201 |
+
if node_a.target == node_b.target:
|
| 202 |
+
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
|
| 203 |
+
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
|
| 204 |
+
if node_a_has_prev and (not node_b_has_prev):
|
| 205 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 206 |
+
elif (not node_a_has_prev) and node_b_has_prev:
|
| 207 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 208 |
+
elif (not node_a_has_prev) and (not node_b_has_prev):
|
| 209 |
+
return SubgraphTypeRelationship.EQUAL
|
| 210 |
+
else:
|
| 211 |
+
# TODO(future PR): check for matches start_op_node and base_op_node
|
| 212 |
+
return SubgraphTypeRelationship.EQUAL
|
| 213 |
+
|
| 214 |
+
if key in type_a_related_to_b:
|
| 215 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 216 |
+
else:
|
| 217 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 218 |
+
elif node_a.op == 'call_module':
|
| 219 |
+
assert (subgraph_a.base_op_node == subgraph_a.start_node and
|
| 220 |
+
subgraph_b.base_op_node == subgraph_b.start_node), \
|
| 221 |
+
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
| 222 |
+
# for call_module, we need to look up the modules to do the type check
|
| 223 |
+
assert isinstance(node_a.target, str)
|
| 224 |
+
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
| 225 |
+
assert isinstance(node_b.target, str)
|
| 226 |
+
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
| 227 |
+
|
| 228 |
+
key = (type(mod_a), type(mod_b))
|
| 229 |
+
|
| 230 |
+
if key not in type_a_related_to_b:
|
| 231 |
+
if type(mod_a) == type(mod_b):
|
| 232 |
+
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
|
| 233 |
+
else:
|
| 234 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 235 |
+
elif type(mod_a) == type(mod_b):
|
| 236 |
+
return SubgraphTypeRelationship.EQUAL
|
| 237 |
+
else:
|
| 238 |
+
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
| 239 |
+
|
| 240 |
+
return SubgraphTypeRelationship.NOT_RELATED
|
| 241 |
+
|
| 242 |
+
def _get_name_for_subgraph(
|
| 243 |
+
subgraph_a: NSSubgraph,
|
| 244 |
+
gm_a: GraphModule,
|
| 245 |
+
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
| 246 |
+
existing_names: Set[str],
|
| 247 |
+
) -> str:
|
| 248 |
+
"""
|
| 249 |
+
Returns a unique name for a subgraph. This name is based on two things:
|
| 250 |
+
1. the name of the set containing the underlying type of the base op in the
|
| 251 |
+
subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
|
| 252 |
+
2. the number of previous subgraphs with related underlying type of the base op
|
| 253 |
+
|
| 254 |
+
For example, in the graph
|
| 255 |
+
|
| 256 |
+
linear0 -> relu0 -> linear1 -> relu1
|
| 257 |
+
|
| 258 |
+
The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
|
| 259 |
+
from the output node backwards, the name given to (linear1, relu1) will be
|
| 260 |
+
`base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
|
| 261 |
+
will be `base_op_torch.nn.functional.linear_1`.
|
| 262 |
+
|
| 263 |
+
Why are we not just using the node name? Answer: because of two requirements:
|
| 264 |
+
A. fusions must be supported
|
| 265 |
+
B. some Numeric Suite APIs can be called without having all of the models in memory
|
| 266 |
+
|
| 267 |
+
For example, let's say we need to match nodes of
|
| 268 |
+
|
| 269 |
+
(1) ... -> linear0 -> relu0 -> ...
|
| 270 |
+
|
| 271 |
+
And
|
| 272 |
+
|
| 273 |
+
(2) ... -> linear_relu0 -> ...
|
| 274 |
+
|
| 275 |
+
Without being able to inspect them together. With the current naming scheme, if
|
| 276 |
+
we iterate through both of these graphs in the same order, and assuming the rest
|
| 277 |
+
of the graphs match, both of these subgraphs will get the same name without
|
| 278 |
+
(1) and (2) knowing anything about each other.
|
| 279 |
+
"""
|
| 280 |
+
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
|
| 281 |
+
target_base_type = None
|
| 282 |
+
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
|
| 283 |
+
if target_type in sets_of_related_ops:
|
| 284 |
+
target_base_type = base_name
|
| 285 |
+
target_base_name = 'base_op_' + str(target_base_type)
|
| 286 |
+
counter = 0
|
| 287 |
+
proposed_name = target_base_name + '_' + str(counter)
|
| 288 |
+
while proposed_name in existing_names:
|
| 289 |
+
counter += 1
|
| 290 |
+
proposed_name = target_base_name + '_' + str(counter)
|
| 291 |
+
existing_names.add(proposed_name)
|
| 292 |
+
return proposed_name
|
| 293 |
+
|
| 294 |
+
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
|
| 295 |
+
if node.op in ('call_function', 'call_method'):
|
| 296 |
+
return node.target
|
| 297 |
+
elif node.op == 'call_module':
|
| 298 |
+
assert isinstance(node.target, str)
|
| 299 |
+
mod = getattr_from_fqn(gm, node.target)
|
| 300 |
+
return type(mod)
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
def get_matching_subgraph_pairs(
|
| 304 |
+
gm_a: GraphModule,
|
| 305 |
+
gm_b: GraphModule,
|
| 306 |
+
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 307 |
+
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 308 |
+
) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
|
| 309 |
+
"""
|
| 310 |
+
Matches matchable subgraphs of graph_a to graph_b.
|
| 311 |
+
|
| 312 |
+
For a node, "matchable" is defined as a node which is not an observer,
|
| 313 |
+
fake_quants, quant or dequant.
|
| 314 |
+
|
| 315 |
+
A subgraph can contain one or more nodes. A subgraph is matchable if
|
| 316 |
+
at least one node inside of it is matchable. Currently, all nodes in
|
| 317 |
+
a subgraph must be matchable (because we assume no observers will be
|
| 318 |
+
inserted in the middle of a fusion).
|
| 319 |
+
|
| 320 |
+
A subgraph is defined by (start_node, end_node). We assume that only
|
| 321 |
+
start_node and end_node are linked with the surrounding graph, all other
|
| 322 |
+
nodes in a subgraph are self-contained.
|
| 323 |
+
|
| 324 |
+
A pair of nodes is "related" if both nodes represent the same mathematical
|
| 325 |
+
operation across different quantization flavors. For example,
|
| 326 |
+
`F.linear` and `torch.ops.quantized.linear` are related, and
|
| 327 |
+
`F.linear` and `torch.nn.Conv` are not related.
|
| 328 |
+
|
| 329 |
+
For each matchable pair of nodes node_a and node_b, they will match
|
| 330 |
+
if node_a and node_b are related.
|
| 331 |
+
|
| 332 |
+
For graphs A and B, they will match iff:
|
| 333 |
+
1. the number of matchable subgraphs in A and B is equivalent
|
| 334 |
+
2. when iterating through the matchable subgraphs of A and B in the same order, each
|
| 335 |
+
corresponding pair of base nodes is related.
|
| 336 |
+
|
| 337 |
+
This enables us to find the corresponding subgraphs between
|
| 338 |
+
graphs of related models. For example, if we had two graphs such as:
|
| 339 |
+
|
| 340 |
+
graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
|
| 341 |
+
w -/
|
| 342 |
+
b -/
|
| 343 |
+
|
| 344 |
+
graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
|
| 345 |
+
packed_params_0 -/
|
| 346 |
+
|
| 347 |
+
This function will return the following result:
|
| 348 |
+
{
|
| 349 |
+
'conv_0': ( # the name of the node in graph_b
|
| 350 |
+
(conv_0, conv_0), # (start_node_a, end_node_a)
|
| 351 |
+
(qconv_0, qconv_0), # (start_node_b, end_node_b)
|
| 352 |
+
),
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
Or, if we have a fusion pattern,
|
| 356 |
+
|
| 357 |
+
graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
|
| 358 |
+
w -/
|
| 359 |
+
b -/
|
| 360 |
+
|
| 361 |
+
graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
|
| 362 |
+
packed_params_0 -/
|
| 363 |
+
|
| 364 |
+
This function will return the following result:
|
| 365 |
+
{
|
| 366 |
+
'linear_relu_0': ( # the name of the node in graph_b
|
| 367 |
+
(linear_0, relu_0), # (start_node_a, end_node_a)
|
| 368 |
+
(linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
|
| 369 |
+
),
|
| 370 |
+
}
|
| 371 |
+
"""
|
| 372 |
+
if unmatchable_types_map is None:
|
| 373 |
+
unmatchable_types_map = get_unmatchable_types_map()
|
| 374 |
+
non_matchable_functions = unmatchable_types_map['funs_unmatchable']
|
| 375 |
+
non_matchable_modules = unmatchable_types_map['mods_unmatchable']
|
| 376 |
+
non_matchable_methods = unmatchable_types_map['meths_unmatchable']
|
| 377 |
+
|
| 378 |
+
graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
|
| 379 |
+
gm_a, non_matchable_functions, non_matchable_modules,
|
| 380 |
+
non_matchable_methods)
|
| 381 |
+
graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
|
| 382 |
+
gm_b, non_matchable_functions, non_matchable_modules,
|
| 383 |
+
non_matchable_methods)
|
| 384 |
+
results = collections.OrderedDict()
|
| 385 |
+
if base_name_to_sets_of_related_ops is None:
|
| 386 |
+
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
| 387 |
+
type_a_related_to_b = \
|
| 388 |
+
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
| 389 |
+
|
| 390 |
+
existing_names_a: Set[str] = set()
|
| 391 |
+
existing_names_b: Set[str] = set()
|
| 392 |
+
|
| 393 |
+
while True:
|
| 394 |
+
# fetch the next subgraphs from a and b
|
| 395 |
+
cur_subgraph_a, cur_subgraph_b = None, None
|
| 396 |
+
try:
|
| 397 |
+
cur_subgraph_a = next(graph_a_iterator)
|
| 398 |
+
except StopIteration:
|
| 399 |
+
pass
|
| 400 |
+
try:
|
| 401 |
+
cur_subgraph_b = next(graph_b_iterator)
|
| 402 |
+
except StopIteration:
|
| 403 |
+
pass
|
| 404 |
+
|
| 405 |
+
# look up types of a and b for useful error messages
|
| 406 |
+
type_start_a, type_start_b = None, None
|
| 407 |
+
if cur_subgraph_a is not None:
|
| 408 |
+
type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
|
| 409 |
+
if cur_subgraph_b is not None:
|
| 410 |
+
type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
|
| 411 |
+
|
| 412 |
+
# check for results and determine what to do next
|
| 413 |
+
if cur_subgraph_a is not None and cur_subgraph_b is not None:
|
| 414 |
+
# both nodes were fetched, check for subgraph_relationship
|
| 415 |
+
# note: subgraph_relationship is checked on the start node, i.e.
|
| 416 |
+
# if a linear-relu pattern is checked, we would check for subgraph_relationship
|
| 417 |
+
# of the linear
|
| 418 |
+
subgraph_relationship = _get_subgraph_relationship_type(
|
| 419 |
+
cur_subgraph_a, cur_subgraph_b,
|
| 420 |
+
gm_a, gm_b, type_a_related_to_b)
|
| 421 |
+
if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
|
| 422 |
+
msg = f"""
|
| 423 |
+
The subgraphs
|
| 424 |
+
({cur_subgraph_a}, {type_start_a}) and
|
| 425 |
+
({cur_subgraph_b}, {type_start_b})
|
| 426 |
+
are not related. Please ensure that the two models you pass in have the same number
|
| 427 |
+
of subgraphs, and each pair of subgraphs is related to each other."""
|
| 428 |
+
raise GraphMatchingException(msg)
|
| 429 |
+
elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
|
| 430 |
+
# skip matching but unknown types
|
| 431 |
+
continue
|
| 432 |
+
key_name_a = _get_name_for_subgraph(
|
| 433 |
+
cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops,
|
| 434 |
+
existing_names_a)
|
| 435 |
+
key_name_b = _get_name_for_subgraph(
|
| 436 |
+
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops,
|
| 437 |
+
existing_names_b)
|
| 438 |
+
assert key_name_a == key_name_b, \
|
| 439 |
+
f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
| 440 |
+
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
|
| 441 |
+
continue
|
| 442 |
+
elif cur_subgraph_a is None and cur_subgraph_b is None:
|
| 443 |
+
# we reached the end of both graphs
|
| 444 |
+
break
|
| 445 |
+
else:
|
| 446 |
+
# only one node was fetched, no match possible, throw error
|
| 447 |
+
msg = f"""
|
| 448 |
+
Attempting to match
|
| 449 |
+
({cur_subgraph_a}, {type_start_a}) and
|
| 450 |
+
({cur_subgraph_b}, {type_start_b}),
|
| 451 |
+
one of which is empty. Please ensure that the two models you pass in have the same number
|
| 452 |
+
of subgraphs."""
|
| 453 |
+
raise GraphMatchingException(msg)
|
| 454 |
+
|
| 455 |
+
# The subgraph pairs are originally created by traversing the two graphs
|
| 456 |
+
# from the outputs to the inputs. Reverse the results to return the
|
| 457 |
+
# subgraphs in their order of execution.
|
| 458 |
+
results = collections.OrderedDict(reversed(list(results.items())))
|
| 459 |
+
|
| 460 |
+
return results
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_passes.py
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx import GraphModule, map_arg
|
| 3 |
+
from torch.fx.graph import Graph, Node
|
| 4 |
+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
| 5 |
+
|
| 6 |
+
from .utils import (
|
| 7 |
+
get_node_first_input_and_output_type,
|
| 8 |
+
getattr_from_fqn,
|
| 9 |
+
NodeInputOrOutputType,
|
| 10 |
+
return_first_non_observer_node,
|
| 11 |
+
get_number_of_non_param_args,
|
| 12 |
+
get_target_type_str,
|
| 13 |
+
get_arg_indices_of_inputs_to_log,
|
| 14 |
+
get_node_input_qparams,
|
| 15 |
+
op_type_supports_shadowing,
|
| 16 |
+
get_normalized_nth_input,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from .ns_types import (
|
| 20 |
+
NSSingleResultValuesType,
|
| 21 |
+
NSSubgraph,
|
| 22 |
+
NSNodeTargetType,
|
| 23 |
+
)
|
| 24 |
+
from torch.ao.ns.fx.mappings import (
|
| 25 |
+
get_node_type_to_io_type_map,
|
| 26 |
+
)
|
| 27 |
+
from torch.ao.quantization.observer import _is_activation_post_process
|
| 28 |
+
|
| 29 |
+
from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
|
| 30 |
+
|
| 31 |
+
def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
|
| 32 |
+
fqn = None
|
| 33 |
+
if hasattr(gm, '_node_name_to_scope'):
|
| 34 |
+
# fqn on observers is not present, because they do not
|
| 35 |
+
# exist when the fqns are created during tracing. If this is
|
| 36 |
+
# an observer, get the fqn of the node being observed.
|
| 37 |
+
node_to_use_for_fqn = node
|
| 38 |
+
if node.op == 'call_module':
|
| 39 |
+
assert isinstance(node.target, str)
|
| 40 |
+
module = getattr_from_fqn(gm, node.target)
|
| 41 |
+
if _is_activation_post_process(module):
|
| 42 |
+
node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
|
| 43 |
+
fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
|
| 44 |
+
return fqn # type: ignore[return-value]
|
| 45 |
+
|
| 46 |
+
def _insert_logger_after_node(
|
| 47 |
+
node: Node,
|
| 48 |
+
gm: GraphModule,
|
| 49 |
+
logger_cls: Callable,
|
| 50 |
+
logger_node_name_suffix: str,
|
| 51 |
+
ref_node_name: str,
|
| 52 |
+
model_name: str,
|
| 53 |
+
ref_name: str,
|
| 54 |
+
ref_node_target_type: str,
|
| 55 |
+
results_type: str,
|
| 56 |
+
index_within_arg: int,
|
| 57 |
+
index_of_arg: int,
|
| 58 |
+
fqn: Optional[str],
|
| 59 |
+
) -> Node:
|
| 60 |
+
"""
|
| 61 |
+
Given a starting graph of
|
| 62 |
+
|
| 63 |
+
prev_node -> node -> next_node
|
| 64 |
+
|
| 65 |
+
This function creates a new logger_cls obj and adds it
|
| 66 |
+
after node, resulting in
|
| 67 |
+
|
| 68 |
+
prev_node -> node -> logger_obj -> next_node
|
| 69 |
+
"""
|
| 70 |
+
# create new name
|
| 71 |
+
logger_node_name = \
|
| 72 |
+
get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
|
| 73 |
+
target_type = get_target_type_str(node, gm)
|
| 74 |
+
# create the logger object
|
| 75 |
+
logger_obj = logger_cls(
|
| 76 |
+
ref_node_name, node.name, model_name, ref_name, target_type,
|
| 77 |
+
ref_node_target_type,
|
| 78 |
+
results_type, index_within_arg, index_of_arg, fqn)
|
| 79 |
+
# attach the logger object to the parent module
|
| 80 |
+
setattr(gm, logger_node_name, logger_obj)
|
| 81 |
+
logger_node = node.graph.create_node(
|
| 82 |
+
'call_module', logger_node_name, (node,), {})
|
| 83 |
+
return logger_node
|
| 84 |
+
|
| 85 |
+
def add_loggers_to_model(
|
| 86 |
+
gm: GraphModule,
|
| 87 |
+
node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
|
| 88 |
+
node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
|
| 89 |
+
logger_cls: Callable,
|
| 90 |
+
model_name: str,
|
| 91 |
+
) -> GraphModule:
|
| 92 |
+
"""
|
| 93 |
+
Takes the graph of gm, adds loggers to the output
|
| 94 |
+
of each node in nodes_to_instrument. Returns a GraphModule with the new
|
| 95 |
+
graph.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
new_graph = Graph()
|
| 99 |
+
env: Dict[str, Any] = {}
|
| 100 |
+
modules = dict(gm.named_modules())
|
| 101 |
+
|
| 102 |
+
def load_arg(a):
|
| 103 |
+
return map_arg(a, lambda node: env[node.name])
|
| 104 |
+
|
| 105 |
+
for node in gm.graph.nodes:
|
| 106 |
+
if node.op == 'output':
|
| 107 |
+
new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
if (
|
| 111 |
+
(node in node_to_instrument_inputs_to_ref_node_name) or
|
| 112 |
+
(node in node_to_instrument_outputs_to_ref_node_name)
|
| 113 |
+
):
|
| 114 |
+
fqn = _maybe_get_fqn(node, gm)
|
| 115 |
+
|
| 116 |
+
if node in node_to_instrument_inputs_to_ref_node_name:
|
| 117 |
+
ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
|
| 118 |
+
# Ops such add and mul are special because either
|
| 119 |
+
# one or two of the first two arguments can be tensors,
|
| 120 |
+
# and if one argument is a tensor it can be first or
|
| 121 |
+
# second (x + 1 versus 1 + x).
|
| 122 |
+
arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
|
| 123 |
+
for node_arg_idx in arg_indices_to_log:
|
| 124 |
+
node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
|
| 125 |
+
if type(node_arg) == Node:
|
| 126 |
+
# create a single input logger
|
| 127 |
+
prev_node = env[node_arg.name]
|
| 128 |
+
env[node_arg.name] = _insert_logger_after_node(
|
| 129 |
+
prev_node, gm, logger_cls, '_ns_logger_', node.name,
|
| 130 |
+
model_name, ref_name, ref_node_type,
|
| 131 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 132 |
+
index_within_arg=0, index_of_arg=node_arg_idx,
|
| 133 |
+
fqn=fqn)
|
| 134 |
+
elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
|
| 135 |
+
# create N input loggers, one for each node
|
| 136 |
+
for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
|
| 137 |
+
prev_node = env[arg.name]
|
| 138 |
+
env[prev_node.name] = _insert_logger_after_node(
|
| 139 |
+
prev_node, gm, logger_cls, '_ns_logger_', node.name,
|
| 140 |
+
model_name, ref_name, ref_node_type,
|
| 141 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 142 |
+
index_within_arg=arg_idx, index_of_arg=node_arg_idx,
|
| 143 |
+
fqn=fqn)
|
| 144 |
+
else:
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
# ensure env is populated with base node
|
| 148 |
+
# Note: runs for both inputs and outputs
|
| 149 |
+
env[node.name] = new_graph.node_copy(node, load_arg)
|
| 150 |
+
|
| 151 |
+
if node in node_to_instrument_outputs_to_ref_node_name:
|
| 152 |
+
ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
|
| 153 |
+
# add the logger after the base node
|
| 154 |
+
env[node.name] = _insert_logger_after_node(
|
| 155 |
+
env[node.name], gm, logger_cls, '_ns_logger_', node.name,
|
| 156 |
+
model_name, ref_name, ref_node_type,
|
| 157 |
+
NSSingleResultValuesType.NODE_OUTPUT.value,
|
| 158 |
+
index_within_arg=0, index_of_arg=0, fqn=fqn)
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
env[node.name] = new_graph.node_copy(node, load_arg)
|
| 162 |
+
|
| 163 |
+
new_gm = GraphModule(gm, new_graph)
|
| 164 |
+
return new_gm
|
| 165 |
+
|
| 166 |
+
def _insert_quantize_per_tensor_node(
|
| 167 |
+
prev_node_c: Node,
|
| 168 |
+
node_a: Node,
|
| 169 |
+
gm_b: GraphModule,
|
| 170 |
+
graph_c: Graph,
|
| 171 |
+
scale: Union[torch.Tensor, float],
|
| 172 |
+
zero_point: Union[torch.Tensor, int],
|
| 173 |
+
dtype_cast_name: str,
|
| 174 |
+
) -> Node:
|
| 175 |
+
# copy scale
|
| 176 |
+
scale_node_name = \
|
| 177 |
+
get_new_attr_name_with_prefix(
|
| 178 |
+
node_a.name + '_input_scale_')(gm_b)
|
| 179 |
+
setattr(gm_b, scale_node_name, scale)
|
| 180 |
+
scale_node = graph_c.create_node(
|
| 181 |
+
'get_attr', scale_node_name, (), {}, scale_node_name)
|
| 182 |
+
# copy zero_point
|
| 183 |
+
zero_point_node_name = \
|
| 184 |
+
get_new_attr_name_with_prefix(
|
| 185 |
+
node_a.name + '_input_zero_point_')(gm_b)
|
| 186 |
+
setattr(gm_b, zero_point_node_name, zero_point)
|
| 187 |
+
zero_point_node = graph_c.create_node(
|
| 188 |
+
'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
|
| 189 |
+
# create the quantize_per_tensor call
|
| 190 |
+
return graph_c.create_node(
|
| 191 |
+
'call_function', torch.quantize_per_tensor,
|
| 192 |
+
(prev_node_c, scale_node, zero_point_node, torch.quint8), {},
|
| 193 |
+
dtype_cast_name)
|
| 194 |
+
|
| 195 |
+
def _insert_dtype_cast_after_node(
|
| 196 |
+
node_a: Node,
|
| 197 |
+
node_c: Node,
|
| 198 |
+
prev_node_c: Union[Node, List[Node]],
|
| 199 |
+
gm_a: GraphModule,
|
| 200 |
+
gm_b: GraphModule,
|
| 201 |
+
graph_c: Graph,
|
| 202 |
+
node_name_prefix: str,
|
| 203 |
+
logger_cls: Callable,
|
| 204 |
+
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
|
| 205 |
+
) -> Union[Node, List[Node]]:
|
| 206 |
+
"""
|
| 207 |
+
Given a starting graph C (derived from graph B) of
|
| 208 |
+
|
| 209 |
+
... -> prev_node_c -> node_c -> ...
|
| 210 |
+
|
| 211 |
+
And a corresponding related node_a, inserts the correct dtype
|
| 212 |
+
cast node after prev_node_c to cast into the dtype expected
|
| 213 |
+
by node_a, resulting in:
|
| 214 |
+
|
| 215 |
+
dtype_cast
|
| 216 |
+
/
|
| 217 |
+
... -> prev_node_c -> node_c -> ...
|
| 218 |
+
|
| 219 |
+
For example, if node_c is an int8 op and node_a is an fp32 op, this function
|
| 220 |
+
will insert a dequant.
|
| 221 |
+
"""
|
| 222 |
+
dtype_cast_op = None
|
| 223 |
+
dtype_cast_mod_cls = None
|
| 224 |
+
dtype_cast_method = None
|
| 225 |
+
dtype_cast_method_dtype = None
|
| 226 |
+
dtype_cast_scale = None
|
| 227 |
+
dtype_cast_zero_point = None
|
| 228 |
+
node_input_type_a, _node_output_type_a = \
|
| 229 |
+
get_node_first_input_and_output_type(
|
| 230 |
+
node_a, gm_a, logger_cls, node_type_to_io_type_map)
|
| 231 |
+
node_input_type_c, _node_output_type_c = \
|
| 232 |
+
get_node_first_input_and_output_type(
|
| 233 |
+
node_c, gm_b, logger_cls, node_type_to_io_type_map)
|
| 234 |
+
|
| 235 |
+
if (
|
| 236 |
+
(node_input_type_a == NodeInputOrOutputType.FP32 and
|
| 237 |
+
node_input_type_c == NodeInputOrOutputType.INT8) or
|
| 238 |
+
(node_input_type_a == NodeInputOrOutputType.FP32 and
|
| 239 |
+
node_input_type_c == NodeInputOrOutputType.FP16) or
|
| 240 |
+
# TODO(future PR): determine the actual dtype of node_c,
|
| 241 |
+
# the current code only works because dequantize works with
|
| 242 |
+
# multiple input dtypes.
|
| 243 |
+
(node_input_type_a == NodeInputOrOutputType.FP32 and
|
| 244 |
+
node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
|
| 245 |
+
):
|
| 246 |
+
dtype_cast_op = torch.dequantize
|
| 247 |
+
elif (
|
| 248 |
+
node_input_type_a == node_input_type_c and
|
| 249 |
+
node_input_type_a != NodeInputOrOutputType.UNKNOWN
|
| 250 |
+
):
|
| 251 |
+
dtype_cast_mod_cls = torch.nn.Identity
|
| 252 |
+
elif (
|
| 253 |
+
node_input_type_a == NodeInputOrOutputType.INT8 and
|
| 254 |
+
node_input_type_c == NodeInputOrOutputType.FP32
|
| 255 |
+
):
|
| 256 |
+
# int8 shadows fp32, the dtype cast needs to quantize to int8
|
| 257 |
+
# with the right qparams.
|
| 258 |
+
node_a_input_qparams = get_node_input_qparams(
|
| 259 |
+
node_a, gm_a, node_type_to_io_type_map)
|
| 260 |
+
if node_a_input_qparams is not None:
|
| 261 |
+
dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
|
| 262 |
+
dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
|
| 263 |
+
elif (
|
| 264 |
+
node_input_type_a == NodeInputOrOutputType.FP16 and
|
| 265 |
+
node_input_type_c == NodeInputOrOutputType.FP32
|
| 266 |
+
):
|
| 267 |
+
dtype_cast_method = 'to'
|
| 268 |
+
dtype_cast_method_dtype = torch.float16
|
| 269 |
+
else:
|
| 270 |
+
raise AssertionError(
|
| 271 |
+
f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
|
| 272 |
+
f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
|
| 273 |
+
|
| 274 |
+
if isinstance(prev_node_c, Node):
|
| 275 |
+
new_dtype_cast_name = \
|
| 276 |
+
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 277 |
+
if dtype_cast_op:
|
| 278 |
+
if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
|
| 279 |
+
return _insert_quantize_per_tensor_node(
|
| 280 |
+
prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
|
| 281 |
+
dtype_cast_zero_point, new_dtype_cast_name)
|
| 282 |
+
else:
|
| 283 |
+
return graph_c.create_node(
|
| 284 |
+
'call_function', dtype_cast_op, (prev_node_c,), {},
|
| 285 |
+
new_dtype_cast_name)
|
| 286 |
+
elif dtype_cast_method:
|
| 287 |
+
return graph_c.create_node(
|
| 288 |
+
'call_method', dtype_cast_method,
|
| 289 |
+
(prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
|
| 290 |
+
else:
|
| 291 |
+
assert dtype_cast_mod_cls
|
| 292 |
+
dtype_cast_mod = dtype_cast_mod_cls()
|
| 293 |
+
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
| 294 |
+
return graph_c.create_node(
|
| 295 |
+
'call_module', new_dtype_cast_name, (prev_node_c,), {},
|
| 296 |
+
new_dtype_cast_name)
|
| 297 |
+
elif isinstance(prev_node_c, list):
|
| 298 |
+
results = []
|
| 299 |
+
for prev_node_c_inner in prev_node_c:
|
| 300 |
+
new_dtype_cast_name = \
|
| 301 |
+
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 302 |
+
if dtype_cast_op:
|
| 303 |
+
# TODO(future PR): add handling for quantize_per_tensor
|
| 304 |
+
new_dtype_cast_node = graph_c.create_node(
|
| 305 |
+
'call_function', dtype_cast_op, (prev_node_c_inner,), {},
|
| 306 |
+
new_dtype_cast_name)
|
| 307 |
+
results.append(new_dtype_cast_node)
|
| 308 |
+
else:
|
| 309 |
+
assert dtype_cast_mod_cls
|
| 310 |
+
dtype_cast_mod = dtype_cast_mod_cls()
|
| 311 |
+
setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
|
| 312 |
+
new_dtype_cast_node = graph_c.create_node(
|
| 313 |
+
'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
|
| 314 |
+
new_dtype_cast_name)
|
| 315 |
+
results.append(new_dtype_cast_node)
|
| 316 |
+
return results
|
| 317 |
+
else:
|
| 318 |
+
raise AssertionError(f"type f{type(prev_node_c)} is not handled")
|
| 319 |
+
|
| 320 |
+
# TODO(future PR): look into using copy_node API instead
|
| 321 |
+
def _copy_node_from_a_to_c(
|
| 322 |
+
node_a: Node,
|
| 323 |
+
gm_a: GraphModule,
|
| 324 |
+
gm_b: GraphModule,
|
| 325 |
+
graph_c: Graph,
|
| 326 |
+
) -> Node:
|
| 327 |
+
"""
|
| 328 |
+
Simple copy of node_a to graph_c.
|
| 329 |
+
"""
|
| 330 |
+
if node_a.op == 'get_attr':
|
| 331 |
+
node_a_copy_name = \
|
| 332 |
+
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
|
| 333 |
+
node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
|
| 334 |
+
if torch.is_tensor(node_a_obj):
|
| 335 |
+
node_a_obj = node_a_obj.detach()
|
| 336 |
+
setattr(gm_b, node_a_copy_name, node_a_obj)
|
| 337 |
+
node_a_copy = graph_c.create_node(
|
| 338 |
+
node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
|
| 339 |
+
return node_a_copy
|
| 340 |
+
elif node_a.op == 'call_method':
|
| 341 |
+
assert node_a.target in ('dequantize', 'to'), \
|
| 342 |
+
f"target {node_a.target} is not implemented"
|
| 343 |
+
if node_a.target == 'dequantize':
|
| 344 |
+
arg_copy = _copy_node_from_a_to_c(
|
| 345 |
+
get_normalized_nth_input(node_a, gm_a, 0),
|
| 346 |
+
gm_a, gm_b, graph_c) # type: ignore[arg-type]
|
| 347 |
+
node_a_copy_name = \
|
| 348 |
+
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
|
| 349 |
+
node_a_copy = graph_c.create_node(
|
| 350 |
+
node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
|
| 351 |
+
return node_a_copy
|
| 352 |
+
else: # to
|
| 353 |
+
arg_copy = _copy_node_from_a_to_c(
|
| 354 |
+
get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type]
|
| 355 |
+
node_a_copy_name = \
|
| 356 |
+
get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
|
| 357 |
+
node_a_copy = graph_c.create_node(
|
| 358 |
+
node_a.op, node_a.target,
|
| 359 |
+
(arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
|
| 360 |
+
{}, node_a_copy_name)
|
| 361 |
+
return node_a_copy
|
| 362 |
+
|
| 363 |
+
else:
|
| 364 |
+
raise AssertionError(
|
| 365 |
+
f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
|
| 366 |
+
|
| 367 |
+
def _can_insert_copy_of_subgraph_a(
|
| 368 |
+
subgraph_a: NSSubgraph,
|
| 369 |
+
gm_a: GraphModule,
|
| 370 |
+
num_non_param_args_node_a: int,
|
| 371 |
+
) -> bool:
|
| 372 |
+
"""
|
| 373 |
+
This function returns `False` if the input subgraph cannot be copied by
|
| 374 |
+
`_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
|
| 375 |
+
that there is a corner case logic for which copy is not yet implemented.
|
| 376 |
+
"""
|
| 377 |
+
# populate the list of nodes we need to check
|
| 378 |
+
nodes = []
|
| 379 |
+
cur_node = subgraph_a.end_node
|
| 380 |
+
while cur_node != subgraph_a.start_node:
|
| 381 |
+
nodes.append(cur_node)
|
| 382 |
+
cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
|
| 383 |
+
nodes.append(cur_node)
|
| 384 |
+
nodes.reverse()
|
| 385 |
+
|
| 386 |
+
def _can_insert(node_a_arg, gm_a):
|
| 387 |
+
if isinstance(node_a_arg, Node):
|
| 388 |
+
arg_a = return_first_non_observer_node(node_a_arg, gm_a)
|
| 389 |
+
if arg_a.op == 'call_method':
|
| 390 |
+
return arg_a.target in ('dequantize', 'to')
|
| 391 |
+
elif arg_a.op == 'get_attr':
|
| 392 |
+
return True
|
| 393 |
+
else:
|
| 394 |
+
return False
|
| 395 |
+
elif isinstance(node_a_arg, (list, tuple)):
|
| 396 |
+
for el in node_a_arg:
|
| 397 |
+
if not isinstance(el, Node):
|
| 398 |
+
return False
|
| 399 |
+
return True
|
| 400 |
+
|
| 401 |
+
# For each node, check if we handle the copy behavior. This follows the
|
| 402 |
+
# logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
|
| 403 |
+
for node_a in nodes:
|
| 404 |
+
|
| 405 |
+
local_num_non_param_args_node_a = num_non_param_args_node_a \
|
| 406 |
+
if node_a is nodes[0] else 1
|
| 407 |
+
|
| 408 |
+
norm_args_kwargs = node_a.normalized_arguments(
|
| 409 |
+
gm_a, normalize_to_only_use_kwargs=True)
|
| 410 |
+
if norm_args_kwargs is not None:
|
| 411 |
+
norm_args, norm_kwargs = norm_args_kwargs
|
| 412 |
+
else:
|
| 413 |
+
norm_args, norm_kwargs = node_a.args, node_a.kwargs
|
| 414 |
+
|
| 415 |
+
cur_idx = 0
|
| 416 |
+
|
| 417 |
+
while cur_idx < len(norm_args):
|
| 418 |
+
if cur_idx == 0:
|
| 419 |
+
pass
|
| 420 |
+
elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
|
| 421 |
+
pass
|
| 422 |
+
else:
|
| 423 |
+
if not _can_insert(norm_args[cur_idx], gm_a):
|
| 424 |
+
return False
|
| 425 |
+
cur_idx += 1
|
| 426 |
+
|
| 427 |
+
for kwarg_val in norm_kwargs.values():
|
| 428 |
+
# stitch the inputs from base graph
|
| 429 |
+
if cur_idx == 0:
|
| 430 |
+
pass
|
| 431 |
+
elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
|
| 432 |
+
pass
|
| 433 |
+
else:
|
| 434 |
+
if not _can_insert(kwarg_val, gm_a):
|
| 435 |
+
return False
|
| 436 |
+
cur_idx += 1
|
| 437 |
+
|
| 438 |
+
return True
|
| 439 |
+
|
| 440 |
+
def _insert_copy_of_subgraph_a_after_input_node_c(
|
| 441 |
+
input_node_c: Union[Node, List[Node]],
|
| 442 |
+
input_node_c_2: Optional[Union[Node, List[Node]]],
|
| 443 |
+
subgraph_a: NSSubgraph,
|
| 444 |
+
gm_a: GraphModule,
|
| 445 |
+
gm_b: GraphModule,
|
| 446 |
+
node_name_prefix: str,
|
| 447 |
+
) -> Node:
|
| 448 |
+
"""
|
| 449 |
+
TODO(before land): real docblock
|
| 450 |
+
"""
|
| 451 |
+
if isinstance(input_node_c, Node):
|
| 452 |
+
graph_c = input_node_c.graph
|
| 453 |
+
else:
|
| 454 |
+
assert isinstance(input_node_c, list)
|
| 455 |
+
graph_c = input_node_c[0].graph
|
| 456 |
+
|
| 457 |
+
# create a sequential list of the subgraphs' nodes from start to end,
|
| 458 |
+
# because we need to add the nodes to graph C in non-reverse order
|
| 459 |
+
nodes_of_a = [subgraph_a.end_node]
|
| 460 |
+
cur_node = subgraph_a.end_node
|
| 461 |
+
while cur_node != subgraph_a.start_node:
|
| 462 |
+
cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
|
| 463 |
+
nodes_of_a.insert(0, cur_node)
|
| 464 |
+
|
| 465 |
+
# go through nodes of a in order, and insert them into the graph of c
|
| 466 |
+
# sequentially
|
| 467 |
+
cur_node_a = nodes_of_a[0]
|
| 468 |
+
cur_node_c = _insert_copy_of_node_a_after_input_node_c(
|
| 469 |
+
input_node_c,
|
| 470 |
+
input_node_c_2,
|
| 471 |
+
cur_node_a,
|
| 472 |
+
gm_a,
|
| 473 |
+
gm_b,
|
| 474 |
+
node_name_prefix)
|
| 475 |
+
for cur_idx_a in range(1, len(nodes_of_a)):
|
| 476 |
+
cur_node_a = nodes_of_a[cur_idx_a]
|
| 477 |
+
prev_node_c = cur_node_c # previous added node is the input to next node
|
| 478 |
+
cur_node_c = _insert_copy_of_node_a_after_input_node_c(
|
| 479 |
+
prev_node_c,
|
| 480 |
+
# TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
|
| 481 |
+
None,
|
| 482 |
+
cur_node_a,
|
| 483 |
+
gm_a,
|
| 484 |
+
gm_b,
|
| 485 |
+
node_name_prefix)
|
| 486 |
+
# return the last inserted node
|
| 487 |
+
return cur_node_c
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _insert_copy_of_node_a_after_input_node_c(
|
| 491 |
+
input_node_c: Union[Node, List[Node]],
|
| 492 |
+
input_node_c_2: Optional[Union[Node, List[Node]]],
|
| 493 |
+
node_a: Node,
|
| 494 |
+
gm_a: GraphModule,
|
| 495 |
+
gm_b: GraphModule,
|
| 496 |
+
node_name_prefix: str,
|
| 497 |
+
) -> Node:
|
| 498 |
+
"""
|
| 499 |
+
Assume that node_a from graph_a has
|
| 500 |
+
args (input, (input2)?, arg1, ...), and
|
| 501 |
+
kwargs {kw0: kwarg0, ...}
|
| 502 |
+
|
| 503 |
+
Note: input2 is optional. If it equals to None, we assume that the op
|
| 504 |
+
has a single non-param input. If it is specified, we assume that the op
|
| 505 |
+
has two non-param inputs.
|
| 506 |
+
|
| 507 |
+
Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
|
| 508 |
+
and creates the corresponding nodes in graph_c. Note: observers are ignored,
|
| 509 |
+
so if an arg is an observer we navigate up until we find a non-observer parent.
|
| 510 |
+
|
| 511 |
+
If node_a is a call_module, points the module pointed to by node_a to gm_b.
|
| 512 |
+
|
| 513 |
+
Creates the copy of node_a in graph_c, with input as the first arg,
|
| 514 |
+
and all other args and kwargs pointing to the copies of the objects
|
| 515 |
+
in gm_b created above.
|
| 516 |
+
|
| 517 |
+
An example in pictures:
|
| 518 |
+
|
| 519 |
+
graph A:
|
| 520 |
+
========
|
| 521 |
+
|
| 522 |
+
input -------------> node_a
|
| 523 |
+
/ / /
|
| 524 |
+
(input_2)?----------/ / /
|
| 525 |
+
/ /
|
| 526 |
+
weight -> weight_obs /
|
| 527 |
+
/
|
| 528 |
+
bias ----------------
|
| 529 |
+
|
| 530 |
+
graph C (derived from B):
|
| 531 |
+
=========================
|
| 532 |
+
|
| 533 |
+
input_node_c --> node_a_copy
|
| 534 |
+
/ / /
|
| 535 |
+
(input_node_c_2)? / /
|
| 536 |
+
/ /
|
| 537 |
+
weight_copy ----/ /
|
| 538 |
+
/
|
| 539 |
+
bias_copy ------/
|
| 540 |
+
"""
|
| 541 |
+
if isinstance(input_node_c, Node):
|
| 542 |
+
graph_c = input_node_c.graph
|
| 543 |
+
else:
|
| 544 |
+
assert isinstance(input_node_c, list)
|
| 545 |
+
graph_c = input_node_c[0].graph
|
| 546 |
+
|
| 547 |
+
norm_args_kwargs = node_a.normalized_arguments(
|
| 548 |
+
gm_a, normalize_to_only_use_kwargs=True)
|
| 549 |
+
if norm_args_kwargs is not None:
|
| 550 |
+
norm_args, norm_kwargs = norm_args_kwargs
|
| 551 |
+
else:
|
| 552 |
+
norm_args, norm_kwargs = node_a.args, node_a.kwargs
|
| 553 |
+
|
| 554 |
+
new_args = []
|
| 555 |
+
new_kwargs = {}
|
| 556 |
+
|
| 557 |
+
def _copy_arg(arg):
|
| 558 |
+
# copy the other inputs from the other graph
|
| 559 |
+
if isinstance(arg, Node):
|
| 560 |
+
arg = return_first_non_observer_node(arg, gm_a)
|
| 561 |
+
arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
|
| 562 |
+
return arg
|
| 563 |
+
elif isinstance(arg, (int, float, torch.dtype)):
|
| 564 |
+
return arg
|
| 565 |
+
elif isinstance(kwarg_val, (list, tuple)):
|
| 566 |
+
for el in kwarg_val:
|
| 567 |
+
assert not isinstance(el, Node), \
|
| 568 |
+
"handling of Node inside list is not implemented"
|
| 569 |
+
return arg
|
| 570 |
+
else:
|
| 571 |
+
raise AssertionError(
|
| 572 |
+
f"handling for kwarg of type {type(kwarg_val)} is not implemented")
|
| 573 |
+
|
| 574 |
+
cur_idx = 0
|
| 575 |
+
|
| 576 |
+
while cur_idx < len(norm_args):
|
| 577 |
+
if cur_idx == 0:
|
| 578 |
+
new_arg = input_node_c
|
| 579 |
+
elif cur_idx == 1 and input_node_c_2 is not None:
|
| 580 |
+
new_arg = input_node_c_2
|
| 581 |
+
else:
|
| 582 |
+
new_arg = _copy_arg(norm_args[cur_idx])
|
| 583 |
+
new_args.append(new_arg)
|
| 584 |
+
cur_idx += 1
|
| 585 |
+
|
| 586 |
+
for kwarg_name, kwarg_val in norm_kwargs.items():
|
| 587 |
+
# stitch the inputs from base graph
|
| 588 |
+
if cur_idx == 0:
|
| 589 |
+
new_kwargs[kwarg_name] = input_node_c
|
| 590 |
+
elif cur_idx == 1 and input_node_c_2 is not None:
|
| 591 |
+
new_kwargs[kwarg_name] = input_node_c_2
|
| 592 |
+
else:
|
| 593 |
+
new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
|
| 594 |
+
cur_idx += 1
|
| 595 |
+
|
| 596 |
+
new_args = tuple(new_args) # type: ignore[assignment]
|
| 597 |
+
|
| 598 |
+
node_a_shadows_c_name = \
|
| 599 |
+
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 600 |
+
|
| 601 |
+
if node_a.op == 'call_module':
|
| 602 |
+
# if target is a module, we point to the module from gm_b
|
| 603 |
+
new_mod_copy_name = \
|
| 604 |
+
get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
|
| 605 |
+
# fetch the corresponding module from gm_a
|
| 606 |
+
assert isinstance(node_a.target, str)
|
| 607 |
+
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
| 608 |
+
setattr(gm_b, new_mod_copy_name, mod_a)
|
| 609 |
+
node_a_shadows_c = graph_c.create_node(
|
| 610 |
+
node_a.op, new_mod_copy_name, new_args,
|
| 611 |
+
new_kwargs, node_a_shadows_c_name)
|
| 612 |
+
return node_a_shadows_c
|
| 613 |
+
else:
|
| 614 |
+
assert node_a.op in ('call_function', 'call_method')
|
| 615 |
+
node_a_shadows_c = graph_c.create_node(
|
| 616 |
+
node_a.op, node_a.target, new_args,
|
| 617 |
+
new_kwargs, node_a_shadows_c_name)
|
| 618 |
+
return node_a_shadows_c
|
| 619 |
+
|
| 620 |
+
def create_a_shadows_b(
|
| 621 |
+
name_a: str,
|
| 622 |
+
gm_a: GraphModule,
|
| 623 |
+
name_b: str,
|
| 624 |
+
gm_b: GraphModule,
|
| 625 |
+
matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
|
| 626 |
+
logger_cls: Callable,
|
| 627 |
+
should_log_inputs: bool,
|
| 628 |
+
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
| 629 |
+
) -> GraphModule:
|
| 630 |
+
"""
|
| 631 |
+
Creates a new GraphModule consisting of the graph of C, with the meaningful
|
| 632 |
+
nodes of A shadowing the corresponding nodes of B. For example,
|
| 633 |
+
|
| 634 |
+
Graph A:
|
| 635 |
+
a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
|
| 636 |
+
|
| 637 |
+
Graph B:
|
| 638 |
+
b0 -> op0_int8 -> b1 -> op1_int8 -> b2
|
| 639 |
+
|
| 640 |
+
matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
|
| 641 |
+
|
| 642 |
+
Graph C (A shadows B):
|
| 643 |
+
|
| 644 |
+
/ dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
|
| 645 |
+
/ /
|
| 646 |
+
b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
|
| 647 |
+
|
| 648 |
+
In a nutshell, this function does the following for each node pair:
|
| 649 |
+
* copies the necessary attributes and modules from gm_a to gm_b,
|
| 650 |
+
keeping names unique
|
| 651 |
+
* adds a dtype cast op (dequant, quant, etc)
|
| 652 |
+
* adds a copy of node_a in gm_b's graph
|
| 653 |
+
* adds loggers to the outputs of node_a and node_b
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
if node_type_to_io_type_map is None:
|
| 657 |
+
node_type_to_io_type_map = get_node_type_to_io_type_map()
|
| 658 |
+
|
| 659 |
+
# graph_c is the graph created from copying the nodes of graph_b and inserting
|
| 660 |
+
# the shadows with the nodes copied from graph_a
|
| 661 |
+
graph_c = Graph()
|
| 662 |
+
env_c: Dict[str, Any] = {}
|
| 663 |
+
modules = dict(gm_b.named_modules())
|
| 664 |
+
|
| 665 |
+
def load_arg(a):
|
| 666 |
+
return map_arg(a, lambda node: env_c[node.name])
|
| 667 |
+
|
| 668 |
+
start_node_b_to_matched_subgraph_a_and_name = {}
|
| 669 |
+
end_node_b_to_matched_subgraph_a_and_name = {}
|
| 670 |
+
for match_name, match in matched_subgraph_pairs.items():
|
| 671 |
+
subgraph_a, subgraph_b = match
|
| 672 |
+
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
| 673 |
+
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
| 674 |
+
start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
|
| 675 |
+
(subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
|
| 676 |
+
end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
|
| 677 |
+
(subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
|
| 678 |
+
|
| 679 |
+
for node_b in gm_b.graph.nodes:
|
| 680 |
+
if node_b.op == 'output':
|
| 681 |
+
graph_c.output(map_arg(node_b.args[0], load_arg))
|
| 682 |
+
continue
|
| 683 |
+
|
| 684 |
+
# calculate the flags to determine what to do with this node
|
| 685 |
+
node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
|
| 686 |
+
node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
|
| 687 |
+
|
| 688 |
+
if (node_b_is_start_node or node_b_is_end_node):
|
| 689 |
+
|
| 690 |
+
if node_b_is_start_node:
|
| 691 |
+
subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
|
| 692 |
+
start_node_b_to_matched_subgraph_a_and_name[node_b]
|
| 693 |
+
else:
|
| 694 |
+
assert node_b_is_end_node
|
| 695 |
+
subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
|
| 696 |
+
end_node_b_to_matched_subgraph_a_and_name[node_b]
|
| 697 |
+
|
| 698 |
+
all_op_types_support_shadowing = (
|
| 699 |
+
op_type_supports_shadowing(subgraph_a.start_node) and
|
| 700 |
+
op_type_supports_shadowing(node_b)
|
| 701 |
+
)
|
| 702 |
+
if not all_op_types_support_shadowing:
|
| 703 |
+
print(
|
| 704 |
+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
|
| 705 |
+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
|
| 706 |
+
', unsupported')
|
| 707 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 708 |
+
continue
|
| 709 |
+
|
| 710 |
+
# For both start_node and end_node verify that we know how to do
|
| 711 |
+
# the dtype cast. If we do not, skip.
|
| 712 |
+
node_input_type_a, node_output_type_a = \
|
| 713 |
+
get_node_first_input_and_output_type(
|
| 714 |
+
subgraph_a.start_node, gm_a, logger_cls,
|
| 715 |
+
node_type_to_io_type_map)
|
| 716 |
+
node_input_type_b, node_output_type_b = \
|
| 717 |
+
get_node_first_input_and_output_type(
|
| 718 |
+
node_b, gm_b, logger_cls,
|
| 719 |
+
node_type_to_io_type_map)
|
| 720 |
+
node_io_types_known_a_and_b = (
|
| 721 |
+
node_input_type_a != NodeInputOrOutputType.UNKNOWN and
|
| 722 |
+
node_output_type_a != NodeInputOrOutputType.UNKNOWN and
|
| 723 |
+
node_input_type_b != NodeInputOrOutputType.UNKNOWN and
|
| 724 |
+
node_output_type_b != NodeInputOrOutputType.UNKNOWN
|
| 725 |
+
)
|
| 726 |
+
if not node_io_types_known_a_and_b:
|
| 727 |
+
print(
|
| 728 |
+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
|
| 729 |
+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
|
| 730 |
+
', unknown dtype cast')
|
| 731 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 732 |
+
continue
|
| 733 |
+
|
| 734 |
+
# If we are shadowing from fp32 to int8, we need to insert
|
| 735 |
+
# quantize_per_tensor call with qparams from the previous node.
|
| 736 |
+
# Only do this if we are able to infer these qparams from the graph.
|
| 737 |
+
if (
|
| 738 |
+
node_input_type_a == NodeInputOrOutputType.INT8 and
|
| 739 |
+
node_input_type_b == NodeInputOrOutputType.FP32
|
| 740 |
+
):
|
| 741 |
+
node_a_input_qparams = get_node_input_qparams(
|
| 742 |
+
subgraph_a.start_node, gm_a, node_type_to_io_type_map)
|
| 743 |
+
if not node_a_input_qparams:
|
| 744 |
+
print(
|
| 745 |
+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
|
| 746 |
+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
|
| 747 |
+
', unknown input qparams')
|
| 748 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 749 |
+
continue
|
| 750 |
+
|
| 751 |
+
num_non_param_args_node_a = \
|
| 752 |
+
get_number_of_non_param_args(subgraph_a.start_node, gm_a)
|
| 753 |
+
if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
|
| 754 |
+
print(
|
| 755 |
+
f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
|
| 756 |
+
f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
|
| 757 |
+
', unhandled logic in subgraph copy')
|
| 758 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 759 |
+
continue
|
| 760 |
+
|
| 761 |
+
fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
|
| 762 |
+
fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined]
|
| 763 |
+
|
| 764 |
+
if node_b_is_start_node:
|
| 765 |
+
|
| 766 |
+
# if necessary, log the input of node_c
|
| 767 |
+
if should_log_inputs:
|
| 768 |
+
prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
|
| 769 |
+
if isinstance(prev_node_b, Node):
|
| 770 |
+
prev_node_c = env_c[prev_node_b.name]
|
| 771 |
+
env_c[prev_node_c.name] = _insert_logger_after_node(
|
| 772 |
+
prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
|
| 773 |
+
node_b.name, name_b, ref_name, ref_node_type_b,
|
| 774 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 775 |
+
index_within_arg=0, index_of_arg=0,
|
| 776 |
+
fqn=fqn_base_b)
|
| 777 |
+
elif isinstance(prev_node_b, list):
|
| 778 |
+
# first, save the prev_node instances, because they
|
| 779 |
+
# will be overwritten in the env after the first logger
|
| 780 |
+
# is added
|
| 781 |
+
prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
|
| 782 |
+
|
| 783 |
+
for arg_idx, arg in enumerate(prev_node_b):
|
| 784 |
+
prev_node_c = prev_node_c_list[arg_idx]
|
| 785 |
+
env_c[prev_node_c.name] = _insert_logger_after_node(
|
| 786 |
+
prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
|
| 787 |
+
node_b.name, name_b, ref_name, ref_node_type_b,
|
| 788 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 789 |
+
index_within_arg=arg_idx, index_of_arg=0,
|
| 790 |
+
fqn=fqn_base_b)
|
| 791 |
+
else:
|
| 792 |
+
# logging of inputs which are not lists is not supported yet
|
| 793 |
+
raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
|
| 794 |
+
# subgraph so far:
|
| 795 |
+
#
|
| 796 |
+
# (prev_node_c)+ -> (logger_c_input)?
|
| 797 |
+
|
| 798 |
+
# Note: this if statement is always True, spelling it out to clarify code
|
| 799 |
+
# intent.
|
| 800 |
+
if node_b_is_start_node or node_b_is_end_node:
|
| 801 |
+
# ensure env_c is populated with base node
|
| 802 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 803 |
+
node_c = env_c[node_b.name]
|
| 804 |
+
|
| 805 |
+
# after this point,
|
| 806 |
+
#
|
| 807 |
+
# node_a is the original node from graph_a, with parent module gm_a
|
| 808 |
+
# node_b is the original node from graph_b, with parent module gm_b
|
| 809 |
+
# node_c is the copy of node_b in graph_c
|
| 810 |
+
#
|
| 811 |
+
# subgraph so far:
|
| 812 |
+
#
|
| 813 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 814 |
+
|
| 815 |
+
if node_b_is_start_node:
|
| 816 |
+
|
| 817 |
+
# cast dtype from the dtype of node_c's input to the dtype of
|
| 818 |
+
# node_a's input (dequant, etc)
|
| 819 |
+
# prev_node_c = node_c.args[0]
|
| 820 |
+
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined]
|
| 821 |
+
if should_log_inputs:
|
| 822 |
+
# skip the input logger when inserting a dtype cast
|
| 823 |
+
if isinstance(prev_node_c, Node):
|
| 824 |
+
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
| 825 |
+
elif isinstance(prev_node_c, list):
|
| 826 |
+
prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
|
| 827 |
+
dtype_cast_node = _insert_dtype_cast_after_node(
|
| 828 |
+
subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
|
| 829 |
+
node_b.name + '_dtype_cast_', logger_cls,
|
| 830 |
+
node_type_to_io_type_map)
|
| 831 |
+
# note: not inserting to env_c because all nodes which use the dtype
|
| 832 |
+
# casts are copied from graph_a
|
| 833 |
+
#
|
| 834 |
+
# subgraph so far:
|
| 835 |
+
#
|
| 836 |
+
# (dtype_cast_node)+
|
| 837 |
+
# /
|
| 838 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 839 |
+
|
| 840 |
+
# if input logging is enabled, log the input to the subgraph
|
| 841 |
+
if should_log_inputs:
|
| 842 |
+
# TODO: explain this
|
| 843 |
+
ref_node_name = ''
|
| 844 |
+
if isinstance(dtype_cast_node, Node):
|
| 845 |
+
dtype_cast_node = _insert_logger_after_node(
|
| 846 |
+
dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
|
| 847 |
+
ref_node_name, name_a, ref_name, ref_node_type_a,
|
| 848 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 849 |
+
index_within_arg=0, index_of_arg=0,
|
| 850 |
+
fqn=fqn_base_a)
|
| 851 |
+
input_logger: Union[Node, List[Node]] = dtype_cast_node
|
| 852 |
+
else:
|
| 853 |
+
assert isinstance(dtype_cast_node, list)
|
| 854 |
+
new_loggers = []
|
| 855 |
+
for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
|
| 856 |
+
dtype_cast_logger = _insert_logger_after_node(
|
| 857 |
+
dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
|
| 858 |
+
ref_node_name, name_a, ref_name, ref_node_type_a,
|
| 859 |
+
NSSingleResultValuesType.NODE_INPUT.value,
|
| 860 |
+
index_within_arg=dtype_cast_idx,
|
| 861 |
+
index_of_arg=0,
|
| 862 |
+
fqn=fqn_base_a)
|
| 863 |
+
new_loggers.append(dtype_cast_logger)
|
| 864 |
+
dtype_cast_node = new_loggers
|
| 865 |
+
input_logger = dtype_cast_node
|
| 866 |
+
# subgraph so far:
|
| 867 |
+
#
|
| 868 |
+
# (dtype_cast_node)+ -> (logger_a_input)?
|
| 869 |
+
# /
|
| 870 |
+
# prev_node_c -> (logger_c_input)? -> node_start_c
|
| 871 |
+
|
| 872 |
+
# hook up the new mod_a copy to be in the graph, receiving the
|
| 873 |
+
# same inputs as mod_b does, with dtype cast to match a
|
| 874 |
+
# Some ops, such as LSTMs, have two non-param inputs. If we have
|
| 875 |
+
# such an op, pass the second param as well. Note: dtype casting
|
| 876 |
+
# for the second param is not implemented yet, it can be added
|
| 877 |
+
# later if there is a use case.
|
| 878 |
+
node_c_second_non_param_arg = None
|
| 879 |
+
num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
|
| 880 |
+
if num_non_param_args_node_a == 2:
|
| 881 |
+
# node_c_second_non_param_arg = node_c.args[1]
|
| 882 |
+
node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
|
| 883 |
+
node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
|
| 884 |
+
dtype_cast_node, node_c_second_non_param_arg,
|
| 885 |
+
subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
|
| 886 |
+
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
| 887 |
+
# subgraph so far:
|
| 888 |
+
#
|
| 889 |
+
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
|
| 890 |
+
# /
|
| 891 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 892 |
+
|
| 893 |
+
if should_log_inputs:
|
| 894 |
+
# When we created the input logger, we left the ref_node_name
|
| 895 |
+
# as an empty string, because the subgraph copy did not exist
|
| 896 |
+
# yet. Now that the subgraph copy exists, we modify this name
|
| 897 |
+
# to its true value.
|
| 898 |
+
# Note: the alternative to this is to create the input logger
|
| 899 |
+
# after creating the subgraph, which is slightly more
|
| 900 |
+
# complicated. This is the lesser of two evils.
|
| 901 |
+
# input_logger = env_c[dtype_cast_node.name]
|
| 902 |
+
# Find the first node in the subgraph
|
| 903 |
+
cur_node = node_a_shadows_c
|
| 904 |
+
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
| 905 |
+
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
| 906 |
+
if isinstance(input_logger, Node):
|
| 907 |
+
input_logger_mod = getattr(gm_b, input_logger.name)
|
| 908 |
+
input_logger_mod.ref_node_name = cur_node.name
|
| 909 |
+
else:
|
| 910 |
+
assert isinstance(input_logger, list)
|
| 911 |
+
for input_logger_inner in input_logger:
|
| 912 |
+
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
| 913 |
+
input_logger_mod.ref_node_name = cur_node.name
|
| 914 |
+
|
| 915 |
+
# hook up a logger to the mod_a copy
|
| 916 |
+
env_c[node_a_shadows_c.name] = _insert_logger_after_node(
|
| 917 |
+
env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
|
| 918 |
+
node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
|
| 919 |
+
NSSingleResultValuesType.NODE_OUTPUT.value,
|
| 920 |
+
index_within_arg=0, index_of_arg=0,
|
| 921 |
+
fqn=fqn_base_a)
|
| 922 |
+
# subgraph so far:
|
| 923 |
+
#
|
| 924 |
+
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
|
| 925 |
+
# /
|
| 926 |
+
# (prev_node_c)+ -> (logger_c_input)? -> node_start_c
|
| 927 |
+
|
| 928 |
+
if node_b_is_end_node:
|
| 929 |
+
|
| 930 |
+
# hook up a logger to the mod_b copy
|
| 931 |
+
env_c[node_b.name] = _insert_logger_after_node(
|
| 932 |
+
env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
|
| 933 |
+
node_b.name, name_b, ref_name, ref_node_type_b,
|
| 934 |
+
NSSingleResultValuesType.NODE_OUTPUT.value,
|
| 935 |
+
index_within_arg=0, index_of_arg=0,
|
| 936 |
+
fqn=fqn_base_b)
|
| 937 |
+
# subgraph so far:
|
| 938 |
+
#
|
| 939 |
+
# dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
|
| 940 |
+
# /
|
| 941 |
+
# (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
|
| 942 |
+
#
|
| 943 |
+
# Note: node_start_c may be the same node as node_end_c, or they
|
| 944 |
+
# may have nodes inbetween.
|
| 945 |
+
|
| 946 |
+
else:
|
| 947 |
+
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
|
| 948 |
+
|
| 949 |
+
gm_c = GraphModule(gm_b, graph_c)
|
| 950 |
+
return gm_c
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_safeguard.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
| 3 |
+
from torch.overrides import TorchFunctionMode
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AutogradStateOpsFailSafeguard(TorchFunctionMode):
|
| 7 |
+
"""
|
| 8 |
+
Detect grad state ops during exporting the graph and fail the process by
|
| 9 |
+
raising an error, to avoid unexpected behavior. Those grad mode ops could be:
|
| 10 |
+
`torch.no_grad`
|
| 11 |
+
`torch.enable_grad`
|
| 12 |
+
`torch.set_grad_enabled`
|
| 13 |
+
|
| 14 |
+
Export with predispatch mode is exempted.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __torch_function__(self, func, types, args=(), kwargs=None):
|
| 18 |
+
kwargs = kwargs or {}
|
| 19 |
+
unsupported_grad_mode_ops = [
|
| 20 |
+
torch._C._set_grad_enabled,
|
| 21 |
+
]
|
| 22 |
+
# It's only enabled while tracing, by confirming the torch dispatch mode is
|
| 23 |
+
# any active PROXY. This is to allow the autograd ops out of tracing.
|
| 24 |
+
current_state = torch._C.is_grad_enabled()
|
| 25 |
+
if func in unsupported_grad_mode_ops:
|
| 26 |
+
assert len(args) == 1
|
| 27 |
+
changed_state = args[0]
|
| 28 |
+
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
| 29 |
+
# Intend to check if it's not the pre_dispatch mode. It's allowed to use
|
| 30 |
+
# autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
|
| 31 |
+
if (
|
| 32 |
+
mode
|
| 33 |
+
and isinstance(mode, ProxyTorchDispatchMode)
|
| 34 |
+
and not mode.pre_dispatch
|
| 35 |
+
and changed_state != current_state
|
| 36 |
+
):
|
| 37 |
+
raise RuntimeError(
|
| 38 |
+
f"Encountered autograd state manager op {func} trying to change global autograd state "
|
| 39 |
+
"while exporting. This is unsafe because we don't capture this op in torch.export "
|
| 40 |
+
"today, hence we can't reflect the user intention soundly."
|
| 41 |
+
)
|
| 42 |
+
return func(*args, **kwargs)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_tree_utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, Optional
|
| 2 |
+
|
| 3 |
+
from torch.utils._pytree import Context, TreeSpec
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
|
| 7 |
+
"""Reorder user-provided kwargs to match the order in `spec`. `spec` is
|
| 8 |
+
expected to be the in_spec of an exported program, i.e. the spec that
|
| 9 |
+
results from flattening `(args, kwargs)`.
|
| 10 |
+
|
| 11 |
+
We need this to provide consistent input ordering, such so that users can
|
| 12 |
+
pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result.
|
| 13 |
+
"""
|
| 14 |
+
# Make sure that the spec is actually shaped like (args, kwargs)
|
| 15 |
+
assert spec.type is tuple
|
| 16 |
+
assert spec.num_children == 2
|
| 17 |
+
kwargs_spec = spec.children_specs[1]
|
| 18 |
+
assert kwargs_spec.type is dict
|
| 19 |
+
|
| 20 |
+
if set(user_kwargs) != set(kwargs_spec.context):
|
| 21 |
+
raise ValueError(
|
| 22 |
+
f"kwarg key mismatch: "
|
| 23 |
+
f"Got {list(user_kwargs)} but expected {kwargs_spec.context}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
reordered_kwargs = {}
|
| 27 |
+
for kw in kwargs_spec.context:
|
| 28 |
+
reordered_kwargs[kw] = user_kwargs[kw]
|
| 29 |
+
|
| 30 |
+
return reordered_kwargs
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_equivalent(
|
| 34 |
+
spec1: TreeSpec,
|
| 35 |
+
spec2: TreeSpec,
|
| 36 |
+
equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool],
|
| 37 |
+
) -> bool:
|
| 38 |
+
"""Customizable equivalence check for two TreeSpecs.
|
| 39 |
+
|
| 40 |
+
Arguments:
|
| 41 |
+
spec1: The first TreeSpec to compare
|
| 42 |
+
spec2: The second TreeSpec to compare
|
| 43 |
+
equivalence_fn: A function to determine the equivalence of two
|
| 44 |
+
TreeSpecs by examining their types and contexts. It will be called like:
|
| 45 |
+
|
| 46 |
+
equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context)
|
| 47 |
+
|
| 48 |
+
This function will be applied recursively to all children.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
True if the two TreeSpecs are equivalent, False otherwise.
|
| 52 |
+
"""
|
| 53 |
+
if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context):
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
# Recurse on children
|
| 57 |
+
if len(spec1.children_specs) != len(spec2.children_specs):
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
|
| 61 |
+
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
return True
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.69 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/common_types.cpython-311.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/cpp.cpython-311.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/init.cpython-311.pyc
ADDED
|
Binary file (28.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/parameter.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/thnn.cpython-311.pyc
ADDED
|
Binary file (346 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/thnn.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# this is for historical pickle deserialization, it is not used otherwise
|
| 2 |
+
|
| 3 |
+
def _get_thnn_function_backend():
|
| 4 |
+
pass
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic import ConvBn1d
|
| 2 |
+
from torch.ao.nn.intrinsic import ConvBn2d
|
| 3 |
+
from torch.ao.nn.intrinsic import ConvBn3d
|
| 4 |
+
from torch.ao.nn.intrinsic import ConvBnReLU1d
|
| 5 |
+
from torch.ao.nn.intrinsic import ConvBnReLU2d
|
| 6 |
+
from torch.ao.nn.intrinsic import ConvBnReLU3d
|
| 7 |
+
from torch.ao.nn.intrinsic import ConvReLU1d
|
| 8 |
+
from torch.ao.nn.intrinsic import ConvReLU2d
|
| 9 |
+
from torch.ao.nn.intrinsic import ConvReLU3d
|
| 10 |
+
from torch.ao.nn.intrinsic import LinearReLU
|
| 11 |
+
from torch.ao.nn.intrinsic import BNReLU2d
|
| 12 |
+
from torch.ao.nn.intrinsic import BNReLU3d
|
| 13 |
+
from torch.ao.nn.intrinsic import LinearBn1d
|
| 14 |
+
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
|
| 15 |
+
|
| 16 |
+
# Include the subpackages in case user imports from it directly
|
| 17 |
+
from . import modules # noqa: F401
|
| 18 |
+
from . import qat # noqa: F401
|
| 19 |
+
from . import quantized # noqa: F401
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
'ConvBn1d',
|
| 23 |
+
'ConvBn2d',
|
| 24 |
+
'ConvBn3d',
|
| 25 |
+
'ConvBnReLU1d',
|
| 26 |
+
'ConvBnReLU2d',
|
| 27 |
+
'ConvBnReLU3d',
|
| 28 |
+
'ConvReLU1d',
|
| 29 |
+
'ConvReLU2d',
|
| 30 |
+
'ConvReLU3d',
|
| 31 |
+
'LinearReLU',
|
| 32 |
+
'BNReLU2d',
|
| 33 |
+
'BNReLU3d',
|
| 34 |
+
'LinearBn1d',
|
| 35 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/fused.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic import BNReLU2d
|
| 2 |
+
from torch.ao.nn.intrinsic import BNReLU3d
|
| 3 |
+
from torch.ao.nn.intrinsic import ConvBn1d
|
| 4 |
+
from torch.ao.nn.intrinsic import ConvBn2d
|
| 5 |
+
from torch.ao.nn.intrinsic import ConvBn3d
|
| 6 |
+
from torch.ao.nn.intrinsic import ConvBnReLU1d
|
| 7 |
+
from torch.ao.nn.intrinsic import ConvBnReLU2d
|
| 8 |
+
from torch.ao.nn.intrinsic import ConvBnReLU3d
|
| 9 |
+
from torch.ao.nn.intrinsic import ConvReLU1d
|
| 10 |
+
from torch.ao.nn.intrinsic import ConvReLU2d
|
| 11 |
+
from torch.ao.nn.intrinsic import ConvReLU3d
|
| 12 |
+
from torch.ao.nn.intrinsic import LinearBn1d
|
| 13 |
+
from torch.ao.nn.intrinsic import LinearReLU
|
| 14 |
+
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
'BNReLU2d',
|
| 18 |
+
'BNReLU3d',
|
| 19 |
+
'ConvBn1d',
|
| 20 |
+
'ConvBn2d',
|
| 21 |
+
'ConvBn3d',
|
| 22 |
+
'ConvBnReLU1d',
|
| 23 |
+
'ConvBnReLU2d',
|
| 24 |
+
'ConvBnReLU3d',
|
| 25 |
+
'ConvReLU1d',
|
| 26 |
+
'ConvReLU2d',
|
| 27 |
+
'ConvReLU3d',
|
| 28 |
+
'LinearBn1d',
|
| 29 |
+
'LinearReLU',
|
| 30 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-311.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-311.pyc
ADDED
|
Binary file (712 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
| 2 |
+
# to ensure customers can use the module below
|
| 3 |
+
# without importing it directly
|
| 4 |
+
import torch.nn.intrinsic.quantized.dynamic
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'BNReLU2d',
|
| 8 |
+
'BNReLU3d',
|
| 9 |
+
'ConvReLU1d',
|
| 10 |
+
'ConvReLU2d',
|
| 11 |
+
'ConvReLU3d',
|
| 12 |
+
'LinearReLU',
|
| 13 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (335 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'LinearReLU',
|
| 5 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear_relu import LinearReLU
|
| 2 |
+
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
|
| 3 |
+
from .bn_relu import BNReLU2d, BNReLU3d
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'LinearReLU',
|
| 7 |
+
'ConvReLU1d',
|
| 8 |
+
'ConvReLU2d',
|
| 9 |
+
'ConvReLU3d',
|
| 10 |
+
'BNReLU2d',
|
| 11 |
+
'BNReLU3d',
|
| 12 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (556 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-311.pyc
ADDED
|
Binary file (350 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
from .linear import Identity, Linear, Bilinear, LazyLinear
|
| 3 |
+
from .conv import Conv1d, Conv2d, Conv3d, \
|
| 4 |
+
ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, \
|
| 5 |
+
LazyConv1d, LazyConv2d, LazyConv3d, LazyConvTranspose1d, LazyConvTranspose2d, LazyConvTranspose3d
|
| 6 |
+
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
|
| 7 |
+
Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \
|
| 8 |
+
Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \
|
| 9 |
+
Hardsigmoid, Hardswish, SiLU, Mish
|
| 10 |
+
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
|
| 11 |
+
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
|
| 12 |
+
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, HuberLoss, \
|
| 13 |
+
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
|
| 14 |
+
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
|
| 15 |
+
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
|
| 16 |
+
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, LPPool3d, \
|
| 17 |
+
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
|
| 18 |
+
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
|
| 19 |
+
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
|
| 20 |
+
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \
|
| 21 |
+
LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d
|
| 22 |
+
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
|
| 23 |
+
from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
|
| 24 |
+
from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \
|
| 25 |
+
ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \
|
| 26 |
+
CircularPad1d, CircularPad2d, CircularPad3d
|
| 27 |
+
from .sparse import Embedding, EmbeddingBag
|
| 28 |
+
from .rnn import RNNBase, RNN, LSTM, GRU, \
|
| 29 |
+
RNNCellBase, RNNCell, LSTMCell, GRUCell
|
| 30 |
+
from .pixelshuffle import PixelShuffle, PixelUnshuffle
|
| 31 |
+
from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample
|
| 32 |
+
from .distance import PairwiseDistance, CosineSimilarity
|
| 33 |
+
from .fold import Fold, Unfold
|
| 34 |
+
from .adaptive import AdaptiveLogSoftmaxWithLoss
|
| 35 |
+
from .transformer import TransformerEncoder, TransformerDecoder, \
|
| 36 |
+
TransformerEncoderLayer, TransformerDecoderLayer, Transformer
|
| 37 |
+
from .flatten import Flatten, Unflatten
|
| 38 |
+
from .channelshuffle import ChannelShuffle
|
| 39 |
+
|
| 40 |
+
__all__ = [
|
| 41 |
+
'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
|
| 42 |
+
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
|
| 43 |
+
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
|
| 44 |
+
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
|
| 45 |
+
'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss',
|
| 46 |
+
'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss',
|
| 47 |
+
'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'GaussianNLLLoss',
|
| 48 |
+
'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
|
| 49 |
+
'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
|
| 50 |
+
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
|
| 51 |
+
'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
|
| 52 |
+
'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
|
| 53 |
+
'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
|
| 54 |
+
'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
|
| 55 |
+
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
|
| 56 |
+
'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
|
| 57 |
+
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
|
| 58 |
+
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
|
| 59 |
+
'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
|
| 60 |
+
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
|
| 61 |
+
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
|
| 62 |
+
'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
|
| 63 |
+
'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
|
| 64 |
+
'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
|
| 65 |
+
'LazyInstanceNorm1d', 'LazyInstanceNorm2d', 'LazyInstanceNorm3d',
|
| 66 |
+
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
|
| 67 |
+
'CircularPad1d', 'CircularPad2d', 'CircularPad3d'
|
| 68 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/_functions.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/container.cpython-311.pyc
ADDED
|
Binary file (55.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (76.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/flatten.cpython-311.pyc
ADDED
|
Binary file (8.13 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/fold.cpython-311.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/lazy.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/activation.py
ADDED
|
@@ -0,0 +1,1624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from .linear import NonDynamicallyQuantizableLinear
|
| 7 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
| 8 |
+
from torch.nn.parameter import Parameter
|
| 9 |
+
from .module import Module
|
| 10 |
+
from .. import functional as F
|
| 11 |
+
|
| 12 |
+
__all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
|
| 13 |
+
'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
|
| 14 |
+
'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
|
| 15 |
+
'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Threshold(Module):
|
| 19 |
+
r"""Thresholds each element of the input Tensor.
|
| 20 |
+
|
| 21 |
+
Threshold is defined as:
|
| 22 |
+
|
| 23 |
+
.. math::
|
| 24 |
+
y =
|
| 25 |
+
\begin{cases}
|
| 26 |
+
x, &\text{ if } x > \text{threshold} \\
|
| 27 |
+
\text{value}, &\text{ otherwise }
|
| 28 |
+
\end{cases}
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
threshold: The value to threshold at
|
| 32 |
+
value: The value to replace with
|
| 33 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 34 |
+
|
| 35 |
+
Shape:
|
| 36 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 37 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 38 |
+
|
| 39 |
+
Examples::
|
| 40 |
+
|
| 41 |
+
>>> m = nn.Threshold(0.1, 20)
|
| 42 |
+
>>> input = torch.randn(2)
|
| 43 |
+
>>> output = m(input)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
__constants__ = ['threshold', 'value', 'inplace']
|
| 47 |
+
|
| 48 |
+
threshold: float
|
| 49 |
+
value: float
|
| 50 |
+
inplace: bool
|
| 51 |
+
|
| 52 |
+
def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.threshold = threshold
|
| 55 |
+
self.value = value
|
| 56 |
+
self.inplace = inplace
|
| 57 |
+
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
|
| 58 |
+
|
| 59 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 60 |
+
return F.threshold(input, self.threshold, self.value, self.inplace)
|
| 61 |
+
|
| 62 |
+
def extra_repr(self):
|
| 63 |
+
inplace_str = ', inplace=True' if self.inplace else ''
|
| 64 |
+
return f'threshold={self.threshold}, value={self.value}{inplace_str}'
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ReLU(Module):
|
| 68 |
+
r"""Applies the rectified linear unit function element-wise.
|
| 69 |
+
|
| 70 |
+
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 74 |
+
|
| 75 |
+
Shape:
|
| 76 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 77 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 78 |
+
|
| 79 |
+
.. image:: ../scripts/activation_images/ReLU.png
|
| 80 |
+
|
| 81 |
+
Examples::
|
| 82 |
+
|
| 83 |
+
>>> m = nn.ReLU()
|
| 84 |
+
>>> input = torch.randn(2)
|
| 85 |
+
>>> output = m(input)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
| 89 |
+
|
| 90 |
+
>>> m = nn.ReLU()
|
| 91 |
+
>>> input = torch.randn(2).unsqueeze(0)
|
| 92 |
+
>>> output = torch.cat((m(input), m(-input)))
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
__constants__ = ['inplace']
|
| 96 |
+
inplace: bool
|
| 97 |
+
|
| 98 |
+
def __init__(self, inplace: bool = False):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.inplace = inplace
|
| 101 |
+
|
| 102 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 103 |
+
return F.relu(input, inplace=self.inplace)
|
| 104 |
+
|
| 105 |
+
def extra_repr(self) -> str:
|
| 106 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
| 107 |
+
return inplace_str
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class RReLU(Module):
|
| 111 |
+
r"""Applies the randomized leaky rectified linear unit function, element-wise.
|
| 112 |
+
|
| 113 |
+
Method described in the paper:
|
| 114 |
+
`Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
|
| 115 |
+
|
| 116 |
+
The function is defined as:
|
| 117 |
+
|
| 118 |
+
.. math::
|
| 119 |
+
\text{RReLU}(x) =
|
| 120 |
+
\begin{cases}
|
| 121 |
+
x & \text{if } x \geq 0 \\
|
| 122 |
+
ax & \text{ otherwise }
|
| 123 |
+
\end{cases}
|
| 124 |
+
|
| 125 |
+
where :math:`a` is randomly sampled from uniform distribution
|
| 126 |
+
:math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
|
| 127 |
+
evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
| 131 |
+
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
| 132 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 133 |
+
|
| 134 |
+
Shape:
|
| 135 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 136 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 137 |
+
|
| 138 |
+
.. image:: ../scripts/activation_images/RReLU.png
|
| 139 |
+
|
| 140 |
+
Examples::
|
| 141 |
+
|
| 142 |
+
>>> m = nn.RReLU(0.1, 0.3)
|
| 143 |
+
>>> input = torch.randn(2)
|
| 144 |
+
>>> output = m(input)
|
| 145 |
+
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
__constants__ = ['lower', 'upper', 'inplace']
|
| 149 |
+
|
| 150 |
+
lower: float
|
| 151 |
+
upper: float
|
| 152 |
+
inplace: bool
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
lower: float = 1. / 8,
|
| 157 |
+
upper: float = 1. / 3,
|
| 158 |
+
inplace: bool = False
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.lower = lower
|
| 162 |
+
self.upper = upper
|
| 163 |
+
self.inplace = inplace
|
| 164 |
+
|
| 165 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 166 |
+
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
|
| 167 |
+
|
| 168 |
+
def extra_repr(self):
|
| 169 |
+
inplace_str = ', inplace=True' if self.inplace else ''
|
| 170 |
+
return f'lower={self.lower}, upper={self.upper}{inplace_str}'
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Hardtanh(Module):
|
| 174 |
+
r"""Applies the HardTanh function element-wise.
|
| 175 |
+
|
| 176 |
+
HardTanh is defined as:
|
| 177 |
+
|
| 178 |
+
.. math::
|
| 179 |
+
\text{HardTanh}(x) = \begin{cases}
|
| 180 |
+
\text{max\_val} & \text{ if } x > \text{ max\_val } \\
|
| 181 |
+
\text{min\_val} & \text{ if } x < \text{ min\_val } \\
|
| 182 |
+
x & \text{ otherwise } \\
|
| 183 |
+
\end{cases}
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
min_val: minimum value of the linear region range. Default: -1
|
| 187 |
+
max_val: maximum value of the linear region range. Default: 1
|
| 188 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 189 |
+
|
| 190 |
+
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
| 191 |
+
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
| 192 |
+
|
| 193 |
+
Shape:
|
| 194 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 195 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 196 |
+
|
| 197 |
+
.. image:: ../scripts/activation_images/Hardtanh.png
|
| 198 |
+
|
| 199 |
+
Examples::
|
| 200 |
+
|
| 201 |
+
>>> m = nn.Hardtanh(-2, 2)
|
| 202 |
+
>>> input = torch.randn(2)
|
| 203 |
+
>>> output = m(input)
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
__constants__ = ['min_val', 'max_val', 'inplace']
|
| 207 |
+
|
| 208 |
+
min_val: float
|
| 209 |
+
max_val: float
|
| 210 |
+
inplace: bool
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
min_val: float = -1.,
|
| 215 |
+
max_val: float = 1.,
|
| 216 |
+
inplace: bool = False,
|
| 217 |
+
min_value: Optional[float] = None,
|
| 218 |
+
max_value: Optional[float] = None
|
| 219 |
+
) -> None:
|
| 220 |
+
super().__init__()
|
| 221 |
+
if min_value is not None:
|
| 222 |
+
warnings.warn("keyword argument min_value is deprecated and rename to min_val")
|
| 223 |
+
min_val = min_value
|
| 224 |
+
if max_value is not None:
|
| 225 |
+
warnings.warn("keyword argument max_value is deprecated and rename to max_val")
|
| 226 |
+
max_val = max_value
|
| 227 |
+
|
| 228 |
+
self.min_val = min_val
|
| 229 |
+
self.max_val = max_val
|
| 230 |
+
self.inplace = inplace
|
| 231 |
+
assert self.max_val > self.min_val
|
| 232 |
+
|
| 233 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 234 |
+
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
|
| 235 |
+
|
| 236 |
+
def extra_repr(self) -> str:
|
| 237 |
+
inplace_str = ', inplace=True' if self.inplace else ''
|
| 238 |
+
return f'min_val={self.min_val}, max_val={self.max_val}{inplace_str}'
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class ReLU6(Hardtanh):
|
| 242 |
+
r"""Applies the ReLU6 function element-wise.
|
| 243 |
+
|
| 244 |
+
.. math::
|
| 245 |
+
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 249 |
+
|
| 250 |
+
Shape:
|
| 251 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 252 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 253 |
+
|
| 254 |
+
.. image:: ../scripts/activation_images/ReLU6.png
|
| 255 |
+
|
| 256 |
+
Examples::
|
| 257 |
+
|
| 258 |
+
>>> m = nn.ReLU6()
|
| 259 |
+
>>> input = torch.randn(2)
|
| 260 |
+
>>> output = m(input)
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self, inplace: bool = False):
|
| 264 |
+
super().__init__(0., 6., inplace)
|
| 265 |
+
|
| 266 |
+
def extra_repr(self) -> str:
|
| 267 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
| 268 |
+
return inplace_str
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Sigmoid(Module):
|
| 272 |
+
r"""Applies the Sigmoid function element-wise.
|
| 273 |
+
|
| 274 |
+
.. math::
|
| 275 |
+
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
Shape:
|
| 279 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 280 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 281 |
+
|
| 282 |
+
.. image:: ../scripts/activation_images/Sigmoid.png
|
| 283 |
+
|
| 284 |
+
Examples::
|
| 285 |
+
|
| 286 |
+
>>> m = nn.Sigmoid()
|
| 287 |
+
>>> input = torch.randn(2)
|
| 288 |
+
>>> output = m(input)
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 292 |
+
return torch.sigmoid(input)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Hardsigmoid(Module):
|
| 296 |
+
r"""Applies the Hardsigmoid function element-wise.
|
| 297 |
+
|
| 298 |
+
Hardsigmoid is defined as:
|
| 299 |
+
|
| 300 |
+
.. math::
|
| 301 |
+
\text{Hardsigmoid}(x) = \begin{cases}
|
| 302 |
+
0 & \text{if~} x \le -3, \\
|
| 303 |
+
1 & \text{if~} x \ge +3, \\
|
| 304 |
+
x / 6 + 1 / 2 & \text{otherwise}
|
| 305 |
+
\end{cases}
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 309 |
+
|
| 310 |
+
Shape:
|
| 311 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 312 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 313 |
+
|
| 314 |
+
.. image:: ../scripts/activation_images/Hardsigmoid.png
|
| 315 |
+
|
| 316 |
+
Examples::
|
| 317 |
+
|
| 318 |
+
>>> m = nn.Hardsigmoid()
|
| 319 |
+
>>> input = torch.randn(2)
|
| 320 |
+
>>> output = m(input)
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
__constants__ = ['inplace']
|
| 324 |
+
|
| 325 |
+
inplace: bool
|
| 326 |
+
|
| 327 |
+
def __init__(self, inplace : bool = False) -> None:
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.inplace = inplace
|
| 330 |
+
|
| 331 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 332 |
+
return F.hardsigmoid(input, self.inplace)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class Tanh(Module):
|
| 336 |
+
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
|
| 337 |
+
|
| 338 |
+
Tanh is defined as:
|
| 339 |
+
|
| 340 |
+
.. math::
|
| 341 |
+
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
| 342 |
+
|
| 343 |
+
Shape:
|
| 344 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 345 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 346 |
+
|
| 347 |
+
.. image:: ../scripts/activation_images/Tanh.png
|
| 348 |
+
|
| 349 |
+
Examples::
|
| 350 |
+
|
| 351 |
+
>>> m = nn.Tanh()
|
| 352 |
+
>>> input = torch.randn(2)
|
| 353 |
+
>>> output = m(input)
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 357 |
+
return torch.tanh(input)
|
| 358 |
+
|
| 359 |
+
class SiLU(Module):
|
| 360 |
+
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
|
| 361 |
+
|
| 362 |
+
The SiLU function is also known as the swish function.
|
| 363 |
+
|
| 364 |
+
.. math::
|
| 365 |
+
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
|
| 366 |
+
|
| 367 |
+
.. note::
|
| 368 |
+
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
|
| 369 |
+
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
|
| 370 |
+
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
|
| 371 |
+
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
|
| 372 |
+
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
|
| 373 |
+
where the SiLU was experimented with later.
|
| 374 |
+
|
| 375 |
+
Shape:
|
| 376 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 377 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 378 |
+
|
| 379 |
+
.. image:: ../scripts/activation_images/SiLU.png
|
| 380 |
+
|
| 381 |
+
Examples::
|
| 382 |
+
|
| 383 |
+
>>> m = nn.SiLU()
|
| 384 |
+
>>> input = torch.randn(2)
|
| 385 |
+
>>> output = m(input)
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
__constants__ = ['inplace']
|
| 389 |
+
inplace: bool
|
| 390 |
+
|
| 391 |
+
def __init__(self, inplace: bool = False):
|
| 392 |
+
super().__init__()
|
| 393 |
+
self.inplace = inplace
|
| 394 |
+
|
| 395 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 396 |
+
return F.silu(input, inplace=self.inplace)
|
| 397 |
+
|
| 398 |
+
def extra_repr(self) -> str:
|
| 399 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
| 400 |
+
return inplace_str
|
| 401 |
+
|
| 402 |
+
class Mish(Module):
|
| 403 |
+
r"""Applies the Mish function, element-wise.
|
| 404 |
+
|
| 405 |
+
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
| 406 |
+
|
| 407 |
+
.. math::
|
| 408 |
+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
| 409 |
+
|
| 410 |
+
.. note::
|
| 411 |
+
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
|
| 412 |
+
|
| 413 |
+
Shape:
|
| 414 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 415 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 416 |
+
|
| 417 |
+
.. image:: ../scripts/activation_images/Mish.png
|
| 418 |
+
|
| 419 |
+
Examples::
|
| 420 |
+
|
| 421 |
+
>>> m = nn.Mish()
|
| 422 |
+
>>> input = torch.randn(2)
|
| 423 |
+
>>> output = m(input)
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
__constants__ = ['inplace']
|
| 427 |
+
inplace: bool
|
| 428 |
+
|
| 429 |
+
def __init__(self, inplace: bool = False):
|
| 430 |
+
super().__init__()
|
| 431 |
+
self.inplace = inplace
|
| 432 |
+
|
| 433 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 434 |
+
return F.mish(input, inplace=self.inplace)
|
| 435 |
+
|
| 436 |
+
def extra_repr(self) -> str:
|
| 437 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
| 438 |
+
return inplace_str
|
| 439 |
+
|
| 440 |
+
class Hardswish(Module):
|
| 441 |
+
r"""Applies the Hardswish function, element-wise.
|
| 442 |
+
|
| 443 |
+
Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
|
| 444 |
+
|
| 445 |
+
Hardswish is defined as:
|
| 446 |
+
|
| 447 |
+
.. math::
|
| 448 |
+
\text{Hardswish}(x) = \begin{cases}
|
| 449 |
+
0 & \text{if~} x \le -3, \\
|
| 450 |
+
x & \text{if~} x \ge +3, \\
|
| 451 |
+
x \cdot (x + 3) /6 & \text{otherwise}
|
| 452 |
+
\end{cases}
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 456 |
+
|
| 457 |
+
Shape:
|
| 458 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 459 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 460 |
+
|
| 461 |
+
.. image:: ../scripts/activation_images/Hardswish.png
|
| 462 |
+
|
| 463 |
+
Examples::
|
| 464 |
+
|
| 465 |
+
>>> m = nn.Hardswish()
|
| 466 |
+
>>> input = torch.randn(2)
|
| 467 |
+
>>> output = m(input)
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
__constants__ = ['inplace']
|
| 471 |
+
|
| 472 |
+
inplace: bool
|
| 473 |
+
|
| 474 |
+
def __init__(self, inplace : bool = False) -> None:
|
| 475 |
+
super().__init__()
|
| 476 |
+
self.inplace = inplace
|
| 477 |
+
|
| 478 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 479 |
+
return F.hardswish(input, self.inplace)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class ELU(Module):
|
| 483 |
+
r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
|
| 484 |
+
|
| 485 |
+
Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
|
| 486 |
+
Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
|
| 487 |
+
|
| 488 |
+
ELU is defined as:
|
| 489 |
+
|
| 490 |
+
.. math::
|
| 491 |
+
\text{ELU}(x) = \begin{cases}
|
| 492 |
+
x, & \text{ if } x > 0\\
|
| 493 |
+
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
| 494 |
+
\end{cases}
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
| 498 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 499 |
+
|
| 500 |
+
Shape:
|
| 501 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 502 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 503 |
+
|
| 504 |
+
.. image:: ../scripts/activation_images/ELU.png
|
| 505 |
+
|
| 506 |
+
Examples::
|
| 507 |
+
|
| 508 |
+
>>> m = nn.ELU()
|
| 509 |
+
>>> input = torch.randn(2)
|
| 510 |
+
>>> output = m(input)
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
__constants__ = ['alpha', 'inplace']
|
| 514 |
+
alpha: float
|
| 515 |
+
inplace: bool
|
| 516 |
+
|
| 517 |
+
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
|
| 518 |
+
super().__init__()
|
| 519 |
+
self.alpha = alpha
|
| 520 |
+
self.inplace = inplace
|
| 521 |
+
|
| 522 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 523 |
+
return F.elu(input, self.alpha, self.inplace)
|
| 524 |
+
|
| 525 |
+
def extra_repr(self) -> str:
|
| 526 |
+
inplace_str = ', inplace=True' if self.inplace else ''
|
| 527 |
+
return f'alpha={self.alpha}{inplace_str}'
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class CELU(Module):
|
| 531 |
+
r"""Applies the CELU function element-wise.
|
| 532 |
+
|
| 533 |
+
.. math::
|
| 534 |
+
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
| 535 |
+
|
| 536 |
+
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
| 540 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 541 |
+
|
| 542 |
+
Shape:
|
| 543 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 544 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 545 |
+
|
| 546 |
+
.. image:: ../scripts/activation_images/CELU.png
|
| 547 |
+
|
| 548 |
+
Examples::
|
| 549 |
+
|
| 550 |
+
>>> m = nn.CELU()
|
| 551 |
+
>>> input = torch.randn(2)
|
| 552 |
+
>>> output = m(input)
|
| 553 |
+
|
| 554 |
+
.. _`Continuously Differentiable Exponential Linear Units`:
|
| 555 |
+
https://arxiv.org/abs/1704.07483
|
| 556 |
+
"""
|
| 557 |
+
|
| 558 |
+
__constants__ = ['alpha', 'inplace']
|
| 559 |
+
alpha: float
|
| 560 |
+
inplace: bool
|
| 561 |
+
|
| 562 |
+
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
|
| 563 |
+
super().__init__()
|
| 564 |
+
self.alpha = alpha
|
| 565 |
+
self.inplace = inplace
|
| 566 |
+
|
| 567 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 568 |
+
return F.celu(input, self.alpha, self.inplace)
|
| 569 |
+
|
| 570 |
+
def extra_repr(self) -> str:
|
| 571 |
+
inplace_str = ', inplace=True' if self.inplace else ''
|
| 572 |
+
return f'alpha={self.alpha}{inplace_str}'
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class SELU(Module):
|
| 576 |
+
r"""Applies the SELU function element-wise.
|
| 577 |
+
|
| 578 |
+
.. math::
|
| 579 |
+
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
| 580 |
+
|
| 581 |
+
with :math:`\alpha = 1.6732632423543772848170429916717` and
|
| 582 |
+
:math:`\text{scale} = 1.0507009873554804934193349852946`.
|
| 583 |
+
|
| 584 |
+
.. warning::
|
| 585 |
+
When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
|
| 586 |
+
``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
|
| 587 |
+
in order to get `Self-Normalizing Neural Networks`_.
|
| 588 |
+
See :func:`torch.nn.init.calculate_gain` for more information.
|
| 589 |
+
|
| 590 |
+
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
|
| 594 |
+
|
| 595 |
+
Shape:
|
| 596 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 597 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 598 |
+
|
| 599 |
+
.. image:: ../scripts/activation_images/SELU.png
|
| 600 |
+
|
| 601 |
+
Examples::
|
| 602 |
+
|
| 603 |
+
>>> m = nn.SELU()
|
| 604 |
+
>>> input = torch.randn(2)
|
| 605 |
+
>>> output = m(input)
|
| 606 |
+
|
| 607 |
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
__constants__ = ['inplace']
|
| 611 |
+
inplace: bool
|
| 612 |
+
|
| 613 |
+
def __init__(self, inplace: bool = False) -> None:
|
| 614 |
+
super().__init__()
|
| 615 |
+
self.inplace = inplace
|
| 616 |
+
|
| 617 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 618 |
+
return F.selu(input, self.inplace)
|
| 619 |
+
|
| 620 |
+
def extra_repr(self) -> str:
|
| 621 |
+
inplace_str = 'inplace=True' if self.inplace else ''
|
| 622 |
+
return inplace_str
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class GLU(Module):
|
| 626 |
+
r"""Applies the gated linear unit function.
|
| 627 |
+
|
| 628 |
+
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
|
| 629 |
+
of the input matrices and :math:`b` is the second half.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
dim (int): the dimension on which to split the input. Default: -1
|
| 633 |
+
|
| 634 |
+
Shape:
|
| 635 |
+
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
| 636 |
+
dimensions
|
| 637 |
+
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
| 638 |
+
|
| 639 |
+
Examples::
|
| 640 |
+
|
| 641 |
+
>>> m = nn.GLU()
|
| 642 |
+
>>> input = torch.randn(4, 2)
|
| 643 |
+
>>> output = m(input)
|
| 644 |
+
"""
|
| 645 |
+
|
| 646 |
+
__constants__ = ['dim']
|
| 647 |
+
dim: int
|
| 648 |
+
|
| 649 |
+
def __init__(self, dim: int = -1) -> None:
|
| 650 |
+
super().__init__()
|
| 651 |
+
self.dim = dim
|
| 652 |
+
|
| 653 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 654 |
+
return F.glu(input, self.dim)
|
| 655 |
+
|
| 656 |
+
def extra_repr(self) -> str:
|
| 657 |
+
return f'dim={self.dim}'
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class GELU(Module):
|
| 661 |
+
r"""Applies the Gaussian Error Linear Units function.
|
| 662 |
+
|
| 663 |
+
.. math:: \text{GELU}(x) = x * \Phi(x)
|
| 664 |
+
|
| 665 |
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
| 666 |
+
|
| 667 |
+
When the approximate argument is 'tanh', Gelu is estimated with:
|
| 668 |
+
|
| 669 |
+
.. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
approximate (str, optional): the gelu approximation algorithm to use:
|
| 673 |
+
``'none'`` | ``'tanh'``. Default: ``'none'``
|
| 674 |
+
|
| 675 |
+
Shape:
|
| 676 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 677 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 678 |
+
|
| 679 |
+
.. image:: ../scripts/activation_images/GELU.png
|
| 680 |
+
|
| 681 |
+
Examples::
|
| 682 |
+
|
| 683 |
+
>>> m = nn.GELU()
|
| 684 |
+
>>> input = torch.randn(2)
|
| 685 |
+
>>> output = m(input)
|
| 686 |
+
"""
|
| 687 |
+
|
| 688 |
+
__constants__ = ['approximate']
|
| 689 |
+
approximate: str
|
| 690 |
+
|
| 691 |
+
def __init__(self, approximate: str = 'none') -> None:
|
| 692 |
+
super().__init__()
|
| 693 |
+
self.approximate = approximate
|
| 694 |
+
|
| 695 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 696 |
+
return F.gelu(input, approximate=self.approximate)
|
| 697 |
+
|
| 698 |
+
def extra_repr(self) -> str:
|
| 699 |
+
return f'approximate={repr(self.approximate)}'
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class Hardshrink(Module):
|
| 703 |
+
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
| 704 |
+
|
| 705 |
+
Hardshrink is defined as:
|
| 706 |
+
|
| 707 |
+
.. math::
|
| 708 |
+
\text{HardShrink}(x) =
|
| 709 |
+
\begin{cases}
|
| 710 |
+
x, & \text{ if } x > \lambda \\
|
| 711 |
+
x, & \text{ if } x < -\lambda \\
|
| 712 |
+
0, & \text{ otherwise }
|
| 713 |
+
\end{cases}
|
| 714 |
+
|
| 715 |
+
Args:
|
| 716 |
+
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
| 717 |
+
|
| 718 |
+
Shape:
|
| 719 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 720 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 721 |
+
|
| 722 |
+
.. image:: ../scripts/activation_images/Hardshrink.png
|
| 723 |
+
|
| 724 |
+
Examples::
|
| 725 |
+
|
| 726 |
+
>>> m = nn.Hardshrink()
|
| 727 |
+
>>> input = torch.randn(2)
|
| 728 |
+
>>> output = m(input)
|
| 729 |
+
"""
|
| 730 |
+
|
| 731 |
+
__constants__ = ['lambd']
|
| 732 |
+
lambd: float
|
| 733 |
+
|
| 734 |
+
def __init__(self, lambd: float = 0.5) -> None:
|
| 735 |
+
super().__init__()
|
| 736 |
+
self.lambd = lambd
|
| 737 |
+
|
| 738 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 739 |
+
return F.hardshrink(input, self.lambd)
|
| 740 |
+
|
| 741 |
+
def extra_repr(self) -> str:
|
| 742 |
+
return f'{self.lambd}'
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class LeakyReLU(Module):
|
| 746 |
+
r"""Applies the LeakyReLU function element-wise.
|
| 747 |
+
|
| 748 |
+
.. math::
|
| 749 |
+
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
or
|
| 753 |
+
|
| 754 |
+
.. math::
|
| 755 |
+
\text{LeakyReLU}(x) =
|
| 756 |
+
\begin{cases}
|
| 757 |
+
x, & \text{ if } x \geq 0 \\
|
| 758 |
+
\text{negative\_slope} \times x, & \text{ otherwise }
|
| 759 |
+
\end{cases}
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
negative_slope: Controls the angle of the negative slope (which is used for
|
| 763 |
+
negative input values). Default: 1e-2
|
| 764 |
+
inplace: can optionally do the operation in-place. Default: ``False``
|
| 765 |
+
|
| 766 |
+
Shape:
|
| 767 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 768 |
+
dimensions
|
| 769 |
+
- Output: :math:`(*)`, same shape as the input
|
| 770 |
+
|
| 771 |
+
.. image:: ../scripts/activation_images/LeakyReLU.png
|
| 772 |
+
|
| 773 |
+
Examples::
|
| 774 |
+
|
| 775 |
+
>>> m = nn.LeakyReLU(0.1)
|
| 776 |
+
>>> input = torch.randn(2)
|
| 777 |
+
>>> output = m(input)
|
| 778 |
+
"""
|
| 779 |
+
|
| 780 |
+
__constants__ = ['inplace', 'negative_slope']
|
| 781 |
+
inplace: bool
|
| 782 |
+
negative_slope: float
|
| 783 |
+
|
| 784 |
+
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
|
| 785 |
+
super().__init__()
|
| 786 |
+
self.negative_slope = negative_slope
|
| 787 |
+
self.inplace = inplace
|
| 788 |
+
|
| 789 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 790 |
+
return F.leaky_relu(input, self.negative_slope, self.inplace)
|
| 791 |
+
|
| 792 |
+
def extra_repr(self) -> str:
|
| 793 |
+
inplace_str = ', inplace=True' if self.inplace else ''
|
| 794 |
+
return f'negative_slope={self.negative_slope}{inplace_str}'
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class LogSigmoid(Module):
|
| 798 |
+
r"""Applies the Logsigmoid function element-wise.
|
| 799 |
+
|
| 800 |
+
.. math::
|
| 801 |
+
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
| 802 |
+
|
| 803 |
+
Shape:
|
| 804 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 805 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 806 |
+
|
| 807 |
+
.. image:: ../scripts/activation_images/LogSigmoid.png
|
| 808 |
+
|
| 809 |
+
Examples::
|
| 810 |
+
|
| 811 |
+
>>> m = nn.LogSigmoid()
|
| 812 |
+
>>> input = torch.randn(2)
|
| 813 |
+
>>> output = m(input)
|
| 814 |
+
"""
|
| 815 |
+
|
| 816 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 817 |
+
return F.logsigmoid(input)
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
class Softplus(Module):
|
| 821 |
+
r"""Applies the Softplus function element-wise.
|
| 822 |
+
|
| 823 |
+
.. math::
|
| 824 |
+
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
|
| 825 |
+
|
| 826 |
+
SoftPlus is a smooth approximation to the ReLU function and can be used
|
| 827 |
+
to constrain the output of a machine to always be positive.
|
| 828 |
+
|
| 829 |
+
For numerical stability the implementation reverts to the linear function
|
| 830 |
+
when :math:`input \times \beta > threshold`.
|
| 831 |
+
|
| 832 |
+
Args:
|
| 833 |
+
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
|
| 834 |
+
threshold: values above this revert to a linear function. Default: 20
|
| 835 |
+
|
| 836 |
+
Shape:
|
| 837 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 838 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 839 |
+
|
| 840 |
+
.. image:: ../scripts/activation_images/Softplus.png
|
| 841 |
+
|
| 842 |
+
Examples::
|
| 843 |
+
|
| 844 |
+
>>> m = nn.Softplus()
|
| 845 |
+
>>> input = torch.randn(2)
|
| 846 |
+
>>> output = m(input)
|
| 847 |
+
"""
|
| 848 |
+
|
| 849 |
+
__constants__ = ['beta', 'threshold']
|
| 850 |
+
beta: float
|
| 851 |
+
threshold: float
|
| 852 |
+
|
| 853 |
+
def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
|
| 854 |
+
super().__init__()
|
| 855 |
+
self.beta = beta
|
| 856 |
+
self.threshold = threshold
|
| 857 |
+
|
| 858 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 859 |
+
return F.softplus(input, self.beta, self.threshold)
|
| 860 |
+
|
| 861 |
+
def extra_repr(self) -> str:
|
| 862 |
+
return f'beta={self.beta}, threshold={self.threshold}'
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
class Softshrink(Module):
|
| 866 |
+
r"""Applies the soft shrinkage function element-wise.
|
| 867 |
+
|
| 868 |
+
.. math::
|
| 869 |
+
\text{SoftShrinkage}(x) =
|
| 870 |
+
\begin{cases}
|
| 871 |
+
x - \lambda, & \text{ if } x > \lambda \\
|
| 872 |
+
x + \lambda, & \text{ if } x < -\lambda \\
|
| 873 |
+
0, & \text{ otherwise }
|
| 874 |
+
\end{cases}
|
| 875 |
+
|
| 876 |
+
Args:
|
| 877 |
+
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
|
| 878 |
+
|
| 879 |
+
Shape:
|
| 880 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 881 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 882 |
+
|
| 883 |
+
.. image:: ../scripts/activation_images/Softshrink.png
|
| 884 |
+
|
| 885 |
+
Examples::
|
| 886 |
+
|
| 887 |
+
>>> m = nn.Softshrink()
|
| 888 |
+
>>> input = torch.randn(2)
|
| 889 |
+
>>> output = m(input)
|
| 890 |
+
"""
|
| 891 |
+
|
| 892 |
+
__constants__ = ['lambd']
|
| 893 |
+
lambd: float
|
| 894 |
+
|
| 895 |
+
def __init__(self, lambd: float = 0.5) -> None:
|
| 896 |
+
super().__init__()
|
| 897 |
+
self.lambd = lambd
|
| 898 |
+
|
| 899 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 900 |
+
return F.softshrink(input, self.lambd)
|
| 901 |
+
|
| 902 |
+
def extra_repr(self) -> str:
|
| 903 |
+
return str(self.lambd)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
|
| 907 |
+
if x is not None:
|
| 908 |
+
return x.device.type in ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
| 909 |
+
return True
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
|
| 913 |
+
if x is not None:
|
| 914 |
+
return x.requires_grad
|
| 915 |
+
return False
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def _is_make_fx_tracing():
|
| 919 |
+
if not torch.jit.is_scripting():
|
| 920 |
+
torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
| 921 |
+
return any(type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack)
|
| 922 |
+
else:
|
| 923 |
+
return False
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
class MultiheadAttention(Module):
|
| 927 |
+
r"""Allows the model to jointly attend to information from different representation subspaces.
|
| 928 |
+
|
| 929 |
+
Method described in the paper:
|
| 930 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
| 931 |
+
|
| 932 |
+
Multi-Head Attention is defined as:
|
| 933 |
+
|
| 934 |
+
.. math::
|
| 935 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
| 936 |
+
|
| 937 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
| 938 |
+
|
| 939 |
+
``nn.MultiHeadAttention`` will use the optimized implementations of
|
| 940 |
+
``scaled_dot_product_attention()`` when possible.
|
| 941 |
+
|
| 942 |
+
In addition to support for the new ``scaled_dot_product_attention()``
|
| 943 |
+
function, for speeding up Inference, MHA will use
|
| 944 |
+
fastpath inference with support for Nested Tensors, iff:
|
| 945 |
+
|
| 946 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
|
| 947 |
+
- inputs are batched (3D) with ``batch_first==True``
|
| 948 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
| 949 |
+
- training is disabled (using ``.eval()``)
|
| 950 |
+
- ``add_bias_kv`` is ``False``
|
| 951 |
+
- ``add_zero_attn`` is ``False``
|
| 952 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
| 953 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
| 954 |
+
nor ``attn_mask`` is passed
|
| 955 |
+
- autocast is disabled
|
| 956 |
+
|
| 957 |
+
If the optimized inference fastpath implementation is in use, a
|
| 958 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
| 959 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
| 960 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
| 961 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
| 962 |
+
that is padding can be expected.
|
| 963 |
+
|
| 964 |
+
Args:
|
| 965 |
+
embed_dim: Total dimension of the model.
|
| 966 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
| 967 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
| 968 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
| 969 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
| 970 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
| 971 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
| 972 |
+
Default: ``False``.
|
| 973 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
| 974 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
| 975 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 976 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 977 |
+
|
| 978 |
+
Examples::
|
| 979 |
+
|
| 980 |
+
>>> # xdoctest: +SKIP
|
| 981 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 982 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
| 983 |
+
|
| 984 |
+
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
| 985 |
+
https://arxiv.org/abs/2205.14135
|
| 986 |
+
|
| 987 |
+
"""
|
| 988 |
+
|
| 989 |
+
__constants__ = ['batch_first']
|
| 990 |
+
bias_k: Optional[torch.Tensor]
|
| 991 |
+
bias_v: Optional[torch.Tensor]
|
| 992 |
+
|
| 993 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
|
| 994 |
+
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
|
| 995 |
+
if embed_dim <= 0 or num_heads <= 0:
|
| 996 |
+
raise ValueError(
|
| 997 |
+
f"embed_dim and num_heads must be greater than 0,"
|
| 998 |
+
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
|
| 999 |
+
)
|
| 1000 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 1001 |
+
super().__init__()
|
| 1002 |
+
self.embed_dim = embed_dim
|
| 1003 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 1004 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 1005 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 1006 |
+
|
| 1007 |
+
self.num_heads = num_heads
|
| 1008 |
+
self.dropout = dropout
|
| 1009 |
+
self.batch_first = batch_first
|
| 1010 |
+
self.head_dim = embed_dim // num_heads
|
| 1011 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 1012 |
+
|
| 1013 |
+
if not self._qkv_same_embed_dim:
|
| 1014 |
+
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
|
| 1015 |
+
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
|
| 1016 |
+
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
|
| 1017 |
+
self.register_parameter('in_proj_weight', None)
|
| 1018 |
+
else:
|
| 1019 |
+
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
|
| 1020 |
+
self.register_parameter('q_proj_weight', None)
|
| 1021 |
+
self.register_parameter('k_proj_weight', None)
|
| 1022 |
+
self.register_parameter('v_proj_weight', None)
|
| 1023 |
+
|
| 1024 |
+
if bias:
|
| 1025 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
| 1026 |
+
else:
|
| 1027 |
+
self.register_parameter('in_proj_bias', None)
|
| 1028 |
+
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
| 1029 |
+
|
| 1030 |
+
if add_bias_kv:
|
| 1031 |
+
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
| 1032 |
+
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
| 1033 |
+
else:
|
| 1034 |
+
self.bias_k = self.bias_v = None
|
| 1035 |
+
|
| 1036 |
+
self.add_zero_attn = add_zero_attn
|
| 1037 |
+
|
| 1038 |
+
self._reset_parameters()
|
| 1039 |
+
|
| 1040 |
+
def _reset_parameters(self):
|
| 1041 |
+
if self._qkv_same_embed_dim:
|
| 1042 |
+
xavier_uniform_(self.in_proj_weight)
|
| 1043 |
+
else:
|
| 1044 |
+
xavier_uniform_(self.q_proj_weight)
|
| 1045 |
+
xavier_uniform_(self.k_proj_weight)
|
| 1046 |
+
xavier_uniform_(self.v_proj_weight)
|
| 1047 |
+
|
| 1048 |
+
if self.in_proj_bias is not None:
|
| 1049 |
+
constant_(self.in_proj_bias, 0.)
|
| 1050 |
+
constant_(self.out_proj.bias, 0.)
|
| 1051 |
+
if self.bias_k is not None:
|
| 1052 |
+
xavier_normal_(self.bias_k)
|
| 1053 |
+
if self.bias_v is not None:
|
| 1054 |
+
xavier_normal_(self.bias_v)
|
| 1055 |
+
|
| 1056 |
+
def __setstate__(self, state):
|
| 1057 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
| 1058 |
+
if '_qkv_same_embed_dim' not in state:
|
| 1059 |
+
state['_qkv_same_embed_dim'] = True
|
| 1060 |
+
|
| 1061 |
+
super().__setstate__(state)
|
| 1062 |
+
|
| 1063 |
+
def forward(
|
| 1064 |
+
self,
|
| 1065 |
+
query: Tensor,
|
| 1066 |
+
key: Tensor,
|
| 1067 |
+
value: Tensor,
|
| 1068 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 1069 |
+
need_weights: bool = True,
|
| 1070 |
+
attn_mask: Optional[Tensor] = None,
|
| 1071 |
+
average_attn_weights: bool = True,
|
| 1072 |
+
is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
|
| 1073 |
+
r"""Compute attention outputs using query, key, and value embeddings.
|
| 1074 |
+
|
| 1075 |
+
Supports optional parameters for padding, masks and attention weights.
|
| 1076 |
+
|
| 1077 |
+
Args:
|
| 1078 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
| 1079 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
| 1080 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
| 1081 |
+
Queries are compared against key-value pairs to produce the output.
|
| 1082 |
+
See "Attention Is All You Need" for more details.
|
| 1083 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
| 1084 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
| 1085 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
| 1086 |
+
See "Attention Is All You Need" for more details.
|
| 1087 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
| 1088 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
| 1089 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
| 1090 |
+
See "Attention Is All You Need" for more details.
|
| 1091 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
| 1092 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
| 1093 |
+
Binary and float masks are supported.
|
| 1094 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
| 1095 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
| 1096 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
| 1097 |
+
Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
|
| 1098 |
+
and achieve the best performance for MHA.
|
| 1099 |
+
Default: ``True``.
|
| 1100 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
| 1101 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
| 1102 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
| 1103 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
| 1104 |
+
Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
| 1105 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
| 1106 |
+
the attention weight.
|
| 1107 |
+
If both attn_mask and key_padding_mask are supplied, their types should match.
|
| 1108 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
| 1109 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
| 1110 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
| 1111 |
+
is_causal: If specified, applies a causal mask as attention mask.
|
| 1112 |
+
Default: ``False``.
|
| 1113 |
+
Warning:
|
| 1114 |
+
``is_causal`` provides a hint that ``attn_mask`` is the
|
| 1115 |
+
causal mask. Providing incorrect hints can result in
|
| 1116 |
+
incorrect execution, including forward and backward
|
| 1117 |
+
compatibility.
|
| 1118 |
+
|
| 1119 |
+
Outputs:
|
| 1120 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
| 1121 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
| 1122 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
| 1123 |
+
embedding dimension ``embed_dim``.
|
| 1124 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
| 1125 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
| 1126 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
| 1127 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
| 1128 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
| 1129 |
+
|
| 1130 |
+
.. note::
|
| 1131 |
+
`batch_first` argument is ignored for unbatched inputs.
|
| 1132 |
+
"""
|
| 1133 |
+
why_not_fast_path = ''
|
| 1134 |
+
if ((attn_mask is not None and torch.is_floating_point(attn_mask))
|
| 1135 |
+
or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
|
| 1136 |
+
why_not_fast_path = "floating-point masks are not supported for fast path."
|
| 1137 |
+
|
| 1138 |
+
is_batched = query.dim() == 3
|
| 1139 |
+
|
| 1140 |
+
key_padding_mask = F._canonical_mask(
|
| 1141 |
+
mask=key_padding_mask,
|
| 1142 |
+
mask_name="key_padding_mask",
|
| 1143 |
+
other_type=F._none_or_dtype(attn_mask),
|
| 1144 |
+
other_name="attn_mask",
|
| 1145 |
+
target_type=query.dtype
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
attn_mask = F._canonical_mask(
|
| 1149 |
+
mask=attn_mask,
|
| 1150 |
+
mask_name="attn_mask",
|
| 1151 |
+
other_type=None,
|
| 1152 |
+
other_name="",
|
| 1153 |
+
target_type=query.dtype,
|
| 1154 |
+
check_other=False,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
| 1158 |
+
|
| 1159 |
+
if not is_fastpath_enabled:
|
| 1160 |
+
why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
| 1161 |
+
elif not is_batched:
|
| 1162 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
| 1163 |
+
elif query is not key or key is not value:
|
| 1164 |
+
# When lifting this restriction, don't forget to either
|
| 1165 |
+
# enforce that the dtypes all match or test cases where
|
| 1166 |
+
# they don't!
|
| 1167 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
| 1168 |
+
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
| 1169 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
| 1170 |
+
elif self.in_proj_weight is None:
|
| 1171 |
+
why_not_fast_path = "in_proj_weight was None"
|
| 1172 |
+
elif query.dtype != self.in_proj_weight.dtype:
|
| 1173 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
| 1174 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
| 1175 |
+
elif self.training:
|
| 1176 |
+
why_not_fast_path = "training is enabled"
|
| 1177 |
+
elif (self.num_heads % 2) != 0:
|
| 1178 |
+
why_not_fast_path = "self.num_heads is not even"
|
| 1179 |
+
elif not self.batch_first:
|
| 1180 |
+
why_not_fast_path = "batch_first was not True"
|
| 1181 |
+
elif self.bias_k is not None:
|
| 1182 |
+
why_not_fast_path = "self.bias_k was not None"
|
| 1183 |
+
elif self.bias_v is not None:
|
| 1184 |
+
why_not_fast_path = "self.bias_v was not None"
|
| 1185 |
+
elif self.add_zero_attn:
|
| 1186 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
| 1187 |
+
elif not self._qkv_same_embed_dim:
|
| 1188 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
| 1189 |
+
elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
|
| 1190 |
+
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
|
| 1191 |
+
is not supported with NestedTensor input"
|
| 1192 |
+
elif torch.is_autocast_enabled():
|
| 1193 |
+
why_not_fast_path = "autocast is enabled"
|
| 1194 |
+
|
| 1195 |
+
if not why_not_fast_path:
|
| 1196 |
+
tensor_args = (
|
| 1197 |
+
query,
|
| 1198 |
+
key,
|
| 1199 |
+
value,
|
| 1200 |
+
self.in_proj_weight,
|
| 1201 |
+
self.in_proj_bias,
|
| 1202 |
+
self.out_proj.weight,
|
| 1203 |
+
self.out_proj.bias,
|
| 1204 |
+
)
|
| 1205 |
+
# We have to use list comprehensions below because TorchScript does not support
|
| 1206 |
+
# generator expressions.
|
| 1207 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 1208 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
| 1209 |
+
elif _is_make_fx_tracing():
|
| 1210 |
+
why_not_fast_path = "we are running make_fx tracing"
|
| 1211 |
+
elif not all(_check_arg_device(x) for x in tensor_args):
|
| 1212 |
+
why_not_fast_path = ("some Tensor argument's device is neither one of "
|
| 1213 |
+
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
|
| 1214 |
+
elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
|
| 1215 |
+
why_not_fast_path = ("grad is enabled and at least one of query or the "
|
| 1216 |
+
"input/output projection weights or biases requires_grad")
|
| 1217 |
+
if not why_not_fast_path:
|
| 1218 |
+
merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
|
| 1219 |
+
|
| 1220 |
+
if self.in_proj_bias is not None and self.in_proj_weight is not None:
|
| 1221 |
+
return torch._native_multi_head_attention(
|
| 1222 |
+
query,
|
| 1223 |
+
key,
|
| 1224 |
+
value,
|
| 1225 |
+
self.embed_dim,
|
| 1226 |
+
self.num_heads,
|
| 1227 |
+
self.in_proj_weight,
|
| 1228 |
+
self.in_proj_bias,
|
| 1229 |
+
self.out_proj.weight,
|
| 1230 |
+
self.out_proj.bias,
|
| 1231 |
+
merged_mask,
|
| 1232 |
+
need_weights,
|
| 1233 |
+
average_attn_weights,
|
| 1234 |
+
mask_type)
|
| 1235 |
+
|
| 1236 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
| 1237 |
+
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
|
| 1238 |
+
f"The fast path was not hit because {why_not_fast_path}")
|
| 1239 |
+
|
| 1240 |
+
if self.batch_first and is_batched:
|
| 1241 |
+
# make sure that the transpose op does not affect the "is" property
|
| 1242 |
+
if key is value:
|
| 1243 |
+
if query is key:
|
| 1244 |
+
query = key = value = query.transpose(1, 0)
|
| 1245 |
+
else:
|
| 1246 |
+
query, key = (x.transpose(1, 0) for x in (query, key))
|
| 1247 |
+
value = key
|
| 1248 |
+
else:
|
| 1249 |
+
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
| 1250 |
+
|
| 1251 |
+
if not self._qkv_same_embed_dim:
|
| 1252 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
| 1253 |
+
query, key, value, self.embed_dim, self.num_heads,
|
| 1254 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 1255 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
| 1256 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 1257 |
+
training=self.training,
|
| 1258 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 1259 |
+
attn_mask=attn_mask,
|
| 1260 |
+
use_separate_proj_weight=True,
|
| 1261 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
| 1262 |
+
v_proj_weight=self.v_proj_weight,
|
| 1263 |
+
average_attn_weights=average_attn_weights,
|
| 1264 |
+
is_causal=is_causal)
|
| 1265 |
+
else:
|
| 1266 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
| 1267 |
+
query, key, value, self.embed_dim, self.num_heads,
|
| 1268 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 1269 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
| 1270 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 1271 |
+
training=self.training,
|
| 1272 |
+
key_padding_mask=key_padding_mask,
|
| 1273 |
+
need_weights=need_weights,
|
| 1274 |
+
attn_mask=attn_mask,
|
| 1275 |
+
average_attn_weights=average_attn_weights,
|
| 1276 |
+
is_causal=is_causal)
|
| 1277 |
+
if self.batch_first and is_batched:
|
| 1278 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
| 1279 |
+
else:
|
| 1280 |
+
return attn_output, attn_output_weights
|
| 1281 |
+
|
| 1282 |
+
def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
|
| 1283 |
+
query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
|
| 1284 |
+
r"""Determine mask type and combine masks if necessary.
|
| 1285 |
+
|
| 1286 |
+
If only one mask is provided, that mask
|
| 1287 |
+
and the corresponding mask type will be returned. If both masks are provided, they will be both
|
| 1288 |
+
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
|
| 1289 |
+
and mask type 2 will be returned
|
| 1290 |
+
Args:
|
| 1291 |
+
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
|
| 1292 |
+
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
|
| 1293 |
+
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
|
| 1294 |
+
Returns:
|
| 1295 |
+
merged_mask: merged mask
|
| 1296 |
+
mask_type: merged mask type (0, 1, or 2)
|
| 1297 |
+
"""
|
| 1298 |
+
mask_type: Optional[int] = None
|
| 1299 |
+
merged_mask: Optional[Tensor] = None
|
| 1300 |
+
|
| 1301 |
+
if key_padding_mask is not None:
|
| 1302 |
+
mask_type = 1
|
| 1303 |
+
merged_mask = key_padding_mask
|
| 1304 |
+
|
| 1305 |
+
if attn_mask is not None:
|
| 1306 |
+
# In this branch query can't be a nested tensor, so it has a shape
|
| 1307 |
+
batch_size, seq_len, _ = query.shape
|
| 1308 |
+
mask_type = 2
|
| 1309 |
+
|
| 1310 |
+
# Always expands attn_mask to 4D
|
| 1311 |
+
if attn_mask.dim() == 3:
|
| 1312 |
+
attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
|
| 1313 |
+
else: # attn_mask.dim() == 2:
|
| 1314 |
+
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
|
| 1315 |
+
merged_mask = attn_mask_expanded
|
| 1316 |
+
|
| 1317 |
+
if key_padding_mask is not None:
|
| 1318 |
+
key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
|
| 1319 |
+
merged_mask = attn_mask_expanded + key_padding_mask_expanded
|
| 1320 |
+
|
| 1321 |
+
# no attn_mask and no key_padding_mask, returns None, None
|
| 1322 |
+
return merged_mask, mask_type
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
class PReLU(Module):
|
| 1326 |
+
r"""Applies the element-wise PReLU function.
|
| 1327 |
+
|
| 1328 |
+
.. math::
|
| 1329 |
+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
| 1330 |
+
|
| 1331 |
+
or
|
| 1332 |
+
|
| 1333 |
+
.. math::
|
| 1334 |
+
\text{PReLU}(x) =
|
| 1335 |
+
\begin{cases}
|
| 1336 |
+
x, & \text{ if } x \geq 0 \\
|
| 1337 |
+
ax, & \text{ otherwise }
|
| 1338 |
+
\end{cases}
|
| 1339 |
+
|
| 1340 |
+
Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
|
| 1341 |
+
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
| 1342 |
+
a separate :math:`a` is used for each input channel.
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
.. note::
|
| 1346 |
+
weight decay should not be used when learning :math:`a` for good performance.
|
| 1347 |
+
|
| 1348 |
+
.. note::
|
| 1349 |
+
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
| 1350 |
+
no channel dim and the number of channels = 1.
|
| 1351 |
+
|
| 1352 |
+
Args:
|
| 1353 |
+
num_parameters (int): number of :math:`a` to learn.
|
| 1354 |
+
Although it takes an int as input, there is only two values are legitimate:
|
| 1355 |
+
1, or the number of channels at input. Default: 1
|
| 1356 |
+
init (float): the initial value of :math:`a`. Default: 0.25
|
| 1357 |
+
|
| 1358 |
+
Shape:
|
| 1359 |
+
- Input: :math:`( *)` where `*` means, any number of additional
|
| 1360 |
+
dimensions.
|
| 1361 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 1362 |
+
|
| 1363 |
+
Attributes:
|
| 1364 |
+
weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
|
| 1365 |
+
|
| 1366 |
+
.. image:: ../scripts/activation_images/PReLU.png
|
| 1367 |
+
|
| 1368 |
+
Examples::
|
| 1369 |
+
|
| 1370 |
+
>>> m = nn.PReLU()
|
| 1371 |
+
>>> input = torch.randn(2)
|
| 1372 |
+
>>> output = m(input)
|
| 1373 |
+
"""
|
| 1374 |
+
|
| 1375 |
+
__constants__ = ['num_parameters']
|
| 1376 |
+
num_parameters: int
|
| 1377 |
+
|
| 1378 |
+
def __init__(self, num_parameters: int = 1, init: float = 0.25,
|
| 1379 |
+
device=None, dtype=None) -> None:
|
| 1380 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 1381 |
+
self.num_parameters = num_parameters
|
| 1382 |
+
super().__init__()
|
| 1383 |
+
self.init = init
|
| 1384 |
+
self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
|
| 1385 |
+
self.reset_parameters()
|
| 1386 |
+
|
| 1387 |
+
def reset_parameters(self):
|
| 1388 |
+
torch.nn.init.constant_(self.weight, self.init)
|
| 1389 |
+
|
| 1390 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1391 |
+
return F.prelu(input, self.weight)
|
| 1392 |
+
|
| 1393 |
+
def extra_repr(self) -> str:
|
| 1394 |
+
return f'num_parameters={self.num_parameters}'
|
| 1395 |
+
|
| 1396 |
+
|
| 1397 |
+
class Softsign(Module):
|
| 1398 |
+
r"""Applies the element-wise Softsign function.
|
| 1399 |
+
|
| 1400 |
+
.. math::
|
| 1401 |
+
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
| 1402 |
+
|
| 1403 |
+
Shape:
|
| 1404 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 1405 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 1406 |
+
|
| 1407 |
+
.. image:: ../scripts/activation_images/Softsign.png
|
| 1408 |
+
|
| 1409 |
+
Examples::
|
| 1410 |
+
|
| 1411 |
+
>>> m = nn.Softsign()
|
| 1412 |
+
>>> input = torch.randn(2)
|
| 1413 |
+
>>> output = m(input)
|
| 1414 |
+
"""
|
| 1415 |
+
|
| 1416 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1417 |
+
return F.softsign(input)
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
class Tanhshrink(Module):
|
| 1421 |
+
r"""Applies the element-wise Tanhshrink function.
|
| 1422 |
+
|
| 1423 |
+
.. math::
|
| 1424 |
+
\text{Tanhshrink}(x) = x - \tanh(x)
|
| 1425 |
+
|
| 1426 |
+
Shape:
|
| 1427 |
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
| 1428 |
+
- Output: :math:`(*)`, same shape as the input.
|
| 1429 |
+
|
| 1430 |
+
.. image:: ../scripts/activation_images/Tanhshrink.png
|
| 1431 |
+
|
| 1432 |
+
Examples::
|
| 1433 |
+
|
| 1434 |
+
>>> m = nn.Tanhshrink()
|
| 1435 |
+
>>> input = torch.randn(2)
|
| 1436 |
+
>>> output = m(input)
|
| 1437 |
+
"""
|
| 1438 |
+
|
| 1439 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1440 |
+
return F.tanhshrink(input)
|
| 1441 |
+
|
| 1442 |
+
|
| 1443 |
+
class Softmin(Module):
|
| 1444 |
+
r"""Applies the Softmin function to an n-dimensional input Tensor.
|
| 1445 |
+
|
| 1446 |
+
Rescales them so that the elements of the n-dimensional output Tensor
|
| 1447 |
+
lie in the range `[0, 1]` and sum to 1.
|
| 1448 |
+
|
| 1449 |
+
Softmin is defined as:
|
| 1450 |
+
|
| 1451 |
+
.. math::
|
| 1452 |
+
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
| 1453 |
+
|
| 1454 |
+
Shape:
|
| 1455 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 1456 |
+
dimensions
|
| 1457 |
+
- Output: :math:`(*)`, same shape as the input
|
| 1458 |
+
|
| 1459 |
+
Args:
|
| 1460 |
+
dim (int): A dimension along which Softmin will be computed (so every slice
|
| 1461 |
+
along dim will sum to 1).
|
| 1462 |
+
|
| 1463 |
+
Returns:
|
| 1464 |
+
a Tensor of the same dimension and shape as the input, with
|
| 1465 |
+
values in the range [0, 1]
|
| 1466 |
+
|
| 1467 |
+
Examples::
|
| 1468 |
+
|
| 1469 |
+
>>> m = nn.Softmin(dim=1)
|
| 1470 |
+
>>> input = torch.randn(2, 3)
|
| 1471 |
+
>>> output = m(input)
|
| 1472 |
+
"""
|
| 1473 |
+
|
| 1474 |
+
__constants__ = ['dim']
|
| 1475 |
+
dim: Optional[int]
|
| 1476 |
+
|
| 1477 |
+
def __init__(self, dim: Optional[int] = None) -> None:
|
| 1478 |
+
super().__init__()
|
| 1479 |
+
self.dim = dim
|
| 1480 |
+
|
| 1481 |
+
def __setstate__(self, state):
|
| 1482 |
+
super().__setstate__(state)
|
| 1483 |
+
if not hasattr(self, 'dim'):
|
| 1484 |
+
self.dim = None
|
| 1485 |
+
|
| 1486 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1487 |
+
return F.softmin(input, self.dim, _stacklevel=5)
|
| 1488 |
+
|
| 1489 |
+
def extra_repr(self):
|
| 1490 |
+
return f'dim={self.dim}'
|
| 1491 |
+
|
| 1492 |
+
class Softmax(Module):
|
| 1493 |
+
r"""Applies the Softmax function to an n-dimensional input Tensor.
|
| 1494 |
+
|
| 1495 |
+
Rescales them so that the elements of the n-dimensional output Tensor
|
| 1496 |
+
lie in the range [0,1] and sum to 1.
|
| 1497 |
+
|
| 1498 |
+
Softmax is defined as:
|
| 1499 |
+
|
| 1500 |
+
.. math::
|
| 1501 |
+
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
| 1502 |
+
|
| 1503 |
+
When the input Tensor is a sparse tensor then the unspecified
|
| 1504 |
+
values are treated as ``-inf``.
|
| 1505 |
+
|
| 1506 |
+
Shape:
|
| 1507 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 1508 |
+
dimensions
|
| 1509 |
+
- Output: :math:`(*)`, same shape as the input
|
| 1510 |
+
|
| 1511 |
+
Returns:
|
| 1512 |
+
a Tensor of the same dimension and shape as the input with
|
| 1513 |
+
values in the range [0, 1]
|
| 1514 |
+
|
| 1515 |
+
Args:
|
| 1516 |
+
dim (int): A dimension along which Softmax will be computed (so every slice
|
| 1517 |
+
along dim will sum to 1).
|
| 1518 |
+
|
| 1519 |
+
.. note::
|
| 1520 |
+
This module doesn't work directly with NLLLoss,
|
| 1521 |
+
which expects the Log to be computed between the Softmax and itself.
|
| 1522 |
+
Use `LogSoftmax` instead (it's faster and has better numerical properties).
|
| 1523 |
+
|
| 1524 |
+
Examples::
|
| 1525 |
+
|
| 1526 |
+
>>> m = nn.Softmax(dim=1)
|
| 1527 |
+
>>> input = torch.randn(2, 3)
|
| 1528 |
+
>>> output = m(input)
|
| 1529 |
+
|
| 1530 |
+
"""
|
| 1531 |
+
|
| 1532 |
+
__constants__ = ['dim']
|
| 1533 |
+
dim: Optional[int]
|
| 1534 |
+
|
| 1535 |
+
def __init__(self, dim: Optional[int] = None) -> None:
|
| 1536 |
+
super().__init__()
|
| 1537 |
+
self.dim = dim
|
| 1538 |
+
|
| 1539 |
+
def __setstate__(self, state):
|
| 1540 |
+
super().__setstate__(state)
|
| 1541 |
+
if not hasattr(self, 'dim'):
|
| 1542 |
+
self.dim = None
|
| 1543 |
+
|
| 1544 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1545 |
+
return F.softmax(input, self.dim, _stacklevel=5)
|
| 1546 |
+
|
| 1547 |
+
def extra_repr(self) -> str:
|
| 1548 |
+
return f'dim={self.dim}'
|
| 1549 |
+
|
| 1550 |
+
|
| 1551 |
+
class Softmax2d(Module):
|
| 1552 |
+
r"""Applies SoftMax over features to each spatial location.
|
| 1553 |
+
|
| 1554 |
+
When given an image of ``Channels x Height x Width``, it will
|
| 1555 |
+
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
| 1556 |
+
|
| 1557 |
+
Shape:
|
| 1558 |
+
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
| 1559 |
+
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
|
| 1560 |
+
|
| 1561 |
+
Returns:
|
| 1562 |
+
a Tensor of the same dimension and shape as the input with
|
| 1563 |
+
values in the range [0, 1]
|
| 1564 |
+
|
| 1565 |
+
Examples::
|
| 1566 |
+
|
| 1567 |
+
>>> m = nn.Softmax2d()
|
| 1568 |
+
>>> # you softmax over the 2nd dimension
|
| 1569 |
+
>>> input = torch.randn(2, 3, 12, 13)
|
| 1570 |
+
>>> output = m(input)
|
| 1571 |
+
"""
|
| 1572 |
+
|
| 1573 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1574 |
+
if input.dim() not in (3, 4):
|
| 1575 |
+
raise ValueError(
|
| 1576 |
+
f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
|
| 1577 |
+
)
|
| 1578 |
+
return F.softmax(input, -3, _stacklevel=5)
|
| 1579 |
+
|
| 1580 |
+
|
| 1581 |
+
class LogSoftmax(Module):
|
| 1582 |
+
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
|
| 1583 |
+
|
| 1584 |
+
The LogSoftmax formulation can be simplified as:
|
| 1585 |
+
|
| 1586 |
+
.. math::
|
| 1587 |
+
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
| 1588 |
+
|
| 1589 |
+
Shape:
|
| 1590 |
+
- Input: :math:`(*)` where `*` means, any number of additional
|
| 1591 |
+
dimensions
|
| 1592 |
+
- Output: :math:`(*)`, same shape as the input
|
| 1593 |
+
|
| 1594 |
+
Args:
|
| 1595 |
+
dim (int): A dimension along which LogSoftmax will be computed.
|
| 1596 |
+
|
| 1597 |
+
Returns:
|
| 1598 |
+
a Tensor of the same dimension and shape as the input with
|
| 1599 |
+
values in the range [-inf, 0)
|
| 1600 |
+
|
| 1601 |
+
Examples::
|
| 1602 |
+
|
| 1603 |
+
>>> m = nn.LogSoftmax(dim=1)
|
| 1604 |
+
>>> input = torch.randn(2, 3)
|
| 1605 |
+
>>> output = m(input)
|
| 1606 |
+
"""
|
| 1607 |
+
|
| 1608 |
+
__constants__ = ['dim']
|
| 1609 |
+
dim: Optional[int]
|
| 1610 |
+
|
| 1611 |
+
def __init__(self, dim: Optional[int] = None) -> None:
|
| 1612 |
+
super().__init__()
|
| 1613 |
+
self.dim = dim
|
| 1614 |
+
|
| 1615 |
+
def __setstate__(self, state):
|
| 1616 |
+
super().__setstate__(state)
|
| 1617 |
+
if not hasattr(self, 'dim'):
|
| 1618 |
+
self.dim = None
|
| 1619 |
+
|
| 1620 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1621 |
+
return F.log_softmax(input, self.dim, _stacklevel=5)
|
| 1622 |
+
|
| 1623 |
+
def extra_repr(self):
|
| 1624 |
+
return f'dim={self.dim}'
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py
ADDED
|
@@ -0,0 +1,849 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
|
| 6 |
+
|
| 7 |
+
from .. import functional as F
|
| 8 |
+
from .. import init
|
| 9 |
+
from ._functions import SyncBatchNorm as sync_batch_norm
|
| 10 |
+
from .lazy import LazyModuleMixin
|
| 11 |
+
from .module import Module
|
| 12 |
+
|
| 13 |
+
__all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d',
|
| 14 |
+
'LazyBatchNorm3d', 'SyncBatchNorm']
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _NormBase(Module):
|
| 18 |
+
"""Common base of _InstanceNorm and _BatchNorm."""
|
| 19 |
+
|
| 20 |
+
_version = 2
|
| 21 |
+
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
|
| 22 |
+
num_features: int
|
| 23 |
+
eps: float
|
| 24 |
+
momentum: float
|
| 25 |
+
affine: bool
|
| 26 |
+
track_running_stats: bool
|
| 27 |
+
# WARNING: weight and bias purposely not defined here.
|
| 28 |
+
# See https://github.com/pytorch/pytorch/issues/39670
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
num_features: int,
|
| 33 |
+
eps: float = 1e-5,
|
| 34 |
+
momentum: float = 0.1,
|
| 35 |
+
affine: bool = True,
|
| 36 |
+
track_running_stats: bool = True,
|
| 37 |
+
device=None,
|
| 38 |
+
dtype=None
|
| 39 |
+
) -> None:
|
| 40 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.num_features = num_features
|
| 43 |
+
self.eps = eps
|
| 44 |
+
self.momentum = momentum
|
| 45 |
+
self.affine = affine
|
| 46 |
+
self.track_running_stats = track_running_stats
|
| 47 |
+
if self.affine:
|
| 48 |
+
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
|
| 49 |
+
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
|
| 50 |
+
else:
|
| 51 |
+
self.register_parameter("weight", None)
|
| 52 |
+
self.register_parameter("bias", None)
|
| 53 |
+
if self.track_running_stats:
|
| 54 |
+
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
|
| 55 |
+
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
|
| 56 |
+
self.running_mean: Optional[Tensor]
|
| 57 |
+
self.running_var: Optional[Tensor]
|
| 58 |
+
self.register_buffer('num_batches_tracked',
|
| 59 |
+
torch.tensor(0, dtype=torch.long,
|
| 60 |
+
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
|
| 61 |
+
self.num_batches_tracked: Optional[Tensor]
|
| 62 |
+
else:
|
| 63 |
+
self.register_buffer("running_mean", None)
|
| 64 |
+
self.register_buffer("running_var", None)
|
| 65 |
+
self.register_buffer("num_batches_tracked", None)
|
| 66 |
+
self.reset_parameters()
|
| 67 |
+
|
| 68 |
+
def reset_running_stats(self) -> None:
|
| 69 |
+
if self.track_running_stats:
|
| 70 |
+
# running_mean/running_var/num_batches... are registered at runtime depending
|
| 71 |
+
# if self.track_running_stats is on
|
| 72 |
+
self.running_mean.zero_() # type: ignore[union-attr]
|
| 73 |
+
self.running_var.fill_(1) # type: ignore[union-attr]
|
| 74 |
+
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
|
| 75 |
+
|
| 76 |
+
def reset_parameters(self) -> None:
|
| 77 |
+
self.reset_running_stats()
|
| 78 |
+
if self.affine:
|
| 79 |
+
init.ones_(self.weight)
|
| 80 |
+
init.zeros_(self.bias)
|
| 81 |
+
|
| 82 |
+
def _check_input_dim(self, input):
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
def extra_repr(self):
|
| 86 |
+
return (
|
| 87 |
+
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
|
| 88 |
+
"track_running_stats={track_running_stats}".format(**self.__dict__)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def _load_from_state_dict(
|
| 92 |
+
self,
|
| 93 |
+
state_dict,
|
| 94 |
+
prefix,
|
| 95 |
+
local_metadata,
|
| 96 |
+
strict,
|
| 97 |
+
missing_keys,
|
| 98 |
+
unexpected_keys,
|
| 99 |
+
error_msgs,
|
| 100 |
+
):
|
| 101 |
+
version = local_metadata.get("version", None)
|
| 102 |
+
|
| 103 |
+
if (version is None or version < 2) and self.track_running_stats:
|
| 104 |
+
# at version 2: added num_batches_tracked buffer
|
| 105 |
+
# this should have a default value of 0
|
| 106 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
| 107 |
+
if num_batches_tracked_key not in state_dict:
|
| 108 |
+
state_dict[num_batches_tracked_key] = (
|
| 109 |
+
self.num_batches_tracked
|
| 110 |
+
if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta')
|
| 111 |
+
else torch.tensor(0, dtype=torch.long)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
super()._load_from_state_dict(
|
| 115 |
+
state_dict,
|
| 116 |
+
prefix,
|
| 117 |
+
local_metadata,
|
| 118 |
+
strict,
|
| 119 |
+
missing_keys,
|
| 120 |
+
unexpected_keys,
|
| 121 |
+
error_msgs,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class _BatchNorm(_NormBase):
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
num_features: int,
|
| 129 |
+
eps: float = 1e-5,
|
| 130 |
+
momentum: float = 0.1,
|
| 131 |
+
affine: bool = True,
|
| 132 |
+
track_running_stats: bool = True,
|
| 133 |
+
device=None,
|
| 134 |
+
dtype=None
|
| 135 |
+
) -> None:
|
| 136 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 137 |
+
super().__init__(
|
| 138 |
+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 142 |
+
self._check_input_dim(input)
|
| 143 |
+
|
| 144 |
+
# exponential_average_factor is set to self.momentum
|
| 145 |
+
# (when it is available) only so that it gets updated
|
| 146 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 147 |
+
if self.momentum is None:
|
| 148 |
+
exponential_average_factor = 0.0
|
| 149 |
+
else:
|
| 150 |
+
exponential_average_factor = self.momentum
|
| 151 |
+
|
| 152 |
+
if self.training and self.track_running_stats:
|
| 153 |
+
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
| 154 |
+
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
| 155 |
+
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
| 156 |
+
if self.momentum is None: # use cumulative moving average
|
| 157 |
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
| 158 |
+
else: # use exponential moving average
|
| 159 |
+
exponential_average_factor = self.momentum
|
| 160 |
+
|
| 161 |
+
r"""
|
| 162 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 163 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 164 |
+
"""
|
| 165 |
+
if self.training:
|
| 166 |
+
bn_training = True
|
| 167 |
+
else:
|
| 168 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 169 |
+
|
| 170 |
+
r"""
|
| 171 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 172 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 173 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 174 |
+
"""
|
| 175 |
+
return F.batch_norm(
|
| 176 |
+
input,
|
| 177 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 178 |
+
self.running_mean
|
| 179 |
+
if not self.training or self.track_running_stats
|
| 180 |
+
else None,
|
| 181 |
+
self.running_var if not self.training or self.track_running_stats else None,
|
| 182 |
+
self.weight,
|
| 183 |
+
self.bias,
|
| 184 |
+
bn_training,
|
| 185 |
+
exponential_average_factor,
|
| 186 |
+
self.eps,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class _LazyNormBase(LazyModuleMixin, _NormBase):
|
| 191 |
+
|
| 192 |
+
weight: UninitializedParameter # type: ignore[assignment]
|
| 193 |
+
bias: UninitializedParameter # type: ignore[assignment]
|
| 194 |
+
|
| 195 |
+
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
| 196 |
+
device=None, dtype=None) -> None:
|
| 197 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 198 |
+
super().__init__(
|
| 199 |
+
# affine and track_running_stats are hardcoded to False to
|
| 200 |
+
# avoid creating tensors that will soon be overwritten.
|
| 201 |
+
0,
|
| 202 |
+
eps,
|
| 203 |
+
momentum,
|
| 204 |
+
False,
|
| 205 |
+
False,
|
| 206 |
+
**factory_kwargs,
|
| 207 |
+
)
|
| 208 |
+
self.affine = affine
|
| 209 |
+
self.track_running_stats = track_running_stats
|
| 210 |
+
if self.affine:
|
| 211 |
+
self.weight = UninitializedParameter(**factory_kwargs)
|
| 212 |
+
self.bias = UninitializedParameter(**factory_kwargs)
|
| 213 |
+
if self.track_running_stats:
|
| 214 |
+
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
| 215 |
+
self.running_var = UninitializedBuffer(**factory_kwargs)
|
| 216 |
+
self.num_batches_tracked = torch.tensor(
|
| 217 |
+
0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
|
| 218 |
+
|
| 219 |
+
def reset_parameters(self) -> None:
|
| 220 |
+
if not self.has_uninitialized_params() and self.num_features != 0:
|
| 221 |
+
super().reset_parameters()
|
| 222 |
+
|
| 223 |
+
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
| 224 |
+
if self.has_uninitialized_params():
|
| 225 |
+
self.num_features = input.shape[1]
|
| 226 |
+
if self.affine:
|
| 227 |
+
assert isinstance(self.weight, UninitializedParameter)
|
| 228 |
+
assert isinstance(self.bias, UninitializedParameter)
|
| 229 |
+
self.weight.materialize((self.num_features,))
|
| 230 |
+
self.bias.materialize((self.num_features,))
|
| 231 |
+
if self.track_running_stats:
|
| 232 |
+
self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
|
| 233 |
+
self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
|
| 234 |
+
self.reset_parameters()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class BatchNorm1d(_BatchNorm):
|
| 238 |
+
r"""Applies Batch Normalization over a 2D or 3D input.
|
| 239 |
+
|
| 240 |
+
Method described in the paper
|
| 241 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 242 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 243 |
+
|
| 244 |
+
.. math::
|
| 245 |
+
|
| 246 |
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 247 |
+
|
| 248 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 249 |
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
| 250 |
+
of size `C` (where `C` is the number of features or channels of the input). By default, the
|
| 251 |
+
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
|
| 252 |
+
At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
|
| 253 |
+
equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
|
| 254 |
+
moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to
|
| 255 |
+
``torch.var(input, unbiased=True)``.
|
| 256 |
+
|
| 257 |
+
Also by default, during training this layer keeps running estimates of its
|
| 258 |
+
computed mean and variance, which are then used for normalization during
|
| 259 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 260 |
+
of 0.1.
|
| 261 |
+
|
| 262 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 263 |
+
keep running estimates, and batch statistics are instead used during
|
| 264 |
+
evaluation time as well.
|
| 265 |
+
|
| 266 |
+
.. note::
|
| 267 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 268 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 269 |
+
update rule for running statistics here is
|
| 270 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 271 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 272 |
+
new observed value.
|
| 273 |
+
|
| 274 |
+
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
| 275 |
+
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
num_features: number of features or channels :math:`C` of the input
|
| 279 |
+
eps: a value added to the denominator for numerical stability.
|
| 280 |
+
Default: 1e-5
|
| 281 |
+
momentum: the value used for the running_mean and running_var
|
| 282 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 283 |
+
(i.e. simple average). Default: 0.1
|
| 284 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 285 |
+
learnable affine parameters. Default: ``True``
|
| 286 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 287 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 288 |
+
this module does not track such statistics, and initializes statistics
|
| 289 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 290 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 291 |
+
in both training and eval modes. Default: ``True``
|
| 292 |
+
|
| 293 |
+
Shape:
|
| 294 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
|
| 295 |
+
:math:`C` is the number of features or channels, and :math:`L` is the sequence length
|
| 296 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| 297 |
+
|
| 298 |
+
Examples::
|
| 299 |
+
|
| 300 |
+
>>> # With Learnable Parameters
|
| 301 |
+
>>> m = nn.BatchNorm1d(100)
|
| 302 |
+
>>> # Without Learnable Parameters
|
| 303 |
+
>>> m = nn.BatchNorm1d(100, affine=False)
|
| 304 |
+
>>> input = torch.randn(20, 100)
|
| 305 |
+
>>> output = m(input)
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
def _check_input_dim(self, input):
|
| 309 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"expected 2D or 3D input (got {input.dim()}D input)"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
|
| 316 |
+
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
|
| 317 |
+
|
| 318 |
+
Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
|
| 319 |
+
from the ``input.size(1)``.
|
| 320 |
+
The attributes that will be lazily initialized are `weight`, `bias`,
|
| 321 |
+
`running_mean` and `running_var`.
|
| 322 |
+
|
| 323 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 324 |
+
on lazy modules and their limitations.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
eps: a value added to the denominator for numerical stability.
|
| 328 |
+
Default: 1e-5
|
| 329 |
+
momentum: the value used for the running_mean and running_var
|
| 330 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 331 |
+
(i.e. simple average). Default: 0.1
|
| 332 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 333 |
+
learnable affine parameters. Default: ``True``
|
| 334 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 335 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 336 |
+
this module does not track such statistics, and initializes statistics
|
| 337 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 338 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 339 |
+
in both training and eval modes. Default: ``True``
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
cls_to_become = BatchNorm1d # type: ignore[assignment]
|
| 343 |
+
|
| 344 |
+
def _check_input_dim(self, input):
|
| 345 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 346 |
+
raise ValueError(
|
| 347 |
+
f"expected 2D or 3D input (got {input.dim()}D input)"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class BatchNorm2d(_BatchNorm):
|
| 352 |
+
r"""Applies Batch Normalization over a 4D input.
|
| 353 |
+
|
| 354 |
+
4D is a mini-batch of 2D inputs
|
| 355 |
+
with additional channel dimension. Method described in the paper
|
| 356 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 357 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 358 |
+
|
| 359 |
+
.. math::
|
| 360 |
+
|
| 361 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 362 |
+
|
| 363 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 364 |
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
| 365 |
+
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
| 366 |
+
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
|
| 367 |
+
standard-deviation is calculated via the biased estimator, equivalent to
|
| 368 |
+
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
|
| 369 |
+
standard-deviation is calculated via the unbiased estimator, equivalent to
|
| 370 |
+
``torch.var(input, unbiased=True)``.
|
| 371 |
+
|
| 372 |
+
Also by default, during training this layer keeps running estimates of its
|
| 373 |
+
computed mean and variance, which are then used for normalization during
|
| 374 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 375 |
+
of 0.1.
|
| 376 |
+
|
| 377 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 378 |
+
keep running estimates, and batch statistics are instead used during
|
| 379 |
+
evaluation time as well.
|
| 380 |
+
|
| 381 |
+
.. note::
|
| 382 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 383 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 384 |
+
update rule for running statistics here is
|
| 385 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 386 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 387 |
+
new observed value.
|
| 388 |
+
|
| 389 |
+
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
| 390 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
num_features: :math:`C` from an expected input of size
|
| 394 |
+
:math:`(N, C, H, W)`
|
| 395 |
+
eps: a value added to the denominator for numerical stability.
|
| 396 |
+
Default: 1e-5
|
| 397 |
+
momentum: the value used for the running_mean and running_var
|
| 398 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 399 |
+
(i.e. simple average). Default: 0.1
|
| 400 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 401 |
+
learnable affine parameters. Default: ``True``
|
| 402 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 403 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 404 |
+
this module does not track such statistics, and initializes statistics
|
| 405 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 406 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 407 |
+
in both training and eval modes. Default: ``True``
|
| 408 |
+
|
| 409 |
+
Shape:
|
| 410 |
+
- Input: :math:`(N, C, H, W)`
|
| 411 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
| 412 |
+
|
| 413 |
+
Examples::
|
| 414 |
+
|
| 415 |
+
>>> # With Learnable Parameters
|
| 416 |
+
>>> m = nn.BatchNorm2d(100)
|
| 417 |
+
>>> # Without Learnable Parameters
|
| 418 |
+
>>> m = nn.BatchNorm2d(100, affine=False)
|
| 419 |
+
>>> input = torch.randn(20, 100, 35, 45)
|
| 420 |
+
>>> output = m(input)
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
def _check_input_dim(self, input):
|
| 424 |
+
if input.dim() != 4:
|
| 425 |
+
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
|
| 429 |
+
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
|
| 430 |
+
|
| 431 |
+
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
|
| 432 |
+
from the ``input.size(1)``.
|
| 433 |
+
The attributes that will be lazily initialized are `weight`, `bias`,
|
| 434 |
+
`running_mean` and `running_var`.
|
| 435 |
+
|
| 436 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 437 |
+
on lazy modules and their limitations.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
eps: a value added to the denominator for numerical stability.
|
| 441 |
+
Default: 1e-5
|
| 442 |
+
momentum: the value used for the running_mean and running_var
|
| 443 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 444 |
+
(i.e. simple average). Default: 0.1
|
| 445 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 446 |
+
learnable affine parameters. Default: ``True``
|
| 447 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 448 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 449 |
+
this module does not track such statistics, and initializes statistics
|
| 450 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 451 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 452 |
+
in both training and eval modes. Default: ``True``
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
cls_to_become = BatchNorm2d # type: ignore[assignment]
|
| 456 |
+
|
| 457 |
+
def _check_input_dim(self, input):
|
| 458 |
+
if input.dim() != 4:
|
| 459 |
+
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class BatchNorm3d(_BatchNorm):
|
| 463 |
+
r"""Applies Batch Normalization over a 5D input.
|
| 464 |
+
|
| 465 |
+
5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
|
| 466 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 467 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 468 |
+
|
| 469 |
+
.. math::
|
| 470 |
+
|
| 471 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 472 |
+
|
| 473 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 474 |
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
| 475 |
+
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
|
| 476 |
+
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
|
| 477 |
+
standard-deviation is calculated via the biased estimator, equivalent to
|
| 478 |
+
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
|
| 479 |
+
standard-deviation is calculated via the unbiased estimator, equivalent to
|
| 480 |
+
``torch.var(input, unbiased=True)``.
|
| 481 |
+
|
| 482 |
+
Also by default, during training this layer keeps running estimates of its
|
| 483 |
+
computed mean and variance, which are then used for normalization during
|
| 484 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 485 |
+
of 0.1.
|
| 486 |
+
|
| 487 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 488 |
+
keep running estimates, and batch statistics are instead used during
|
| 489 |
+
evaluation time as well.
|
| 490 |
+
|
| 491 |
+
.. note::
|
| 492 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 493 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 494 |
+
update rule for running statistics here is
|
| 495 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 496 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 497 |
+
new observed value.
|
| 498 |
+
|
| 499 |
+
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
| 500 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
|
| 501 |
+
or Spatio-temporal Batch Normalization.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
num_features: :math:`C` from an expected input of size
|
| 505 |
+
:math:`(N, C, D, H, W)`
|
| 506 |
+
eps: a value added to the denominator for numerical stability.
|
| 507 |
+
Default: 1e-5
|
| 508 |
+
momentum: the value used for the running_mean and running_var
|
| 509 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 510 |
+
(i.e. simple average). Default: 0.1
|
| 511 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 512 |
+
learnable affine parameters. Default: ``True``
|
| 513 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 514 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 515 |
+
this module does not track such statistics, and initializes statistics
|
| 516 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 517 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 518 |
+
in both training and eval modes. Default: ``True``
|
| 519 |
+
|
| 520 |
+
Shape:
|
| 521 |
+
- Input: :math:`(N, C, D, H, W)`
|
| 522 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| 523 |
+
|
| 524 |
+
Examples::
|
| 525 |
+
|
| 526 |
+
>>> # With Learnable Parameters
|
| 527 |
+
>>> m = nn.BatchNorm3d(100)
|
| 528 |
+
>>> # Without Learnable Parameters
|
| 529 |
+
>>> m = nn.BatchNorm3d(100, affine=False)
|
| 530 |
+
>>> input = torch.randn(20, 100, 35, 45, 10)
|
| 531 |
+
>>> output = m(input)
|
| 532 |
+
"""
|
| 533 |
+
|
| 534 |
+
def _check_input_dim(self, input):
|
| 535 |
+
if input.dim() != 5:
|
| 536 |
+
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
|
| 540 |
+
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
|
| 541 |
+
|
| 542 |
+
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
|
| 543 |
+
from the ``input.size(1)``.
|
| 544 |
+
The attributes that will be lazily initialized are `weight`, `bias`,
|
| 545 |
+
`running_mean` and `running_var`.
|
| 546 |
+
|
| 547 |
+
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
|
| 548 |
+
on lazy modules and their limitations.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
eps: a value added to the denominator for numerical stability.
|
| 552 |
+
Default: 1e-5
|
| 553 |
+
momentum: the value used for the running_mean and running_var
|
| 554 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 555 |
+
(i.e. simple average). Default: 0.1
|
| 556 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 557 |
+
learnable affine parameters. Default: ``True``
|
| 558 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 559 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 560 |
+
this module does not track such statistics, and initializes statistics
|
| 561 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 562 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 563 |
+
in both training and eval modes. Default: ``True``
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
cls_to_become = BatchNorm3d # type: ignore[assignment]
|
| 567 |
+
|
| 568 |
+
def _check_input_dim(self, input):
|
| 569 |
+
if input.dim() != 5:
|
| 570 |
+
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class SyncBatchNorm(_BatchNorm):
|
| 574 |
+
r"""Applies Batch Normalization over a N-Dimensional input.
|
| 575 |
+
|
| 576 |
+
The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
|
| 577 |
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
| 578 |
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
|
| 579 |
+
|
| 580 |
+
.. math::
|
| 581 |
+
|
| 582 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 583 |
+
|
| 584 |
+
The mean and standard-deviation are calculated per-dimension over all
|
| 585 |
+
mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
|
| 586 |
+
are learnable parameter vectors of size `C` (where `C` is the input size).
|
| 587 |
+
By default, the elements of :math:`\gamma` are sampled from
|
| 588 |
+
:math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
|
| 589 |
+
The standard-deviation is calculated via the biased estimator, equivalent to
|
| 590 |
+
`torch.var(input, unbiased=False)`.
|
| 591 |
+
|
| 592 |
+
Also by default, during training this layer keeps running estimates of its
|
| 593 |
+
computed mean and variance, which are then used for normalization during
|
| 594 |
+
evaluation. The running estimates are kept with a default :attr:`momentum`
|
| 595 |
+
of 0.1.
|
| 596 |
+
|
| 597 |
+
If :attr:`track_running_stats` is set to ``False``, this layer then does not
|
| 598 |
+
keep running estimates, and batch statistics are instead used during
|
| 599 |
+
evaluation time as well.
|
| 600 |
+
|
| 601 |
+
.. note::
|
| 602 |
+
This :attr:`momentum` argument is different from one used in optimizer
|
| 603 |
+
classes and the conventional notion of momentum. Mathematically, the
|
| 604 |
+
update rule for running statistics here is
|
| 605 |
+
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
|
| 606 |
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
| 607 |
+
new observed value.
|
| 608 |
+
|
| 609 |
+
Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
|
| 610 |
+
statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
|
| 611 |
+
Normalization or Spatio-temporal Batch Normalization.
|
| 612 |
+
|
| 613 |
+
Currently :class:`SyncBatchNorm` only supports
|
| 614 |
+
:class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
|
| 615 |
+
:meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
|
| 616 |
+
:attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
|
| 617 |
+
Network with DDP.
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
num_features: :math:`C` from an expected input of size
|
| 621 |
+
:math:`(N, C, +)`
|
| 622 |
+
eps: a value added to the denominator for numerical stability.
|
| 623 |
+
Default: ``1e-5``
|
| 624 |
+
momentum: the value used for the running_mean and running_var
|
| 625 |
+
computation. Can be set to ``None`` for cumulative moving average
|
| 626 |
+
(i.e. simple average). Default: 0.1
|
| 627 |
+
affine: a boolean value that when set to ``True``, this module has
|
| 628 |
+
learnable affine parameters. Default: ``True``
|
| 629 |
+
track_running_stats: a boolean value that when set to ``True``, this
|
| 630 |
+
module tracks the running mean and variance, and when set to ``False``,
|
| 631 |
+
this module does not track such statistics, and initializes statistics
|
| 632 |
+
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
| 633 |
+
When these buffers are ``None``, this module always uses batch statistics.
|
| 634 |
+
in both training and eval modes. Default: ``True``
|
| 635 |
+
process_group: synchronization of stats happen within each process group
|
| 636 |
+
individually. Default behavior is synchronization across the whole
|
| 637 |
+
world
|
| 638 |
+
|
| 639 |
+
Shape:
|
| 640 |
+
- Input: :math:`(N, C, +)`
|
| 641 |
+
- Output: :math:`(N, C, +)` (same shape as input)
|
| 642 |
+
|
| 643 |
+
.. note::
|
| 644 |
+
Synchronization of batchnorm statistics occurs only while training, i.e.
|
| 645 |
+
synchronization is disabled when ``model.eval()`` is set or if
|
| 646 |
+
``self.training`` is otherwise ``False``.
|
| 647 |
+
|
| 648 |
+
Examples::
|
| 649 |
+
|
| 650 |
+
>>> # xdoctest: +SKIP
|
| 651 |
+
>>> # With Learnable Parameters
|
| 652 |
+
>>> m = nn.SyncBatchNorm(100)
|
| 653 |
+
>>> # creating process group (optional)
|
| 654 |
+
>>> # ranks is a list of int identifying rank ids.
|
| 655 |
+
>>> ranks = list(range(8))
|
| 656 |
+
>>> r1, r2 = ranks[:4], ranks[4:]
|
| 657 |
+
>>> # Note: every rank calls into new_group for every
|
| 658 |
+
>>> # process group created, even if that rank is not
|
| 659 |
+
>>> # part of the group.
|
| 660 |
+
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
| 661 |
+
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
| 662 |
+
>>> # Without Learnable Parameters
|
| 663 |
+
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
|
| 664 |
+
>>> input = torch.randn(20, 100, 35, 45, 10)
|
| 665 |
+
>>> output = m(input)
|
| 666 |
+
|
| 667 |
+
>>> # network is nn.BatchNorm layer
|
| 668 |
+
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
|
| 669 |
+
>>> # only single gpu per process is currently supported
|
| 670 |
+
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
|
| 671 |
+
>>> sync_bn_network,
|
| 672 |
+
>>> device_ids=[args.local_rank],
|
| 673 |
+
>>> output_device=args.local_rank)
|
| 674 |
+
"""
|
| 675 |
+
|
| 676 |
+
def __init__(
|
| 677 |
+
self,
|
| 678 |
+
num_features: int,
|
| 679 |
+
eps: float = 1e-5,
|
| 680 |
+
momentum: float = 0.1,
|
| 681 |
+
affine: bool = True,
|
| 682 |
+
track_running_stats: bool = True,
|
| 683 |
+
process_group: Optional[Any] = None,
|
| 684 |
+
device=None,
|
| 685 |
+
dtype=None
|
| 686 |
+
) -> None:
|
| 687 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 688 |
+
super().__init__(
|
| 689 |
+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
| 690 |
+
)
|
| 691 |
+
self.process_group = process_group
|
| 692 |
+
|
| 693 |
+
def _check_input_dim(self, input):
|
| 694 |
+
if input.dim() < 2:
|
| 695 |
+
raise ValueError(
|
| 696 |
+
f"expected at least 2D input (got {input.dim()}D input)"
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
def _check_non_zero_input_channels(self, input):
|
| 700 |
+
if input.size(1) == 0:
|
| 701 |
+
raise ValueError(
|
| 702 |
+
"SyncBatchNorm number of input channels should be non-zero"
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 706 |
+
self._check_input_dim(input)
|
| 707 |
+
self._check_non_zero_input_channels(input)
|
| 708 |
+
|
| 709 |
+
# exponential_average_factor is set to self.momentum
|
| 710 |
+
# (when it is available) only so that it gets updated
|
| 711 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 712 |
+
if self.momentum is None:
|
| 713 |
+
exponential_average_factor = 0.0
|
| 714 |
+
else:
|
| 715 |
+
exponential_average_factor = self.momentum
|
| 716 |
+
|
| 717 |
+
if self.training and self.track_running_stats:
|
| 718 |
+
assert self.num_batches_tracked is not None
|
| 719 |
+
self.num_batches_tracked.add_(1)
|
| 720 |
+
if self.momentum is None: # use cumulative moving average
|
| 721 |
+
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
| 722 |
+
else: # use exponential moving average
|
| 723 |
+
exponential_average_factor = self.momentum
|
| 724 |
+
|
| 725 |
+
r"""
|
| 726 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 727 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 728 |
+
"""
|
| 729 |
+
if self.training:
|
| 730 |
+
bn_training = True
|
| 731 |
+
else:
|
| 732 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 733 |
+
|
| 734 |
+
r"""
|
| 735 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 736 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 737 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 738 |
+
"""
|
| 739 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 740 |
+
running_mean = (
|
| 741 |
+
self.running_mean if not self.training or self.track_running_stats else None
|
| 742 |
+
)
|
| 743 |
+
running_var = (
|
| 744 |
+
self.running_var if not self.training or self.track_running_stats else None
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# Don't sync batchnorm stats in inference mode (model.eval()).
|
| 748 |
+
need_sync = (bn_training and self.training and
|
| 749 |
+
torch.distributed.is_available() and torch.distributed.is_initialized())
|
| 750 |
+
if need_sync:
|
| 751 |
+
# currently only GPU/PrivateUse1 input is supported
|
| 752 |
+
if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
| 753 |
+
raise ValueError("SyncBatchNorm expected input tensor to be on GPU or "
|
| 754 |
+
f"{torch._C._get_privateuse1_backend_name()}")
|
| 755 |
+
|
| 756 |
+
process_group = torch.distributed.group.WORLD
|
| 757 |
+
if self.process_group:
|
| 758 |
+
process_group = self.process_group
|
| 759 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 760 |
+
need_sync = world_size > 1
|
| 761 |
+
|
| 762 |
+
# fallback to framework BN when synchronization is not necessary
|
| 763 |
+
if not need_sync:
|
| 764 |
+
return F.batch_norm(
|
| 765 |
+
input,
|
| 766 |
+
running_mean,
|
| 767 |
+
running_var,
|
| 768 |
+
self.weight,
|
| 769 |
+
self.bias,
|
| 770 |
+
bn_training,
|
| 771 |
+
exponential_average_factor,
|
| 772 |
+
self.eps,
|
| 773 |
+
)
|
| 774 |
+
else:
|
| 775 |
+
assert bn_training
|
| 776 |
+
return sync_batch_norm.apply(
|
| 777 |
+
input,
|
| 778 |
+
self.weight,
|
| 779 |
+
self.bias,
|
| 780 |
+
running_mean,
|
| 781 |
+
running_var,
|
| 782 |
+
self.eps,
|
| 783 |
+
exponential_average_factor,
|
| 784 |
+
process_group, # type: ignore[possibly-undefined]
|
| 785 |
+
world_size, # type: ignore[possibly-undefined]
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
@classmethod
|
| 789 |
+
def convert_sync_batchnorm(cls, module, process_group=None):
|
| 790 |
+
r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
|
| 791 |
+
|
| 792 |
+
Args:
|
| 793 |
+
module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
|
| 794 |
+
process_group (optional): process group to scope synchronization,
|
| 795 |
+
default is the whole world
|
| 796 |
+
|
| 797 |
+
Returns:
|
| 798 |
+
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
|
| 799 |
+
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
|
| 800 |
+
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
|
| 801 |
+
instead.
|
| 802 |
+
|
| 803 |
+
Example::
|
| 804 |
+
|
| 805 |
+
>>> # Network with nn.BatchNorm layer
|
| 806 |
+
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
| 807 |
+
>>> module = torch.nn.Sequential(
|
| 808 |
+
>>> torch.nn.Linear(20, 100),
|
| 809 |
+
>>> torch.nn.BatchNorm1d(100),
|
| 810 |
+
>>> ).cuda()
|
| 811 |
+
>>> # creating process group (optional)
|
| 812 |
+
>>> # ranks is a list of int identifying rank ids.
|
| 813 |
+
>>> ranks = list(range(8))
|
| 814 |
+
>>> r1, r2 = ranks[:4], ranks[4:]
|
| 815 |
+
>>> # Note: every rank calls into new_group for every
|
| 816 |
+
>>> # process group created, even if that rank is not
|
| 817 |
+
>>> # part of the group.
|
| 818 |
+
>>> # xdoctest: +SKIP("distributed")
|
| 819 |
+
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
|
| 820 |
+
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
|
| 821 |
+
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
|
| 822 |
+
|
| 823 |
+
"""
|
| 824 |
+
module_output = module
|
| 825 |
+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
| 826 |
+
module_output = torch.nn.SyncBatchNorm(
|
| 827 |
+
module.num_features,
|
| 828 |
+
module.eps,
|
| 829 |
+
module.momentum,
|
| 830 |
+
module.affine,
|
| 831 |
+
module.track_running_stats,
|
| 832 |
+
process_group,
|
| 833 |
+
)
|
| 834 |
+
if module.affine:
|
| 835 |
+
with torch.no_grad():
|
| 836 |
+
module_output.weight = module.weight
|
| 837 |
+
module_output.bias = module.bias
|
| 838 |
+
module_output.running_mean = module.running_mean
|
| 839 |
+
module_output.running_var = module.running_var
|
| 840 |
+
module_output.num_batches_tracked = module.num_batches_tracked
|
| 841 |
+
module_output.training = module.training
|
| 842 |
+
if hasattr(module, "qconfig"):
|
| 843 |
+
module_output.qconfig = module.qconfig
|
| 844 |
+
for name, child in module.named_children():
|
| 845 |
+
module_output.add_module(
|
| 846 |
+
name, cls.convert_sync_batchnorm(child, process_group)
|
| 847 |
+
)
|
| 848 |
+
del module
|
| 849 |
+
return module_output
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/channelshuffle.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
from .. import functional as F
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
__all__ = ['ChannelShuffle']
|
| 7 |
+
|
| 8 |
+
class ChannelShuffle(Module):
|
| 9 |
+
r"""Divides and rearranges the channels in a tensor.
|
| 10 |
+
|
| 11 |
+
This operation divides the channels in a tensor of shape :math:`(*, C , H, W)`
|
| 12 |
+
into g groups and rearranges them as :math:`(*, \frac{C}{g}, g, H, W)`,
|
| 13 |
+
while keeping the original tensor shape.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
groups (int): number of groups to divide channels in.
|
| 17 |
+
|
| 18 |
+
Examples::
|
| 19 |
+
|
| 20 |
+
>>> # xdoctest: +IGNORE_WANT("FIXME: incorrect want")
|
| 21 |
+
>>> channel_shuffle = nn.ChannelShuffle(2)
|
| 22 |
+
>>> input = torch.randn(1, 4, 2, 2)
|
| 23 |
+
>>> print(input)
|
| 24 |
+
[[[[1, 2],
|
| 25 |
+
[3, 4]],
|
| 26 |
+
[[5, 6],
|
| 27 |
+
[7, 8]],
|
| 28 |
+
[[9, 10],
|
| 29 |
+
[11, 12]],
|
| 30 |
+
[[13, 14],
|
| 31 |
+
[15, 16]],
|
| 32 |
+
]]
|
| 33 |
+
>>> output = channel_shuffle(input)
|
| 34 |
+
>>> print(output)
|
| 35 |
+
[[[[1, 2],
|
| 36 |
+
[3, 4]],
|
| 37 |
+
[[9, 10],
|
| 38 |
+
[11, 12]],
|
| 39 |
+
[[5, 6],
|
| 40 |
+
[7, 8]],
|
| 41 |
+
[[13, 14],
|
| 42 |
+
[15, 16]],
|
| 43 |
+
]]
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
__constants__ = ['groups']
|
| 47 |
+
groups: int
|
| 48 |
+
|
| 49 |
+
def __init__(self, groups: int) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.groups = groups
|
| 52 |
+
|
| 53 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 54 |
+
return F.channel_shuffle(input, self.groups)
|
| 55 |
+
|
| 56 |
+
def extra_repr(self) -> str:
|
| 57 |
+
return f'groups={self.groups}'
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from collections import OrderedDict, abc as container_abcs
|
| 3 |
+
from itertools import chain, islice
|
| 4 |
+
import operator
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from .module import Module
|
| 8 |
+
from ..parameter import Parameter
|
| 9 |
+
from torch._jit_internal import _copy_to_script_wrapper
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
|
| 12 |
+
from typing_extensions import Self
|
| 13 |
+
|
| 14 |
+
__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
|
| 15 |
+
|
| 16 |
+
T = TypeVar('T', bound=Module)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
|
| 20 |
+
def _addindent(s_, numSpaces):
|
| 21 |
+
s = s_.split('\n')
|
| 22 |
+
# don't do anything for single-line stuff
|
| 23 |
+
if len(s) == 1:
|
| 24 |
+
return s_
|
| 25 |
+
first = s.pop(0)
|
| 26 |
+
s = [(numSpaces * ' ') + line for line in s]
|
| 27 |
+
s = '\n'.join(s)
|
| 28 |
+
s = first + '\n' + s
|
| 29 |
+
return s
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Container(Module):
|
| 33 |
+
|
| 34 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
# DeprecationWarning is ignored by default <sigh>
|
| 37 |
+
warnings.warn("nn.Container is deprecated. All of it's functionality "
|
| 38 |
+
"is now implemented in nn.Module. Subclass that instead.")
|
| 39 |
+
for key, value in kwargs.items():
|
| 40 |
+
self.add_module(key, value)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Sequential(Module):
|
| 44 |
+
r"""A sequential container.
|
| 45 |
+
|
| 46 |
+
Modules will be added to it in the order they are passed in the
|
| 47 |
+
constructor. Alternatively, an ``OrderedDict`` of modules can be
|
| 48 |
+
passed in. The ``forward()`` method of ``Sequential`` accepts any
|
| 49 |
+
input and forwards it to the first module it contains. It then
|
| 50 |
+
"chains" outputs to inputs sequentially for each subsequent module,
|
| 51 |
+
finally returning the output of the last module.
|
| 52 |
+
|
| 53 |
+
The value a ``Sequential`` provides over manually calling a sequence
|
| 54 |
+
of modules is that it allows treating the whole container as a
|
| 55 |
+
single module, such that performing a transformation on the
|
| 56 |
+
``Sequential`` applies to each of the modules it stores (which are
|
| 57 |
+
each a registered submodule of the ``Sequential``).
|
| 58 |
+
|
| 59 |
+
What's the difference between a ``Sequential`` and a
|
| 60 |
+
:class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
|
| 61 |
+
sounds like--a list for storing ``Module`` s! On the other hand,
|
| 62 |
+
the layers in a ``Sequential`` are connected in a cascading way.
|
| 63 |
+
|
| 64 |
+
Example::
|
| 65 |
+
|
| 66 |
+
# Using Sequential to create a small model. When `model` is run,
|
| 67 |
+
# input will first be passed to `Conv2d(1,20,5)`. The output of
|
| 68 |
+
# `Conv2d(1,20,5)` will be used as the input to the first
|
| 69 |
+
# `ReLU`; the output of the first `ReLU` will become the input
|
| 70 |
+
# for `Conv2d(20,64,5)`. Finally, the output of
|
| 71 |
+
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
|
| 72 |
+
model = nn.Sequential(
|
| 73 |
+
nn.Conv2d(1,20,5),
|
| 74 |
+
nn.ReLU(),
|
| 75 |
+
nn.Conv2d(20,64,5),
|
| 76 |
+
nn.ReLU()
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Using Sequential with OrderedDict. This is functionally the
|
| 80 |
+
# same as the above code
|
| 81 |
+
model = nn.Sequential(OrderedDict([
|
| 82 |
+
('conv1', nn.Conv2d(1,20,5)),
|
| 83 |
+
('relu1', nn.ReLU()),
|
| 84 |
+
('conv2', nn.Conv2d(20,64,5)),
|
| 85 |
+
('relu2', nn.ReLU())
|
| 86 |
+
]))
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
_modules: Dict[str, Module] # type: ignore[assignment]
|
| 90 |
+
|
| 91 |
+
@overload
|
| 92 |
+
def __init__(self, *args: Module) -> None:
|
| 93 |
+
...
|
| 94 |
+
|
| 95 |
+
@overload
|
| 96 |
+
def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
|
| 97 |
+
...
|
| 98 |
+
|
| 99 |
+
def __init__(self, *args):
|
| 100 |
+
super().__init__()
|
| 101 |
+
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
| 102 |
+
for key, module in args[0].items():
|
| 103 |
+
self.add_module(key, module)
|
| 104 |
+
else:
|
| 105 |
+
for idx, module in enumerate(args):
|
| 106 |
+
self.add_module(str(idx), module)
|
| 107 |
+
|
| 108 |
+
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
|
| 109 |
+
"""Get the idx-th item of the iterator."""
|
| 110 |
+
size = len(self)
|
| 111 |
+
idx = operator.index(idx)
|
| 112 |
+
if not -size <= idx < size:
|
| 113 |
+
raise IndexError(f'index {idx} is out of range')
|
| 114 |
+
idx %= size
|
| 115 |
+
return next(islice(iterator, idx, None))
|
| 116 |
+
|
| 117 |
+
@_copy_to_script_wrapper
|
| 118 |
+
def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]:
|
| 119 |
+
if isinstance(idx, slice):
|
| 120 |
+
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
| 121 |
+
else:
|
| 122 |
+
return self._get_item_by_idx(self._modules.values(), idx)
|
| 123 |
+
|
| 124 |
+
def __setitem__(self, idx: int, module: Module) -> None:
|
| 125 |
+
key: str = self._get_item_by_idx(self._modules.keys(), idx)
|
| 126 |
+
return setattr(self, key, module)
|
| 127 |
+
|
| 128 |
+
def __delitem__(self, idx: Union[slice, int]) -> None:
|
| 129 |
+
if isinstance(idx, slice):
|
| 130 |
+
for key in list(self._modules.keys())[idx]:
|
| 131 |
+
delattr(self, key)
|
| 132 |
+
else:
|
| 133 |
+
key = self._get_item_by_idx(self._modules.keys(), idx)
|
| 134 |
+
delattr(self, key)
|
| 135 |
+
# To preserve numbering
|
| 136 |
+
str_indices = [str(i) for i in range(len(self._modules))]
|
| 137 |
+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
| 138 |
+
|
| 139 |
+
@_copy_to_script_wrapper
|
| 140 |
+
def __len__(self) -> int:
|
| 141 |
+
return len(self._modules)
|
| 142 |
+
|
| 143 |
+
def __add__(self, other) -> 'Sequential':
|
| 144 |
+
if isinstance(other, Sequential):
|
| 145 |
+
ret = Sequential()
|
| 146 |
+
for layer in self:
|
| 147 |
+
ret.append(layer)
|
| 148 |
+
for layer in other:
|
| 149 |
+
ret.append(layer)
|
| 150 |
+
return ret
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError('add operator supports only objects '
|
| 153 |
+
f'of Sequential class, but {str(type(other))} is given.')
|
| 154 |
+
|
| 155 |
+
def pop(self, key: Union[int, slice]) -> Module:
|
| 156 |
+
v = self[key]
|
| 157 |
+
del self[key]
|
| 158 |
+
return v
|
| 159 |
+
|
| 160 |
+
def __iadd__(self, other) -> Self:
|
| 161 |
+
if isinstance(other, Sequential):
|
| 162 |
+
offset = len(self)
|
| 163 |
+
for i, module in enumerate(other):
|
| 164 |
+
self.add_module(str(i + offset), module)
|
| 165 |
+
return self
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError('add operator supports only objects '
|
| 168 |
+
f'of Sequential class, but {str(type(other))} is given.')
|
| 169 |
+
|
| 170 |
+
def __mul__(self, other: int) -> 'Sequential':
|
| 171 |
+
if not isinstance(other, int):
|
| 172 |
+
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
|
| 173 |
+
elif (other <= 0):
|
| 174 |
+
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
|
| 175 |
+
else:
|
| 176 |
+
combined = Sequential()
|
| 177 |
+
offset = 0
|
| 178 |
+
for _ in range(other):
|
| 179 |
+
for module in self:
|
| 180 |
+
combined.add_module(str(offset), module)
|
| 181 |
+
offset += 1
|
| 182 |
+
return combined
|
| 183 |
+
|
| 184 |
+
def __rmul__(self, other: int) -> 'Sequential':
|
| 185 |
+
return self.__mul__(other)
|
| 186 |
+
|
| 187 |
+
def __imul__(self, other: int) -> Self:
|
| 188 |
+
if not isinstance(other, int):
|
| 189 |
+
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
|
| 190 |
+
elif (other <= 0):
|
| 191 |
+
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
|
| 192 |
+
else:
|
| 193 |
+
len_original = len(self)
|
| 194 |
+
offset = len(self)
|
| 195 |
+
for _ in range(other - 1):
|
| 196 |
+
for i in range(len_original):
|
| 197 |
+
self.add_module(str(i + offset), self._modules[str(i)])
|
| 198 |
+
offset += len_original
|
| 199 |
+
return self
|
| 200 |
+
|
| 201 |
+
@_copy_to_script_wrapper
|
| 202 |
+
def __dir__(self):
|
| 203 |
+
keys = super().__dir__()
|
| 204 |
+
keys = [key for key in keys if not key.isdigit()]
|
| 205 |
+
return keys
|
| 206 |
+
|
| 207 |
+
@_copy_to_script_wrapper
|
| 208 |
+
def __iter__(self) -> Iterator[Module]:
|
| 209 |
+
return iter(self._modules.values())
|
| 210 |
+
|
| 211 |
+
# NB: We can't really type check this function as the type of input
|
| 212 |
+
# may change dynamically (as is tested in
|
| 213 |
+
# TestScript.test_sequential_intermediary_types). Cannot annotate
|
| 214 |
+
# with Any as TorchScript expects a more precise type
|
| 215 |
+
def forward(self, input):
|
| 216 |
+
for module in self:
|
| 217 |
+
input = module(input)
|
| 218 |
+
return input
|
| 219 |
+
|
| 220 |
+
def append(self, module: Module) -> 'Sequential':
|
| 221 |
+
r"""Append a given module to the end.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
module (nn.Module): module to append
|
| 225 |
+
"""
|
| 226 |
+
self.add_module(str(len(self)), module)
|
| 227 |
+
return self
|
| 228 |
+
|
| 229 |
+
def insert(self, index: int, module: Module) -> 'Sequential':
|
| 230 |
+
if not isinstance(module, Module):
|
| 231 |
+
raise AssertionError(
|
| 232 |
+
f'module should be of type: {Module}')
|
| 233 |
+
n = len(self._modules)
|
| 234 |
+
if not (-n <= index <= n):
|
| 235 |
+
raise IndexError(
|
| 236 |
+
f'Index out of range: {index}')
|
| 237 |
+
if index < 0:
|
| 238 |
+
index += n
|
| 239 |
+
for i in range(n, index, -1):
|
| 240 |
+
self._modules[str(i)] = self._modules[str(i - 1)]
|
| 241 |
+
self._modules[str(index)] = module
|
| 242 |
+
return self
|
| 243 |
+
|
| 244 |
+
def extend(self, sequential) -> 'Sequential':
|
| 245 |
+
for layer in sequential:
|
| 246 |
+
self.append(layer)
|
| 247 |
+
return self
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class ModuleList(Module):
|
| 251 |
+
r"""Holds submodules in a list.
|
| 252 |
+
|
| 253 |
+
:class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
|
| 254 |
+
modules it contains are properly registered, and will be visible by all
|
| 255 |
+
:class:`~torch.nn.Module` methods.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
modules (iterable, optional): an iterable of modules to add
|
| 259 |
+
|
| 260 |
+
Example::
|
| 261 |
+
|
| 262 |
+
class MyModule(nn.Module):
|
| 263 |
+
def __init__(self):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
# ModuleList can act as an iterable, or be indexed using ints
|
| 269 |
+
for i, l in enumerate(self.linears):
|
| 270 |
+
x = self.linears[i // 2](x) + l(x)
|
| 271 |
+
return x
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
_modules: Dict[str, Module] # type: ignore[assignment]
|
| 275 |
+
|
| 276 |
+
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
|
| 277 |
+
super().__init__()
|
| 278 |
+
if modules is not None:
|
| 279 |
+
self += modules
|
| 280 |
+
|
| 281 |
+
def _get_abs_string_index(self, idx):
|
| 282 |
+
"""Get the absolute index for the list of modules."""
|
| 283 |
+
idx = operator.index(idx)
|
| 284 |
+
if not (-len(self) <= idx < len(self)):
|
| 285 |
+
raise IndexError(f'index {idx} is out of range')
|
| 286 |
+
if idx < 0:
|
| 287 |
+
idx += len(self)
|
| 288 |
+
return str(idx)
|
| 289 |
+
|
| 290 |
+
@_copy_to_script_wrapper
|
| 291 |
+
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
|
| 292 |
+
if isinstance(idx, slice):
|
| 293 |
+
return self.__class__(list(self._modules.values())[idx])
|
| 294 |
+
else:
|
| 295 |
+
return self._modules[self._get_abs_string_index(idx)]
|
| 296 |
+
|
| 297 |
+
def __setitem__(self, idx: int, module: Module) -> None:
|
| 298 |
+
idx = self._get_abs_string_index(idx)
|
| 299 |
+
return setattr(self, str(idx), module)
|
| 300 |
+
|
| 301 |
+
def __delitem__(self, idx: Union[int, slice]) -> None:
|
| 302 |
+
if isinstance(idx, slice):
|
| 303 |
+
for k in range(len(self._modules))[idx]:
|
| 304 |
+
delattr(self, str(k))
|
| 305 |
+
else:
|
| 306 |
+
delattr(self, self._get_abs_string_index(idx))
|
| 307 |
+
# To preserve numbering, self._modules is being reconstructed with modules after deletion
|
| 308 |
+
str_indices = [str(i) for i in range(len(self._modules))]
|
| 309 |
+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
| 310 |
+
|
| 311 |
+
@_copy_to_script_wrapper
|
| 312 |
+
def __len__(self) -> int:
|
| 313 |
+
return len(self._modules)
|
| 314 |
+
|
| 315 |
+
@_copy_to_script_wrapper
|
| 316 |
+
def __iter__(self) -> Iterator[Module]:
|
| 317 |
+
return iter(self._modules.values())
|
| 318 |
+
|
| 319 |
+
def __iadd__(self, modules: Iterable[Module]) -> Self:
|
| 320 |
+
return self.extend(modules)
|
| 321 |
+
|
| 322 |
+
def __add__(self, other: Iterable[Module]) -> 'ModuleList':
|
| 323 |
+
combined = ModuleList()
|
| 324 |
+
for i, module in enumerate(chain(self, other)):
|
| 325 |
+
combined.add_module(str(i), module)
|
| 326 |
+
return combined
|
| 327 |
+
|
| 328 |
+
def __repr__(self):
|
| 329 |
+
"""Return a custom repr for ModuleList that compresses repeated module representations."""
|
| 330 |
+
list_of_reprs = [repr(item) for item in self]
|
| 331 |
+
if len(list_of_reprs) == 0:
|
| 332 |
+
return self._get_name() + '()'
|
| 333 |
+
|
| 334 |
+
start_end_indices = [[0, 0]]
|
| 335 |
+
repeated_blocks = [list_of_reprs[0]]
|
| 336 |
+
for i, r in enumerate(list_of_reprs[1:], 1):
|
| 337 |
+
if r == repeated_blocks[-1]:
|
| 338 |
+
start_end_indices[-1][1] += 1
|
| 339 |
+
continue
|
| 340 |
+
|
| 341 |
+
start_end_indices.append([i, i])
|
| 342 |
+
repeated_blocks.append(r)
|
| 343 |
+
|
| 344 |
+
lines = []
|
| 345 |
+
main_str = self._get_name() + '('
|
| 346 |
+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
| 347 |
+
local_repr = f"({start_id}): {b}" # default repr
|
| 348 |
+
|
| 349 |
+
if start_id != end_id:
|
| 350 |
+
n = end_id - start_id + 1
|
| 351 |
+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
| 352 |
+
|
| 353 |
+
local_repr = _addindent(local_repr, 2)
|
| 354 |
+
lines.append(local_repr)
|
| 355 |
+
|
| 356 |
+
main_str += '\n ' + '\n '.join(lines) + '\n'
|
| 357 |
+
main_str += ')'
|
| 358 |
+
return main_str
|
| 359 |
+
|
| 360 |
+
@_copy_to_script_wrapper
|
| 361 |
+
def __dir__(self):
|
| 362 |
+
keys = super().__dir__()
|
| 363 |
+
keys = [key for key in keys if not key.isdigit()]
|
| 364 |
+
return keys
|
| 365 |
+
|
| 366 |
+
def insert(self, index: int, module: Module) -> None:
|
| 367 |
+
r"""Insert a given module before a given index in the list.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
index (int): index to insert.
|
| 371 |
+
module (nn.Module): module to insert
|
| 372 |
+
"""
|
| 373 |
+
for i in range(len(self._modules), index, -1):
|
| 374 |
+
self._modules[str(i)] = self._modules[str(i - 1)]
|
| 375 |
+
self._modules[str(index)] = module
|
| 376 |
+
|
| 377 |
+
def append(self, module: Module) -> 'ModuleList':
|
| 378 |
+
r"""Append a given module to the end of the list.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
module (nn.Module): module to append
|
| 382 |
+
"""
|
| 383 |
+
self.add_module(str(len(self)), module)
|
| 384 |
+
return self
|
| 385 |
+
|
| 386 |
+
def pop(self, key: Union[int, slice]) -> Module:
|
| 387 |
+
v = self[key]
|
| 388 |
+
del self[key]
|
| 389 |
+
return v
|
| 390 |
+
|
| 391 |
+
def extend(self, modules: Iterable[Module]) -> Self:
|
| 392 |
+
r"""Append modules from a Python iterable to the end of the list.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
modules (iterable): iterable of modules to append
|
| 396 |
+
"""
|
| 397 |
+
if not isinstance(modules, container_abcs.Iterable):
|
| 398 |
+
raise TypeError("ModuleList.extend should be called with an "
|
| 399 |
+
"iterable, but got " + type(modules).__name__)
|
| 400 |
+
offset = len(self)
|
| 401 |
+
for i, module in enumerate(modules):
|
| 402 |
+
self.add_module(str(offset + i), module)
|
| 403 |
+
return self
|
| 404 |
+
|
| 405 |
+
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class ModuleDict(Module):
|
| 409 |
+
r"""Holds submodules in a dictionary.
|
| 410 |
+
|
| 411 |
+
:class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
|
| 412 |
+
but modules it contains are properly registered, and will be visible by all
|
| 413 |
+
:class:`~torch.nn.Module` methods.
|
| 414 |
+
|
| 415 |
+
:class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
|
| 416 |
+
|
| 417 |
+
* the order of insertion, and
|
| 418 |
+
|
| 419 |
+
* in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
|
| 420 |
+
``OrderedDict``, ``dict`` (started from Python 3.6) or another
|
| 421 |
+
:class:`~torch.nn.ModuleDict` (the argument to
|
| 422 |
+
:meth:`~torch.nn.ModuleDict.update`).
|
| 423 |
+
|
| 424 |
+
Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
|
| 425 |
+
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
|
| 426 |
+
preserve the order of the merged mapping.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
modules (iterable, optional): a mapping (dictionary) of (string: module)
|
| 430 |
+
or an iterable of key-value pairs of type (string, module)
|
| 431 |
+
|
| 432 |
+
Example::
|
| 433 |
+
|
| 434 |
+
class MyModule(nn.Module):
|
| 435 |
+
def __init__(self):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.choices = nn.ModuleDict({
|
| 438 |
+
'conv': nn.Conv2d(10, 10, 3),
|
| 439 |
+
'pool': nn.MaxPool2d(3)
|
| 440 |
+
})
|
| 441 |
+
self.activations = nn.ModuleDict([
|
| 442 |
+
['lrelu', nn.LeakyReLU()],
|
| 443 |
+
['prelu', nn.PReLU()]
|
| 444 |
+
])
|
| 445 |
+
|
| 446 |
+
def forward(self, x, choice, act):
|
| 447 |
+
x = self.choices[choice](x)
|
| 448 |
+
x = self.activations[act](x)
|
| 449 |
+
return x
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
_modules: Dict[str, Module] # type: ignore[assignment]
|
| 453 |
+
|
| 454 |
+
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
| 455 |
+
super().__init__()
|
| 456 |
+
if modules is not None:
|
| 457 |
+
self.update(modules)
|
| 458 |
+
|
| 459 |
+
@_copy_to_script_wrapper
|
| 460 |
+
def __getitem__(self, key: str) -> Module:
|
| 461 |
+
return self._modules[key]
|
| 462 |
+
|
| 463 |
+
def __setitem__(self, key: str, module: Module) -> None:
|
| 464 |
+
self.add_module(key, module)
|
| 465 |
+
|
| 466 |
+
def __delitem__(self, key: str) -> None:
|
| 467 |
+
del self._modules[key]
|
| 468 |
+
|
| 469 |
+
@_copy_to_script_wrapper
|
| 470 |
+
def __len__(self) -> int:
|
| 471 |
+
return len(self._modules)
|
| 472 |
+
|
| 473 |
+
@_copy_to_script_wrapper
|
| 474 |
+
def __iter__(self) -> Iterator[str]:
|
| 475 |
+
return iter(self._modules)
|
| 476 |
+
|
| 477 |
+
@_copy_to_script_wrapper
|
| 478 |
+
def __contains__(self, key: str) -> bool:
|
| 479 |
+
return key in self._modules
|
| 480 |
+
|
| 481 |
+
def clear(self) -> None:
|
| 482 |
+
"""Remove all items from the ModuleDict."""
|
| 483 |
+
self._modules.clear()
|
| 484 |
+
|
| 485 |
+
def pop(self, key: str) -> Module:
|
| 486 |
+
r"""Remove key from the ModuleDict and return its module.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
key (str): key to pop from the ModuleDict
|
| 490 |
+
"""
|
| 491 |
+
v = self[key]
|
| 492 |
+
del self[key]
|
| 493 |
+
return v
|
| 494 |
+
|
| 495 |
+
@_copy_to_script_wrapper
|
| 496 |
+
def keys(self) -> Iterable[str]:
|
| 497 |
+
r"""Return an iterable of the ModuleDict keys."""
|
| 498 |
+
return self._modules.keys()
|
| 499 |
+
|
| 500 |
+
@_copy_to_script_wrapper
|
| 501 |
+
def items(self) -> Iterable[Tuple[str, Module]]:
|
| 502 |
+
r"""Return an iterable of the ModuleDict key/value pairs."""
|
| 503 |
+
return self._modules.items()
|
| 504 |
+
|
| 505 |
+
@_copy_to_script_wrapper
|
| 506 |
+
def values(self) -> Iterable[Module]:
|
| 507 |
+
r"""Return an iterable of the ModuleDict values."""
|
| 508 |
+
return self._modules.values()
|
| 509 |
+
|
| 510 |
+
def update(self, modules: Mapping[str, Module]) -> None:
|
| 511 |
+
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
|
| 512 |
+
|
| 513 |
+
.. note::
|
| 514 |
+
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
|
| 515 |
+
an iterable of key-value pairs, the order of new elements in it is preserved.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
|
| 519 |
+
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
|
| 520 |
+
"""
|
| 521 |
+
if not isinstance(modules, container_abcs.Iterable):
|
| 522 |
+
raise TypeError("ModuleDict.update should be called with an "
|
| 523 |
+
"iterable of key/value pairs, but got " +
|
| 524 |
+
type(modules).__name__)
|
| 525 |
+
|
| 526 |
+
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
|
| 527 |
+
for key, module in modules.items():
|
| 528 |
+
self[key] = module
|
| 529 |
+
else:
|
| 530 |
+
# modules here can be a list with two items
|
| 531 |
+
for j, m in enumerate(modules):
|
| 532 |
+
if not isinstance(m, container_abcs.Iterable):
|
| 533 |
+
raise TypeError("ModuleDict update sequence element "
|
| 534 |
+
"#" + str(j) + " should be Iterable; is" +
|
| 535 |
+
type(m).__name__)
|
| 536 |
+
if not len(m) == 2:
|
| 537 |
+
raise ValueError("ModuleDict update sequence element "
|
| 538 |
+
"#" + str(j) + " has length " + str(len(m)) +
|
| 539 |
+
"; 2 is required")
|
| 540 |
+
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
| 541 |
+
# that's too cumbersome to type correctly with overloads, so we add an ignore here
|
| 542 |
+
self[m[0]] = m[1] # type: ignore[assignment]
|
| 543 |
+
|
| 544 |
+
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class ParameterList(Module):
|
| 548 |
+
r"""Holds parameters in a list.
|
| 549 |
+
|
| 550 |
+
:class:`~torch.nn.ParameterList` can be used like a regular Python
|
| 551 |
+
list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
|
| 552 |
+
and will be visible by all :class:`~torch.nn.Module` methods.
|
| 553 |
+
|
| 554 |
+
Note that the constructor, assigning an element of the list, the
|
| 555 |
+
:meth:`~torch.nn.ParameterDict.append` method and the :meth:`~torch.nn.ParameterDict.extend`
|
| 556 |
+
method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
parameters (iterable, optional): an iterable of elements to add to the list.
|
| 560 |
+
|
| 561 |
+
Example::
|
| 562 |
+
|
| 563 |
+
class MyModule(nn.Module):
|
| 564 |
+
def __init__(self):
|
| 565 |
+
super().__init__()
|
| 566 |
+
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
|
| 567 |
+
|
| 568 |
+
def forward(self, x):
|
| 569 |
+
# ParameterList can act as an iterable, or be indexed using ints
|
| 570 |
+
for i, p in enumerate(self.params):
|
| 571 |
+
x = self.params[i // 2].mm(x) + p.mm(x)
|
| 572 |
+
return x
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
|
| 576 |
+
super().__init__()
|
| 577 |
+
self._size = 0
|
| 578 |
+
if values is not None:
|
| 579 |
+
self += values
|
| 580 |
+
|
| 581 |
+
def _get_abs_string_index(self, idx):
|
| 582 |
+
"""Get the absolute index for the list of modules."""
|
| 583 |
+
idx = operator.index(idx)
|
| 584 |
+
if not (-len(self) <= idx < len(self)):
|
| 585 |
+
raise IndexError(f'index {idx} is out of range')
|
| 586 |
+
if idx < 0:
|
| 587 |
+
idx += len(self)
|
| 588 |
+
return str(idx)
|
| 589 |
+
|
| 590 |
+
@overload
|
| 591 |
+
def __getitem__(self, idx: int) -> Any:
|
| 592 |
+
...
|
| 593 |
+
|
| 594 |
+
@overload
|
| 595 |
+
def __getitem__(self: T, idx: slice) -> T:
|
| 596 |
+
...
|
| 597 |
+
|
| 598 |
+
def __getitem__(self, idx):
|
| 599 |
+
if isinstance(idx, slice):
|
| 600 |
+
start, stop, step = idx.indices(len(self))
|
| 601 |
+
out = self.__class__()
|
| 602 |
+
for i in range(start, stop, step):
|
| 603 |
+
out.append(self[i])
|
| 604 |
+
return out
|
| 605 |
+
else:
|
| 606 |
+
idx = self._get_abs_string_index(idx)
|
| 607 |
+
return getattr(self, str(idx))
|
| 608 |
+
|
| 609 |
+
def __setitem__(self, idx: int, param: Any) -> None:
|
| 610 |
+
# Note that all other function that add an entry to the list part of
|
| 611 |
+
# the ParameterList end up here. So this is the only place where we need
|
| 612 |
+
# to wrap things into Parameter if needed.
|
| 613 |
+
# Objects added via setattr() are not in the list part and thus won't
|
| 614 |
+
# call into this function.
|
| 615 |
+
idx = self._get_abs_string_index(idx)
|
| 616 |
+
if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
|
| 617 |
+
param = Parameter(param)
|
| 618 |
+
return setattr(self, str(idx), param)
|
| 619 |
+
|
| 620 |
+
def __len__(self) -> int:
|
| 621 |
+
return self._size
|
| 622 |
+
|
| 623 |
+
def __iter__(self) -> Iterator[Any]:
|
| 624 |
+
return iter(self[i] for i in range(len(self)))
|
| 625 |
+
|
| 626 |
+
def __iadd__(self, parameters: Iterable[Any]) -> Self:
|
| 627 |
+
return self.extend(parameters)
|
| 628 |
+
|
| 629 |
+
def __dir__(self):
|
| 630 |
+
keys = super().__dir__()
|
| 631 |
+
keys = [key for key in keys if not key.isdigit()]
|
| 632 |
+
return keys
|
| 633 |
+
|
| 634 |
+
def append(self, value: Any) -> 'ParameterList':
|
| 635 |
+
"""Append a given value at the end of the list.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
value (Any): value to append
|
| 639 |
+
"""
|
| 640 |
+
new_idx = len(self)
|
| 641 |
+
self._size += 1
|
| 642 |
+
self[new_idx] = value
|
| 643 |
+
return self
|
| 644 |
+
|
| 645 |
+
def extend(self, values: Iterable[Any]) -> Self:
|
| 646 |
+
"""Append values from a Python iterable to the end of the list.
|
| 647 |
+
|
| 648 |
+
Args:
|
| 649 |
+
values (iterable): iterable of values to append
|
| 650 |
+
"""
|
| 651 |
+
# Tensor is an iterable but we never want to unpack it here
|
| 652 |
+
if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch.Tensor):
|
| 653 |
+
raise TypeError("ParameterList.extend should be called with an "
|
| 654 |
+
"iterable, but got " + type(values).__name__)
|
| 655 |
+
for value in values:
|
| 656 |
+
self.append(value)
|
| 657 |
+
return self
|
| 658 |
+
|
| 659 |
+
def extra_repr(self) -> str:
|
| 660 |
+
child_lines = []
|
| 661 |
+
for k, p in enumerate(self):
|
| 662 |
+
if isinstance(p, torch.Tensor):
|
| 663 |
+
size_str = 'x'.join(str(size) for size in p.size())
|
| 664 |
+
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
| 665 |
+
device_str = f' ({p.device})'
|
| 666 |
+
else:
|
| 667 |
+
device_str = ''
|
| 668 |
+
parastr = '{} containing: [{} of size {}{}]'.format(
|
| 669 |
+
"Parameter" if isinstance(p, Parameter) else "Tensor",
|
| 670 |
+
p.dtype, size_str, device_str)
|
| 671 |
+
child_lines.append(' (' + str(k) + '): ' + parastr)
|
| 672 |
+
else:
|
| 673 |
+
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
|
| 674 |
+
|
| 675 |
+
tmpstr = '\n'.join(child_lines)
|
| 676 |
+
return tmpstr
|
| 677 |
+
|
| 678 |
+
def __call__(self, *args, **kwargs):
|
| 679 |
+
raise RuntimeError('ParameterList should not be called.')
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
class ParameterDict(Module):
|
| 683 |
+
r"""Holds parameters in a dictionary.
|
| 684 |
+
|
| 685 |
+
ParameterDict can be indexed like a regular Python dictionary, but Parameters it
|
| 686 |
+
contains are properly registered, and will be visible by all Module methods.
|
| 687 |
+
Other objects are treated as would be done by a regular Python dictionary
|
| 688 |
+
|
| 689 |
+
:class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
|
| 690 |
+
:meth:`~torch.nn.ParameterDict.update` with other unordered mapping
|
| 691 |
+
types (e.g., Python's plain ``dict``) does not preserve the order of the
|
| 692 |
+
merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
|
| 693 |
+
will preserve their ordering.
|
| 694 |
+
|
| 695 |
+
Note that the constructor, assigning an element of the dictionary and the
|
| 696 |
+
:meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
|
| 697 |
+
:class:`~torch.nn.Parameter`.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
values (iterable, optional): a mapping (dictionary) of
|
| 701 |
+
(string : Any) or an iterable of key-value pairs
|
| 702 |
+
of type (string, Any)
|
| 703 |
+
|
| 704 |
+
Example::
|
| 705 |
+
|
| 706 |
+
class MyModule(nn.Module):
|
| 707 |
+
def __init__(self):
|
| 708 |
+
super().__init__()
|
| 709 |
+
self.params = nn.ParameterDict({
|
| 710 |
+
'left': nn.Parameter(torch.randn(5, 10)),
|
| 711 |
+
'right': nn.Parameter(torch.randn(5, 10))
|
| 712 |
+
})
|
| 713 |
+
|
| 714 |
+
def forward(self, x, choice):
|
| 715 |
+
x = self.params[choice].mm(x)
|
| 716 |
+
return x
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
def __init__(self, parameters: Any = None) -> None:
|
| 720 |
+
super().__init__()
|
| 721 |
+
self._keys: Dict[str, None] = {}
|
| 722 |
+
if parameters is not None:
|
| 723 |
+
self.update(parameters)
|
| 724 |
+
|
| 725 |
+
def _key_to_attr(self, key: str) -> str:
|
| 726 |
+
if not isinstance(key, str):
|
| 727 |
+
raise TypeError("Index given to ParameterDict cannot be used as a key as it is "
|
| 728 |
+
f"not a string (type is '{type(key).__name__}'). Open an issue on "
|
| 729 |
+
"github if you need non-string keys.")
|
| 730 |
+
else:
|
| 731 |
+
# Use the key as-is so that `.named_parameters()` returns the right thing
|
| 732 |
+
return key
|
| 733 |
+
|
| 734 |
+
def __getitem__(self, key: str) -> Any:
|
| 735 |
+
attr = self._key_to_attr(key)
|
| 736 |
+
return getattr(self, attr)
|
| 737 |
+
|
| 738 |
+
def __setitem__(self, key: str, value: Any) -> None:
|
| 739 |
+
# Note that all other function that add an entry to the dictionary part of
|
| 740 |
+
# the ParameterDict end up here. So this is the only place where we need
|
| 741 |
+
# to wrap things into Parameter if needed.
|
| 742 |
+
# Objects added via setattr() are not in the dictionary part and thus won't
|
| 743 |
+
# call into this function.
|
| 744 |
+
self._keys[key] = None
|
| 745 |
+
attr = self._key_to_attr(key)
|
| 746 |
+
if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
|
| 747 |
+
value = Parameter(value)
|
| 748 |
+
setattr(self, attr, value)
|
| 749 |
+
|
| 750 |
+
def __delitem__(self, key: str) -> None:
|
| 751 |
+
del self._keys[key]
|
| 752 |
+
attr = self._key_to_attr(key)
|
| 753 |
+
delattr(self, attr)
|
| 754 |
+
|
| 755 |
+
def __len__(self) -> int:
|
| 756 |
+
return len(self._keys)
|
| 757 |
+
|
| 758 |
+
def __iter__(self) -> Iterator[str]:
|
| 759 |
+
return iter(self._keys)
|
| 760 |
+
|
| 761 |
+
def __reversed__(self) -> Iterator[str]:
|
| 762 |
+
return reversed(list(self._keys))
|
| 763 |
+
|
| 764 |
+
def copy(self) -> 'ParameterDict':
|
| 765 |
+
"""Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
|
| 766 |
+
# We have to use an OrderedDict because the ParameterDict constructor
|
| 767 |
+
# behaves differently on plain dict vs OrderedDict
|
| 768 |
+
return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
|
| 769 |
+
|
| 770 |
+
def __contains__(self, key: str) -> bool:
|
| 771 |
+
return key in self._keys
|
| 772 |
+
|
| 773 |
+
def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
|
| 774 |
+
"""Set the default for a key in the Parameterdict.
|
| 775 |
+
|
| 776 |
+
If key is in the ParameterDict, return its value.
|
| 777 |
+
If not, insert `key` with a parameter `default` and return `default`.
|
| 778 |
+
`default` defaults to `None`.
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
key (str): key to set default for
|
| 782 |
+
default (Any): the parameter set to the key
|
| 783 |
+
"""
|
| 784 |
+
if key not in self:
|
| 785 |
+
self[key] = default
|
| 786 |
+
return self[key]
|
| 787 |
+
|
| 788 |
+
def clear(self) -> None:
|
| 789 |
+
"""Remove all items from the ParameterDict."""
|
| 790 |
+
for k in self._keys.copy():
|
| 791 |
+
del self[k]
|
| 792 |
+
|
| 793 |
+
def pop(self, key: str) -> Any:
|
| 794 |
+
r"""Remove key from the ParameterDict and return its parameter.
|
| 795 |
+
|
| 796 |
+
Args:
|
| 797 |
+
key (str): key to pop from the ParameterDict
|
| 798 |
+
"""
|
| 799 |
+
v = self[key]
|
| 800 |
+
del self[key]
|
| 801 |
+
return v
|
| 802 |
+
|
| 803 |
+
def popitem(self) -> Tuple[str, Any]:
|
| 804 |
+
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
|
| 805 |
+
k, _ = self._keys.popitem()
|
| 806 |
+
# We need the key in the _keys to be able to access/del
|
| 807 |
+
self._keys[k] = None
|
| 808 |
+
val = self[k]
|
| 809 |
+
del self[k]
|
| 810 |
+
return k, val
|
| 811 |
+
|
| 812 |
+
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
| 813 |
+
r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
|
| 814 |
+
|
| 815 |
+
Args:
|
| 816 |
+
key (str): key to get from the ParameterDict
|
| 817 |
+
default (Parameter, optional): value to return if key not present
|
| 818 |
+
"""
|
| 819 |
+
return self[key] if key in self else default
|
| 820 |
+
|
| 821 |
+
def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict':
|
| 822 |
+
r"""Return a new ParameterDict with the keys provided.
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
keys (iterable, string): keys to make the new ParameterDict from
|
| 826 |
+
default (Parameter, optional): value to set for all keys
|
| 827 |
+
"""
|
| 828 |
+
return ParameterDict((k, default) for k in keys)
|
| 829 |
+
|
| 830 |
+
def keys(self) -> Iterable[str]:
|
| 831 |
+
r"""Return an iterable of the ParameterDict keys."""
|
| 832 |
+
return self._keys.keys()
|
| 833 |
+
|
| 834 |
+
def items(self) -> Iterable[Tuple[str, Any]]:
|
| 835 |
+
r"""Return an iterable of the ParameterDict key/value pairs."""
|
| 836 |
+
return ((k, self[k]) for k in self._keys)
|
| 837 |
+
|
| 838 |
+
def values(self) -> Iterable[Any]:
|
| 839 |
+
r"""Return an iterable of the ParameterDict values."""
|
| 840 |
+
return (self[k] for k in self._keys)
|
| 841 |
+
|
| 842 |
+
def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None:
|
| 843 |
+
r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
|
| 844 |
+
|
| 845 |
+
.. note::
|
| 846 |
+
If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
|
| 847 |
+
an iterable of key-value pairs, the order of new elements in it is preserved.
|
| 848 |
+
|
| 849 |
+
Args:
|
| 850 |
+
parameters (iterable): a mapping (dictionary) from string to
|
| 851 |
+
:class:`~torch.nn.Parameter`, or an iterable of
|
| 852 |
+
key-value pairs of type (string, :class:`~torch.nn.Parameter`)
|
| 853 |
+
"""
|
| 854 |
+
if not isinstance(parameters, container_abcs.Iterable):
|
| 855 |
+
raise TypeError("ParametersDict.update should be called with an "
|
| 856 |
+
"iterable of key/value pairs, but got " +
|
| 857 |
+
type(parameters).__name__)
|
| 858 |
+
|
| 859 |
+
if isinstance(parameters, (OrderedDict, ParameterDict)):
|
| 860 |
+
for key, parameter in parameters.items():
|
| 861 |
+
self[key] = parameter
|
| 862 |
+
elif isinstance(parameters, container_abcs.Mapping):
|
| 863 |
+
for key, parameter in sorted(parameters.items()):
|
| 864 |
+
self[key] = parameter
|
| 865 |
+
else:
|
| 866 |
+
for j, p in enumerate(parameters):
|
| 867 |
+
if not isinstance(p, container_abcs.Iterable):
|
| 868 |
+
raise TypeError("ParameterDict update sequence element "
|
| 869 |
+
"#" + str(j) + " should be Iterable; is" +
|
| 870 |
+
type(p).__name__)
|
| 871 |
+
if not len(p) == 2:
|
| 872 |
+
raise ValueError("ParameterDict update sequence element "
|
| 873 |
+
"#" + str(j) + " has length " + str(len(p)) +
|
| 874 |
+
"; 2 is required")
|
| 875 |
+
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
|
| 876 |
+
self[p[0]] = p[1] # type: ignore[assignment]
|
| 877 |
+
|
| 878 |
+
def extra_repr(self) -> str:
|
| 879 |
+
child_lines = []
|
| 880 |
+
for k, p in self.items():
|
| 881 |
+
if isinstance(p, torch.Tensor):
|
| 882 |
+
size_str = 'x'.join(str(size) for size in p.size())
|
| 883 |
+
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
| 884 |
+
device_str = f' ({p.device})'
|
| 885 |
+
else:
|
| 886 |
+
device_str = ''
|
| 887 |
+
parastr = '{} containing: [{} of size {}{}]'.format(
|
| 888 |
+
"Parameter" if isinstance(p, Parameter) else "Tensor",
|
| 889 |
+
torch.typename(p), size_str, device_str)
|
| 890 |
+
child_lines.append(' (' + str(k) + '): ' + parastr)
|
| 891 |
+
else:
|
| 892 |
+
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
|
| 893 |
+
tmpstr = '\n'.join(child_lines)
|
| 894 |
+
return tmpstr
|
| 895 |
+
|
| 896 |
+
def __call__(self, input):
|
| 897 |
+
raise RuntimeError('ParameterDict should not be called.')
|
| 898 |
+
|
| 899 |
+
def __or__(self, other: 'ParameterDict') -> 'ParameterDict':
|
| 900 |
+
copy = self.copy()
|
| 901 |
+
copy.update(other)
|
| 902 |
+
return copy
|
| 903 |
+
|
| 904 |
+
def __ror__(self, other: 'ParameterDict') -> 'ParameterDict':
|
| 905 |
+
copy = other.copy()
|
| 906 |
+
copy.update(self)
|
| 907 |
+
return copy
|
| 908 |
+
|
| 909 |
+
def __ior__(self, other : 'ParameterDict') -> Self:
|
| 910 |
+
self.update(other)
|
| 911 |
+
return self
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/dropout.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
from .. import functional as F
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
__all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout']
|
| 7 |
+
|
| 8 |
+
class _DropoutNd(Module):
|
| 9 |
+
__constants__ = ['p', 'inplace']
|
| 10 |
+
p: float
|
| 11 |
+
inplace: bool
|
| 12 |
+
|
| 13 |
+
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
if p < 0 or p > 1:
|
| 16 |
+
raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
|
| 17 |
+
self.p = p
|
| 18 |
+
self.inplace = inplace
|
| 19 |
+
|
| 20 |
+
def extra_repr(self) -> str:
|
| 21 |
+
return f'p={self.p}, inplace={self.inplace}'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Dropout(_DropoutNd):
|
| 25 |
+
r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`.
|
| 26 |
+
|
| 27 |
+
The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution.
|
| 28 |
+
|
| 29 |
+
Each channel will be zeroed out independently on every forward call.
|
| 30 |
+
|
| 31 |
+
This has proven to be an effective technique for regularization and
|
| 32 |
+
preventing the co-adaptation of neurons as described in the paper
|
| 33 |
+
`Improving neural networks by preventing co-adaptation of feature
|
| 34 |
+
detectors`_ .
|
| 35 |
+
|
| 36 |
+
Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
|
| 37 |
+
training. This means that during evaluation the module simply computes an
|
| 38 |
+
identity function.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
p: probability of an element to be zeroed. Default: 0.5
|
| 42 |
+
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
|
| 43 |
+
|
| 44 |
+
Shape:
|
| 45 |
+
- Input: :math:`(*)`. Input can be of any shape
|
| 46 |
+
- Output: :math:`(*)`. Output is of the same shape as input
|
| 47 |
+
|
| 48 |
+
Examples::
|
| 49 |
+
|
| 50 |
+
>>> m = nn.Dropout(p=0.2)
|
| 51 |
+
>>> input = torch.randn(20, 16)
|
| 52 |
+
>>> output = m(input)
|
| 53 |
+
|
| 54 |
+
.. _Improving neural networks by preventing co-adaptation of feature
|
| 55 |
+
detectors: https://arxiv.org/abs/1207.0580
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 59 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Dropout1d(_DropoutNd):
|
| 63 |
+
r"""Randomly zero out entire channels.
|
| 64 |
+
|
| 65 |
+
A channel is a 1D feature map,
|
| 66 |
+
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
| 67 |
+
batched input is a 1D tensor :math:`\text{input}[i, j]`.
|
| 68 |
+
|
| 69 |
+
Each channel will be zeroed out independently on every forward call with
|
| 70 |
+
probability :attr:`p` using samples from a Bernoulli distribution.
|
| 71 |
+
|
| 72 |
+
Usually the input comes from :class:`nn.Conv1d` modules.
|
| 73 |
+
|
| 74 |
+
As described in the paper
|
| 75 |
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
| 76 |
+
if adjacent pixels within feature maps are strongly correlated
|
| 77 |
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
| 78 |
+
will not regularize the activations and will otherwise just result
|
| 79 |
+
in an effective learning rate decrease.
|
| 80 |
+
|
| 81 |
+
In this case, :func:`nn.Dropout1d` will help promote independence between
|
| 82 |
+
feature maps and should be used instead.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
p (float, optional): probability of an element to be zero-ed.
|
| 86 |
+
inplace (bool, optional): If set to ``True``, will do this operation
|
| 87 |
+
in-place
|
| 88 |
+
|
| 89 |
+
Shape:
|
| 90 |
+
- Input: :math:`(N, C, L)` or :math:`(C, L)`.
|
| 91 |
+
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
|
| 92 |
+
|
| 93 |
+
Examples::
|
| 94 |
+
|
| 95 |
+
>>> m = nn.Dropout1d(p=0.2)
|
| 96 |
+
>>> input = torch.randn(20, 16, 32)
|
| 97 |
+
>>> output = m(input)
|
| 98 |
+
|
| 99 |
+
.. _Efficient Object Localization Using Convolutional Networks:
|
| 100 |
+
https://arxiv.org/abs/1411.4280
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 104 |
+
return F.dropout1d(input, self.p, self.training, self.inplace)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Dropout2d(_DropoutNd):
|
| 108 |
+
r"""Randomly zero out entire channels.
|
| 109 |
+
|
| 110 |
+
A channel is a 2D feature map,
|
| 111 |
+
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
| 112 |
+
batched input is a 2D tensor :math:`\text{input}[i, j]`.
|
| 113 |
+
|
| 114 |
+
Each channel will be zeroed out independently on every forward call with
|
| 115 |
+
probability :attr:`p` using samples from a Bernoulli distribution.
|
| 116 |
+
|
| 117 |
+
Usually the input comes from :class:`nn.Conv2d` modules.
|
| 118 |
+
|
| 119 |
+
As described in the paper
|
| 120 |
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
| 121 |
+
if adjacent pixels within feature maps are strongly correlated
|
| 122 |
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
| 123 |
+
will not regularize the activations and will otherwise just result
|
| 124 |
+
in an effective learning rate decrease.
|
| 125 |
+
|
| 126 |
+
In this case, :func:`nn.Dropout2d` will help promote independence between
|
| 127 |
+
feature maps and should be used instead.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
p (float, optional): probability of an element to be zero-ed.
|
| 131 |
+
inplace (bool, optional): If set to ``True``, will do this operation
|
| 132 |
+
in-place
|
| 133 |
+
|
| 134 |
+
.. warning ::
|
| 135 |
+
Due to historical reasons, this class will perform 1D channel-wise dropout
|
| 136 |
+
for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
|
| 137 |
+
support inputs without a batch dimension of shape :math:`(C, H, W)`. This
|
| 138 |
+
behavior will change in a future release to interpret 3D inputs as no-batch-dim
|
| 139 |
+
inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
|
| 140 |
+
|
| 141 |
+
Shape:
|
| 142 |
+
- Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
|
| 143 |
+
- Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
|
| 144 |
+
|
| 145 |
+
Examples::
|
| 146 |
+
|
| 147 |
+
>>> m = nn.Dropout2d(p=0.2)
|
| 148 |
+
>>> input = torch.randn(20, 16, 32, 32)
|
| 149 |
+
>>> output = m(input)
|
| 150 |
+
|
| 151 |
+
.. _Efficient Object Localization Using Convolutional Networks:
|
| 152 |
+
https://arxiv.org/abs/1411.4280
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 156 |
+
return F.dropout2d(input, self.p, self.training, self.inplace)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Dropout3d(_DropoutNd):
|
| 160 |
+
r"""Randomly zero out entire channels.
|
| 161 |
+
|
| 162 |
+
A channel is a 3D feature map,
|
| 163 |
+
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
| 164 |
+
batched input is a 3D tensor :math:`\text{input}[i, j]`.
|
| 165 |
+
|
| 166 |
+
Each channel will be zeroed out independently on every forward call with
|
| 167 |
+
probability :attr:`p` using samples from a Bernoulli distribution.
|
| 168 |
+
|
| 169 |
+
Usually the input comes from :class:`nn.Conv3d` modules.
|
| 170 |
+
|
| 171 |
+
As described in the paper
|
| 172 |
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
| 173 |
+
if adjacent pixels within feature maps are strongly correlated
|
| 174 |
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
| 175 |
+
will not regularize the activations and will otherwise just result
|
| 176 |
+
in an effective learning rate decrease.
|
| 177 |
+
|
| 178 |
+
In this case, :func:`nn.Dropout3d` will help promote independence between
|
| 179 |
+
feature maps and should be used instead.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
p (float, optional): probability of an element to be zeroed.
|
| 183 |
+
inplace (bool, optional): If set to ``True``, will do this operation
|
| 184 |
+
in-place
|
| 185 |
+
|
| 186 |
+
Shape:
|
| 187 |
+
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
| 188 |
+
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
| 189 |
+
|
| 190 |
+
Examples::
|
| 191 |
+
|
| 192 |
+
>>> m = nn.Dropout3d(p=0.2)
|
| 193 |
+
>>> input = torch.randn(20, 16, 4, 32, 32)
|
| 194 |
+
>>> output = m(input)
|
| 195 |
+
|
| 196 |
+
.. _Efficient Object Localization Using Convolutional Networks:
|
| 197 |
+
https://arxiv.org/abs/1411.4280
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 201 |
+
return F.dropout3d(input, self.p, self.training, self.inplace)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class AlphaDropout(_DropoutNd):
|
| 205 |
+
r"""Applies Alpha Dropout over the input.
|
| 206 |
+
|
| 207 |
+
Alpha Dropout is a type of Dropout that maintains the self-normalizing
|
| 208 |
+
property.
|
| 209 |
+
For an input with zero mean and unit standard deviation, the output of
|
| 210 |
+
Alpha Dropout maintains the original mean and standard deviation of the
|
| 211 |
+
input.
|
| 212 |
+
Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
|
| 213 |
+
that the outputs have zero mean and unit standard deviation.
|
| 214 |
+
|
| 215 |
+
During training, it randomly masks some of the elements of the input
|
| 216 |
+
tensor with probability *p* using samples from a bernoulli distribution.
|
| 217 |
+
The elements to masked are randomized on every forward call, and scaled
|
| 218 |
+
and shifted to maintain zero mean and unit standard deviation.
|
| 219 |
+
|
| 220 |
+
During evaluation the module simply computes an identity function.
|
| 221 |
+
|
| 222 |
+
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
p (float): probability of an element to be dropped. Default: 0.5
|
| 226 |
+
inplace (bool, optional): If set to ``True``, will do this operation
|
| 227 |
+
in-place
|
| 228 |
+
|
| 229 |
+
Shape:
|
| 230 |
+
- Input: :math:`(*)`. Input can be of any shape
|
| 231 |
+
- Output: :math:`(*)`. Output is of the same shape as input
|
| 232 |
+
|
| 233 |
+
Examples::
|
| 234 |
+
|
| 235 |
+
>>> m = nn.AlphaDropout(p=0.2)
|
| 236 |
+
>>> input = torch.randn(20, 16)
|
| 237 |
+
>>> output = m(input)
|
| 238 |
+
|
| 239 |
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 243 |
+
return F.alpha_dropout(input, self.p, self.training)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class FeatureAlphaDropout(_DropoutNd):
|
| 247 |
+
r"""Randomly masks out entire channels.
|
| 248 |
+
|
| 249 |
+
A channel is a feature map,
|
| 250 |
+
e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
|
| 251 |
+
is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of
|
| 252 |
+
setting activations to zero, as in regular Dropout, the activations are set
|
| 253 |
+
to the negative saturation value of the SELU activation function. More details
|
| 254 |
+
can be found in the paper `Self-Normalizing Neural Networks`_ .
|
| 255 |
+
|
| 256 |
+
Each element will be masked independently for each sample on every forward
|
| 257 |
+
call with probability :attr:`p` using samples from a Bernoulli distribution.
|
| 258 |
+
The elements to be masked are randomized on every forward call, and scaled
|
| 259 |
+
and shifted to maintain zero mean and unit variance.
|
| 260 |
+
|
| 261 |
+
Usually the input comes from :class:`nn.AlphaDropout` modules.
|
| 262 |
+
|
| 263 |
+
As described in the paper
|
| 264 |
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
| 265 |
+
if adjacent pixels within feature maps are strongly correlated
|
| 266 |
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
| 267 |
+
will not regularize the activations and will otherwise just result
|
| 268 |
+
in an effective learning rate decrease.
|
| 269 |
+
|
| 270 |
+
In this case, :func:`nn.AlphaDropout` will help promote independence between
|
| 271 |
+
feature maps and should be used instead.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
p (float, optional): probability of an element to be zeroed. Default: 0.5
|
| 275 |
+
inplace (bool, optional): If set to ``True``, will do this operation
|
| 276 |
+
in-place
|
| 277 |
+
|
| 278 |
+
Shape:
|
| 279 |
+
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
| 280 |
+
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
| 281 |
+
|
| 282 |
+
Examples::
|
| 283 |
+
|
| 284 |
+
>>> m = nn.FeatureAlphaDropout(p=0.2)
|
| 285 |
+
>>> input = torch.randn(20, 16, 4, 32, 32)
|
| 286 |
+
>>> output = m(input)
|
| 287 |
+
|
| 288 |
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
| 289 |
+
.. _Efficient Object Localization Using Convolutional Networks:
|
| 290 |
+
https://arxiv.org/abs/1411.4280
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 294 |
+
return F.feature_alpha_dropout(input, self.p, self.training)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/flatten.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.types import _size
|
| 6 |
+
|
| 7 |
+
__all__ = ['Flatten', 'Unflatten']
|
| 8 |
+
|
| 9 |
+
class Flatten(Module):
|
| 10 |
+
r"""
|
| 11 |
+
Flattens a contiguous range of dims into a tensor.
|
| 12 |
+
|
| 13 |
+
For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
|
| 14 |
+
|
| 15 |
+
Shape:
|
| 16 |
+
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
|
| 17 |
+
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
|
| 18 |
+
number of dimensions including none.
|
| 19 |
+
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
start_dim: first dim to flatten (default = 1).
|
| 23 |
+
end_dim: last dim to flatten (default = -1).
|
| 24 |
+
|
| 25 |
+
Examples::
|
| 26 |
+
>>> input = torch.randn(32, 1, 5, 5)
|
| 27 |
+
>>> # With default parameters
|
| 28 |
+
>>> m = nn.Flatten()
|
| 29 |
+
>>> output = m(input)
|
| 30 |
+
>>> output.size()
|
| 31 |
+
torch.Size([32, 25])
|
| 32 |
+
>>> # With non-default parameters
|
| 33 |
+
>>> m = nn.Flatten(0, 2)
|
| 34 |
+
>>> output = m(input)
|
| 35 |
+
>>> output.size()
|
| 36 |
+
torch.Size([160, 5])
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
__constants__ = ['start_dim', 'end_dim']
|
| 40 |
+
start_dim: int
|
| 41 |
+
end_dim: int
|
| 42 |
+
|
| 43 |
+
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.start_dim = start_dim
|
| 46 |
+
self.end_dim = end_dim
|
| 47 |
+
|
| 48 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 49 |
+
return input.flatten(self.start_dim, self.end_dim)
|
| 50 |
+
|
| 51 |
+
def extra_repr(self) -> str:
|
| 52 |
+
return f'start_dim={self.start_dim}, end_dim={self.end_dim}'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Unflatten(Module):
|
| 56 |
+
r"""
|
| 57 |
+
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
|
| 58 |
+
|
| 59 |
+
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
|
| 60 |
+
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
|
| 61 |
+
|
| 62 |
+
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
|
| 63 |
+
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
|
| 64 |
+
(tuple of `(name, size)` tuples) for `NamedTensor` input.
|
| 65 |
+
|
| 66 |
+
Shape:
|
| 67 |
+
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
|
| 68 |
+
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
|
| 69 |
+
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
|
| 70 |
+
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
dim (Union[int, str]): Dimension to be unflattened
|
| 74 |
+
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
|
| 75 |
+
|
| 76 |
+
Examples:
|
| 77 |
+
>>> input = torch.randn(2, 50)
|
| 78 |
+
>>> # With tuple of ints
|
| 79 |
+
>>> m = nn.Sequential(
|
| 80 |
+
>>> nn.Linear(50, 50),
|
| 81 |
+
>>> nn.Unflatten(1, (2, 5, 5))
|
| 82 |
+
>>> )
|
| 83 |
+
>>> output = m(input)
|
| 84 |
+
>>> output.size()
|
| 85 |
+
torch.Size([2, 2, 5, 5])
|
| 86 |
+
>>> # With torch.Size
|
| 87 |
+
>>> m = nn.Sequential(
|
| 88 |
+
>>> nn.Linear(50, 50),
|
| 89 |
+
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
|
| 90 |
+
>>> )
|
| 91 |
+
>>> output = m(input)
|
| 92 |
+
>>> output.size()
|
| 93 |
+
torch.Size([2, 2, 5, 5])
|
| 94 |
+
>>> # With namedshape (tuple of tuples)
|
| 95 |
+
>>> input = torch.randn(2, 50, names=('N', 'features'))
|
| 96 |
+
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
|
| 97 |
+
>>> output = unflatten(input)
|
| 98 |
+
>>> output.size()
|
| 99 |
+
torch.Size([2, 2, 5, 5])
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
NamedShape = Tuple[Tuple[str, int]]
|
| 103 |
+
|
| 104 |
+
__constants__ = ['dim', 'unflattened_size']
|
| 105 |
+
dim: Union[int, str]
|
| 106 |
+
unflattened_size: Union[_size, NamedShape]
|
| 107 |
+
|
| 108 |
+
def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
if isinstance(dim, int):
|
| 112 |
+
self._require_tuple_int(unflattened_size)
|
| 113 |
+
elif isinstance(dim, str):
|
| 114 |
+
self._require_tuple_tuple(unflattened_size)
|
| 115 |
+
else:
|
| 116 |
+
raise TypeError("invalid argument type for dim parameter")
|
| 117 |
+
|
| 118 |
+
self.dim = dim
|
| 119 |
+
self.unflattened_size = unflattened_size
|
| 120 |
+
|
| 121 |
+
def _require_tuple_tuple(self, input):
|
| 122 |
+
if (isinstance(input, tuple)):
|
| 123 |
+
for idx, elem in enumerate(input):
|
| 124 |
+
if not isinstance(elem, tuple):
|
| 125 |
+
raise TypeError("unflattened_size must be tuple of tuples, " +
|
| 126 |
+
f"but found element of type {type(elem).__name__} at pos {idx}")
|
| 127 |
+
return
|
| 128 |
+
raise TypeError("unflattened_size must be a tuple of tuples, " +
|
| 129 |
+
f"but found type {type(input).__name__}")
|
| 130 |
+
|
| 131 |
+
def _require_tuple_int(self, input):
|
| 132 |
+
if (isinstance(input, (tuple, list))):
|
| 133 |
+
for idx, elem in enumerate(input):
|
| 134 |
+
if not isinstance(elem, int):
|
| 135 |
+
raise TypeError("unflattened_size must be tuple of ints, " +
|
| 136 |
+
f"but found element of type {type(elem).__name__} at pos {idx}")
|
| 137 |
+
return
|
| 138 |
+
raise TypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}")
|
| 139 |
+
|
| 140 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 141 |
+
return input.unflatten(self.dim, self.unflattened_size)
|
| 142 |
+
|
| 143 |
+
def extra_repr(self) -> str:
|
| 144 |
+
return f'dim={self.dim}, unflattened_size={self.unflattened_size}'
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numbers
|
| 3 |
+
from torch.nn.parameter import Parameter
|
| 4 |
+
from .module import Module
|
| 5 |
+
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
|
| 6 |
+
from .. import functional as F
|
| 7 |
+
from .. import init
|
| 8 |
+
|
| 9 |
+
from torch import Tensor, Size
|
| 10 |
+
from typing import Union, List, Tuple
|
| 11 |
+
|
| 12 |
+
__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm']
|
| 13 |
+
|
| 14 |
+
class LocalResponseNorm(Module):
|
| 15 |
+
r"""Applies local response normalization over an input signal.
|
| 16 |
+
|
| 17 |
+
The input signal is composed of several input planes, where channels occupy the second dimension.
|
| 18 |
+
Applies normalization across channels.
|
| 19 |
+
|
| 20 |
+
.. math::
|
| 21 |
+
b_{c} = a_{c}\left(k + \frac{\alpha}{n}
|
| 22 |
+
\sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
size: amount of neighbouring channels used for normalization
|
| 26 |
+
alpha: multiplicative factor. Default: 0.0001
|
| 27 |
+
beta: exponent. Default: 0.75
|
| 28 |
+
k: additive factor. Default: 1
|
| 29 |
+
|
| 30 |
+
Shape:
|
| 31 |
+
- Input: :math:`(N, C, *)`
|
| 32 |
+
- Output: :math:`(N, C, *)` (same shape as input)
|
| 33 |
+
|
| 34 |
+
Examples::
|
| 35 |
+
|
| 36 |
+
>>> lrn = nn.LocalResponseNorm(2)
|
| 37 |
+
>>> signal_2d = torch.randn(32, 5, 24, 24)
|
| 38 |
+
>>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
|
| 39 |
+
>>> output_2d = lrn(signal_2d)
|
| 40 |
+
>>> output_4d = lrn(signal_4d)
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
__constants__ = ['size', 'alpha', 'beta', 'k']
|
| 45 |
+
size: int
|
| 46 |
+
alpha: float
|
| 47 |
+
beta: float
|
| 48 |
+
k: float
|
| 49 |
+
|
| 50 |
+
def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.size = size
|
| 53 |
+
self.alpha = alpha
|
| 54 |
+
self.beta = beta
|
| 55 |
+
self.k = k
|
| 56 |
+
|
| 57 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 58 |
+
return F.local_response_norm(input, self.size, self.alpha, self.beta,
|
| 59 |
+
self.k)
|
| 60 |
+
|
| 61 |
+
def extra_repr(self):
|
| 62 |
+
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class CrossMapLRN2d(Module):
|
| 66 |
+
size: int
|
| 67 |
+
alpha: float
|
| 68 |
+
beta: float
|
| 69 |
+
k: float
|
| 70 |
+
|
| 71 |
+
def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.size = size
|
| 74 |
+
self.alpha = alpha
|
| 75 |
+
self.beta = beta
|
| 76 |
+
self.k = k
|
| 77 |
+
|
| 78 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 79 |
+
return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta,
|
| 80 |
+
self.k)
|
| 81 |
+
|
| 82 |
+
def extra_repr(self) -> str:
|
| 83 |
+
return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
_shape_t = Union[int, List[int], Size]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class LayerNorm(Module):
|
| 90 |
+
r"""Applies Layer Normalization over a mini-batch of inputs.
|
| 91 |
+
|
| 92 |
+
This layer implements the operation as described in
|
| 93 |
+
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
|
| 94 |
+
|
| 95 |
+
.. math::
|
| 96 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 97 |
+
|
| 98 |
+
The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
|
| 99 |
+
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
|
| 100 |
+
is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
|
| 101 |
+
the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
|
| 102 |
+
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
| 103 |
+
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
| 104 |
+
The standard-deviation is calculated via the biased estimator, equivalent to
|
| 105 |
+
`torch.var(input, unbiased=False)`.
|
| 106 |
+
|
| 107 |
+
.. note::
|
| 108 |
+
Unlike Batch Normalization and Instance Normalization, which applies
|
| 109 |
+
scalar scale and bias for each entire channel/plane with the
|
| 110 |
+
:attr:`affine` option, Layer Normalization applies per-element scale and
|
| 111 |
+
bias with :attr:`elementwise_affine`.
|
| 112 |
+
|
| 113 |
+
This layer uses statistics computed from input data in both training and
|
| 114 |
+
evaluation modes.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
normalized_shape (int or list or torch.Size): input shape from an expected input
|
| 118 |
+
of size
|
| 119 |
+
|
| 120 |
+
.. math::
|
| 121 |
+
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
|
| 122 |
+
\times \ldots \times \text{normalized\_shape}[-1]]
|
| 123 |
+
|
| 124 |
+
If a single integer is used, it is treated as a singleton list, and this module will
|
| 125 |
+
normalize over the last dimension which is expected to be of that specific size.
|
| 126 |
+
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
| 127 |
+
elementwise_affine: a boolean value that when set to ``True``, this module
|
| 128 |
+
has learnable per-element affine parameters initialized to ones (for weights)
|
| 129 |
+
and zeros (for biases). Default: ``True``.
|
| 130 |
+
bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
|
| 131 |
+
:attr:`elementwise_affine` is ``True``). Default: ``True``.
|
| 132 |
+
|
| 133 |
+
Attributes:
|
| 134 |
+
weight: the learnable weights of the module of shape
|
| 135 |
+
:math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
|
| 136 |
+
The values are initialized to 1.
|
| 137 |
+
bias: the learnable bias of the module of shape
|
| 138 |
+
:math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
|
| 139 |
+
The values are initialized to 0.
|
| 140 |
+
|
| 141 |
+
Shape:
|
| 142 |
+
- Input: :math:`(N, *)`
|
| 143 |
+
- Output: :math:`(N, *)` (same shape as input)
|
| 144 |
+
|
| 145 |
+
Examples::
|
| 146 |
+
|
| 147 |
+
>>> # NLP Example
|
| 148 |
+
>>> batch, sentence_length, embedding_dim = 20, 5, 10
|
| 149 |
+
>>> embedding = torch.randn(batch, sentence_length, embedding_dim)
|
| 150 |
+
>>> layer_norm = nn.LayerNorm(embedding_dim)
|
| 151 |
+
>>> # Activate module
|
| 152 |
+
>>> layer_norm(embedding)
|
| 153 |
+
>>>
|
| 154 |
+
>>> # Image Example
|
| 155 |
+
>>> N, C, H, W = 20, 5, 10, 10
|
| 156 |
+
>>> input = torch.randn(N, C, H, W)
|
| 157 |
+
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
|
| 158 |
+
>>> # as shown in the image below
|
| 159 |
+
>>> layer_norm = nn.LayerNorm([C, H, W])
|
| 160 |
+
>>> output = layer_norm(input)
|
| 161 |
+
|
| 162 |
+
.. image:: ../_static/img/nn/layer_norm.jpg
|
| 163 |
+
:scale: 50 %
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
| 168 |
+
normalized_shape: Tuple[int, ...]
|
| 169 |
+
eps: float
|
| 170 |
+
elementwise_affine: bool
|
| 171 |
+
|
| 172 |
+
def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
|
| 173 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 174 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 175 |
+
super().__init__()
|
| 176 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 177 |
+
# mypy error: incompatible types in assignment
|
| 178 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
| 179 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
| 180 |
+
self.eps = eps
|
| 181 |
+
self.elementwise_affine = elementwise_affine
|
| 182 |
+
if self.elementwise_affine:
|
| 183 |
+
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
| 184 |
+
if bias:
|
| 185 |
+
self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
| 186 |
+
else:
|
| 187 |
+
self.register_parameter('bias', None)
|
| 188 |
+
else:
|
| 189 |
+
self.register_parameter('weight', None)
|
| 190 |
+
self.register_parameter('bias', None)
|
| 191 |
+
|
| 192 |
+
self.reset_parameters()
|
| 193 |
+
|
| 194 |
+
def reset_parameters(self) -> None:
|
| 195 |
+
if self.elementwise_affine:
|
| 196 |
+
init.ones_(self.weight)
|
| 197 |
+
if self.bias is not None:
|
| 198 |
+
init.zeros_(self.bias)
|
| 199 |
+
|
| 200 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 201 |
+
return F.layer_norm(
|
| 202 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 203 |
+
|
| 204 |
+
def extra_repr(self) -> str:
|
| 205 |
+
return '{normalized_shape}, eps={eps}, ' \
|
| 206 |
+
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class GroupNorm(Module):
|
| 210 |
+
r"""Applies Group Normalization over a mini-batch of inputs.
|
| 211 |
+
|
| 212 |
+
This layer implements the operation as described in
|
| 213 |
+
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
|
| 214 |
+
|
| 215 |
+
.. math::
|
| 216 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
| 217 |
+
|
| 218 |
+
The input channels are separated into :attr:`num_groups` groups, each containing
|
| 219 |
+
``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
|
| 220 |
+
:attr:`num_groups`. The mean and standard-deviation are calculated
|
| 221 |
+
separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
|
| 222 |
+
per-channel affine transform parameter vectors of size :attr:`num_channels` if
|
| 223 |
+
:attr:`affine` is ``True``.
|
| 224 |
+
The standard-deviation is calculated via the biased estimator, equivalent to
|
| 225 |
+
`torch.var(input, unbiased=False)`.
|
| 226 |
+
|
| 227 |
+
This layer uses statistics computed from input data in both training and
|
| 228 |
+
evaluation modes.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
num_groups (int): number of groups to separate the channels into
|
| 232 |
+
num_channels (int): number of channels expected in input
|
| 233 |
+
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
| 234 |
+
affine: a boolean value that when set to ``True``, this module
|
| 235 |
+
has learnable per-channel affine parameters initialized to ones (for weights)
|
| 236 |
+
and zeros (for biases). Default: ``True``.
|
| 237 |
+
|
| 238 |
+
Shape:
|
| 239 |
+
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
|
| 240 |
+
- Output: :math:`(N, C, *)` (same shape as input)
|
| 241 |
+
|
| 242 |
+
Examples::
|
| 243 |
+
|
| 244 |
+
>>> input = torch.randn(20, 6, 10, 10)
|
| 245 |
+
>>> # Separate 6 channels into 3 groups
|
| 246 |
+
>>> m = nn.GroupNorm(3, 6)
|
| 247 |
+
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
|
| 248 |
+
>>> m = nn.GroupNorm(6, 6)
|
| 249 |
+
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
|
| 250 |
+
>>> m = nn.GroupNorm(1, 6)
|
| 251 |
+
>>> # Activating the module
|
| 252 |
+
>>> output = m(input)
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
|
| 256 |
+
num_groups: int
|
| 257 |
+
num_channels: int
|
| 258 |
+
eps: float
|
| 259 |
+
affine: bool
|
| 260 |
+
|
| 261 |
+
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
|
| 262 |
+
device=None, dtype=None) -> None:
|
| 263 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 264 |
+
super().__init__()
|
| 265 |
+
if num_channels % num_groups != 0:
|
| 266 |
+
raise ValueError('num_channels must be divisible by num_groups')
|
| 267 |
+
|
| 268 |
+
self.num_groups = num_groups
|
| 269 |
+
self.num_channels = num_channels
|
| 270 |
+
self.eps = eps
|
| 271 |
+
self.affine = affine
|
| 272 |
+
if self.affine:
|
| 273 |
+
self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
|
| 274 |
+
self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
|
| 275 |
+
else:
|
| 276 |
+
self.register_parameter('weight', None)
|
| 277 |
+
self.register_parameter('bias', None)
|
| 278 |
+
|
| 279 |
+
self.reset_parameters()
|
| 280 |
+
|
| 281 |
+
def reset_parameters(self) -> None:
|
| 282 |
+
if self.affine:
|
| 283 |
+
init.ones_(self.weight)
|
| 284 |
+
init.zeros_(self.bias)
|
| 285 |
+
|
| 286 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 287 |
+
return F.group_norm(
|
| 288 |
+
input, self.num_groups, self.weight, self.bias, self.eps)
|
| 289 |
+
|
| 290 |
+
def extra_repr(self) -> str:
|
| 291 |
+
return '{num_groups}, {num_channels}, eps={eps}, ' \
|
| 292 |
+
'affine={affine}'.format(**self.__dict__)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# TODO: ContrastiveNorm2d
|
| 296 |
+
# TODO: DivisiveNorm2d
|
| 297 |
+
# TODO: SubtractiveNorm2d
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/padding.py
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
from .utils import _pair, _quadruple, _ntuple
|
| 3 |
+
from .. import functional as F
|
| 4 |
+
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from ..common_types import _size_2_t, _size_4_t, _size_6_t
|
| 7 |
+
from typing import Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# TODO: grad_output size asserts in THNN
|
| 11 |
+
|
| 12 |
+
__all__ = ['CircularPad1d', 'CircularPad2d', 'CircularPad3d', 'ConstantPad1d', 'ConstantPad2d',
|
| 13 |
+
'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
|
| 14 |
+
'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d']
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _CircularPadNd(Module):
|
| 18 |
+
__constants__ = ['padding']
|
| 19 |
+
padding: Sequence[int]
|
| 20 |
+
|
| 21 |
+
def _check_input_dim(self, input):
|
| 22 |
+
raise NotImplementedError
|
| 23 |
+
|
| 24 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 25 |
+
self._check_input_dim(input)
|
| 26 |
+
return F.pad(input, self.padding, 'circular')
|
| 27 |
+
|
| 28 |
+
def extra_repr(self) -> str:
|
| 29 |
+
return f'{self.padding}'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CircularPad1d(_CircularPadNd):
|
| 33 |
+
r"""Pads the input tensor using circular padding of the input boundary.
|
| 34 |
+
|
| 35 |
+
Tensor values at the beginning of the dimension are used to pad the end,
|
| 36 |
+
and values at the end are used to pad the beginning. If negative padding is
|
| 37 |
+
applied then the ends of the tensor get removed.
|
| 38 |
+
|
| 39 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 43 |
+
padding in all boundaries. If a 2-`tuple`, uses
|
| 44 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
| 45 |
+
|
| 46 |
+
Shape:
|
| 47 |
+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
|
| 48 |
+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
|
| 49 |
+
|
| 50 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 51 |
+
|
| 52 |
+
Examples::
|
| 53 |
+
|
| 54 |
+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
|
| 55 |
+
>>> m = nn.CircularPad1d(2)
|
| 56 |
+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
|
| 57 |
+
>>> input
|
| 58 |
+
tensor([[[0., 1., 2., 3.],
|
| 59 |
+
[4., 5., 6., 7.]]])
|
| 60 |
+
>>> m(input)
|
| 61 |
+
tensor([[[2., 3., 0., 1., 2., 3., 0., 1.],
|
| 62 |
+
[6., 7., 4., 5., 6., 7., 4., 5.]]])
|
| 63 |
+
>>> # using different paddings for different sides
|
| 64 |
+
>>> m = nn.CircularPad1d((3, 1))
|
| 65 |
+
>>> m(input)
|
| 66 |
+
tensor([[[1., 2., 3., 0., 1., 2., 3., 0.],
|
| 67 |
+
[5., 6., 7., 4., 5., 6., 7., 4.]]])
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
padding: Tuple[int, int]
|
| 71 |
+
|
| 72 |
+
def __init__(self, padding: _size_2_t) -> None:
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.padding = _pair(padding)
|
| 75 |
+
|
| 76 |
+
def _check_input_dim(self, input):
|
| 77 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"expected 2D or 3D input (got {input.dim()}D input)"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class CircularPad2d(_CircularPadNd):
|
| 84 |
+
r"""Pads the input tensor using circular padding of the input boundary.
|
| 85 |
+
|
| 86 |
+
Tensor values at the beginning of the dimension are used to pad the end,
|
| 87 |
+
and values at the end are used to pad the beginning. If negative padding is
|
| 88 |
+
applied then the ends of the tensor get removed.
|
| 89 |
+
|
| 90 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 94 |
+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
| 95 |
+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
| 96 |
+
|
| 97 |
+
Shape:
|
| 98 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 99 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 100 |
+
|
| 101 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 102 |
+
|
| 103 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 104 |
+
|
| 105 |
+
Examples::
|
| 106 |
+
|
| 107 |
+
>>> m = nn.CircularPad2d(2)
|
| 108 |
+
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
|
| 109 |
+
>>> input
|
| 110 |
+
tensor([[[[0., 1., 2.],
|
| 111 |
+
[3., 4., 5.],
|
| 112 |
+
[6., 7., 8.]]]])
|
| 113 |
+
>>> m(input)
|
| 114 |
+
tensor([[[[4., 5., 3., 4., 5., 3., 4.],
|
| 115 |
+
[7., 8., 6., 7., 8., 6., 7.],
|
| 116 |
+
[1., 2., 0., 1., 2., 0., 1.],
|
| 117 |
+
[4., 5., 3., 4., 5., 3., 4.],
|
| 118 |
+
[7., 8., 6., 7., 8., 6., 7.],
|
| 119 |
+
[1., 2., 0., 1., 2., 0., 1.],
|
| 120 |
+
[4., 5., 3., 4., 5., 3., 4.]]]])
|
| 121 |
+
>>> # using different paddings for different sides
|
| 122 |
+
>>> m = nn.CircularPad2d((1, 1, 2, 0))
|
| 123 |
+
>>> m(input)
|
| 124 |
+
tensor([[[[5., 3., 4., 5., 3.],
|
| 125 |
+
[8., 6., 7., 8., 6.],
|
| 126 |
+
[2., 0., 1., 2., 0.],
|
| 127 |
+
[5., 3., 4., 5., 3.],
|
| 128 |
+
[8., 6., 7., 8., 6.]]]])
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
padding: Tuple[int, int, int, int]
|
| 132 |
+
|
| 133 |
+
def __init__(self, padding: _size_4_t) -> None:
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.padding = _quadruple(padding)
|
| 136 |
+
|
| 137 |
+
def _check_input_dim(self, input):
|
| 138 |
+
if input.dim() != 3 and input.dim() != 4:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"expected 3D or 4D input (got {input.dim()}D input)"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class CircularPad3d(_CircularPadNd):
|
| 145 |
+
r"""Pads the input tensor using circular padding of the input boundary.
|
| 146 |
+
|
| 147 |
+
Tensor values at the beginning of the dimension are used to pad the end,
|
| 148 |
+
and values at the end are used to pad the beginning. If negative padding is
|
| 149 |
+
applied then the ends of the tensor get removed.
|
| 150 |
+
|
| 151 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 155 |
+
padding in all boundaries. If a 6-`tuple`, uses
|
| 156 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
| 157 |
+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
| 158 |
+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
| 159 |
+
|
| 160 |
+
Shape:
|
| 161 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 162 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
|
| 163 |
+
where
|
| 164 |
+
|
| 165 |
+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
| 166 |
+
|
| 167 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 168 |
+
|
| 169 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 170 |
+
|
| 171 |
+
Examples::
|
| 172 |
+
|
| 173 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 174 |
+
>>> m = nn.CircularPad3d(3)
|
| 175 |
+
>>> input = torch.randn(16, 3, 8, 320, 480)
|
| 176 |
+
>>> output = m(input)
|
| 177 |
+
>>> # using different paddings for different sides
|
| 178 |
+
>>> m = nn.CircularPad3d((3, 3, 6, 6, 1, 1))
|
| 179 |
+
>>> output = m(input)
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
padding: Tuple[int, int, int, int, int, int]
|
| 183 |
+
|
| 184 |
+
def __init__(self, padding: _size_6_t) -> None:
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.padding = _ntuple(6)(padding)
|
| 187 |
+
|
| 188 |
+
def _check_input_dim(self, input):
|
| 189 |
+
if input.dim() != 4 and input.dim() != 5:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
f"expected 4D or 5D input (got {input.dim()}D input)"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class _ConstantPadNd(Module):
|
| 196 |
+
__constants__ = ['padding', 'value']
|
| 197 |
+
value: float
|
| 198 |
+
padding: Sequence[int]
|
| 199 |
+
|
| 200 |
+
def __init__(self, value: float) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.value = value
|
| 203 |
+
|
| 204 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 205 |
+
return F.pad(input, self.padding, 'constant', self.value)
|
| 206 |
+
|
| 207 |
+
def extra_repr(self) -> str:
|
| 208 |
+
return f'padding={self.padding}, value={self.value}'
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class ConstantPad1d(_ConstantPadNd):
|
| 212 |
+
r"""Pads the input tensor boundaries with a constant value.
|
| 213 |
+
|
| 214 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 218 |
+
padding in both boundaries. If a 2-`tuple`, uses
|
| 219 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
| 220 |
+
|
| 221 |
+
Shape:
|
| 222 |
+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
|
| 223 |
+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
|
| 224 |
+
|
| 225 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 226 |
+
|
| 227 |
+
Examples::
|
| 228 |
+
|
| 229 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 230 |
+
>>> m = nn.ConstantPad1d(2, 3.5)
|
| 231 |
+
>>> input = torch.randn(1, 2, 4)
|
| 232 |
+
>>> input
|
| 233 |
+
tensor([[[-1.0491, -0.7152, -0.0749, 0.8530],
|
| 234 |
+
[-1.3287, 1.8966, 0.1466, -0.2771]]])
|
| 235 |
+
>>> m(input)
|
| 236 |
+
tensor([[[ 3.5000, 3.5000, -1.0491, -0.7152, -0.0749, 0.8530, 3.5000,
|
| 237 |
+
3.5000],
|
| 238 |
+
[ 3.5000, 3.5000, -1.3287, 1.8966, 0.1466, -0.2771, 3.5000,
|
| 239 |
+
3.5000]]])
|
| 240 |
+
>>> m = nn.ConstantPad1d(2, 3.5)
|
| 241 |
+
>>> input = torch.randn(1, 2, 3)
|
| 242 |
+
>>> input
|
| 243 |
+
tensor([[[ 1.6616, 1.4523, -1.1255],
|
| 244 |
+
[-3.6372, 0.1182, -1.8652]]])
|
| 245 |
+
>>> m(input)
|
| 246 |
+
tensor([[[ 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000, 3.5000],
|
| 247 |
+
[ 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000, 3.5000]]])
|
| 248 |
+
>>> # using different paddings for different sides
|
| 249 |
+
>>> m = nn.ConstantPad1d((3, 1), 3.5)
|
| 250 |
+
>>> m(input)
|
| 251 |
+
tensor([[[ 3.5000, 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000],
|
| 252 |
+
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
padding: Tuple[int, int]
|
| 256 |
+
|
| 257 |
+
def __init__(self, padding: _size_2_t, value: float):
|
| 258 |
+
super().__init__(value)
|
| 259 |
+
self.padding = _pair(padding)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class ConstantPad2d(_ConstantPadNd):
|
| 263 |
+
r"""Pads the input tensor boundaries with a constant value.
|
| 264 |
+
|
| 265 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 269 |
+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
| 270 |
+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
| 271 |
+
|
| 272 |
+
Shape:
|
| 273 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 274 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 275 |
+
|
| 276 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 277 |
+
|
| 278 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 279 |
+
|
| 280 |
+
Examples::
|
| 281 |
+
|
| 282 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 283 |
+
>>> m = nn.ConstantPad2d(2, 3.5)
|
| 284 |
+
>>> input = torch.randn(1, 2, 2)
|
| 285 |
+
>>> input
|
| 286 |
+
tensor([[[ 1.6585, 0.4320],
|
| 287 |
+
[-0.8701, -0.4649]]])
|
| 288 |
+
>>> m(input)
|
| 289 |
+
tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
| 290 |
+
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
| 291 |
+
[ 3.5000, 3.5000, 1.6585, 0.4320, 3.5000, 3.5000],
|
| 292 |
+
[ 3.5000, 3.5000, -0.8701, -0.4649, 3.5000, 3.5000],
|
| 293 |
+
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
| 294 |
+
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
|
| 295 |
+
>>> # using different paddings for different sides
|
| 296 |
+
>>> m = nn.ConstantPad2d((3, 0, 2, 1), 3.5)
|
| 297 |
+
>>> m(input)
|
| 298 |
+
tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
| 299 |
+
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
|
| 300 |
+
[ 3.5000, 3.5000, 3.5000, 1.6585, 0.4320],
|
| 301 |
+
[ 3.5000, 3.5000, 3.5000, -0.8701, -0.4649],
|
| 302 |
+
[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
__constants__ = ['padding', 'value']
|
| 306 |
+
padding: Tuple[int, int, int, int]
|
| 307 |
+
|
| 308 |
+
def __init__(self, padding: _size_4_t, value: float) -> None:
|
| 309 |
+
super().__init__(value)
|
| 310 |
+
self.padding = _quadruple(padding)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class ConstantPad3d(_ConstantPadNd):
|
| 314 |
+
r"""Pads the input tensor boundaries with a constant value.
|
| 315 |
+
|
| 316 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 320 |
+
padding in all boundaries. If a 6-`tuple`, uses
|
| 321 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
| 322 |
+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
| 323 |
+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
| 324 |
+
|
| 325 |
+
Shape:
|
| 326 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 327 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
|
| 328 |
+
:math:`(C, D_{out}, H_{out}, W_{out})`, where
|
| 329 |
+
|
| 330 |
+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
| 331 |
+
|
| 332 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 333 |
+
|
| 334 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 335 |
+
|
| 336 |
+
Examples::
|
| 337 |
+
|
| 338 |
+
>>> m = nn.ConstantPad3d(3, 3.5)
|
| 339 |
+
>>> input = torch.randn(16, 3, 10, 20, 30)
|
| 340 |
+
>>> output = m(input)
|
| 341 |
+
>>> # using different paddings for different sides
|
| 342 |
+
>>> m = nn.ConstantPad3d((3, 3, 6, 6, 0, 1), 3.5)
|
| 343 |
+
>>> output = m(input)
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
padding: Tuple[int, int, int, int, int, int]
|
| 347 |
+
|
| 348 |
+
def __init__(self, padding: _size_6_t, value: float) -> None:
|
| 349 |
+
super().__init__(value)
|
| 350 |
+
self.padding = _ntuple(6)(padding)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class _ReflectionPadNd(Module):
|
| 354 |
+
__constants__ = ['padding']
|
| 355 |
+
padding: Sequence[int]
|
| 356 |
+
|
| 357 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 358 |
+
return F.pad(input, self.padding, 'reflect')
|
| 359 |
+
|
| 360 |
+
def extra_repr(self) -> str:
|
| 361 |
+
return f'{self.padding}'
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class ReflectionPad1d(_ReflectionPadNd):
|
| 365 |
+
r"""Pads the input tensor using the reflection of the input boundary.
|
| 366 |
+
|
| 367 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 371 |
+
padding in all boundaries. If a 2-`tuple`, uses
|
| 372 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
| 373 |
+
|
| 374 |
+
Shape:
|
| 375 |
+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
|
| 376 |
+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
|
| 377 |
+
|
| 378 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 379 |
+
|
| 380 |
+
Examples::
|
| 381 |
+
|
| 382 |
+
>>> m = nn.ReflectionPad1d(2)
|
| 383 |
+
>>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
|
| 384 |
+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
|
| 385 |
+
>>> input
|
| 386 |
+
tensor([[[0., 1., 2., 3.],
|
| 387 |
+
[4., 5., 6., 7.]]])
|
| 388 |
+
>>> m(input)
|
| 389 |
+
tensor([[[2., 1., 0., 1., 2., 3., 2., 1.],
|
| 390 |
+
[6., 5., 4., 5., 6., 7., 6., 5.]]])
|
| 391 |
+
>>> # using different paddings for different sides
|
| 392 |
+
>>> m = nn.ReflectionPad1d((3, 1))
|
| 393 |
+
>>> m(input)
|
| 394 |
+
tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
|
| 395 |
+
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
padding: Tuple[int, int]
|
| 399 |
+
|
| 400 |
+
def __init__(self, padding: _size_2_t) -> None:
|
| 401 |
+
super().__init__()
|
| 402 |
+
self.padding = _pair(padding)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class ReflectionPad2d(_ReflectionPadNd):
|
| 406 |
+
r"""Pads the input tensor using the reflection of the input boundary.
|
| 407 |
+
|
| 408 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 412 |
+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
| 413 |
+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
| 414 |
+
Note that padding size should be less than the corresponding input dimension.
|
| 415 |
+
|
| 416 |
+
Shape:
|
| 417 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 418 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})` where
|
| 419 |
+
|
| 420 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 421 |
+
|
| 422 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 423 |
+
|
| 424 |
+
Examples::
|
| 425 |
+
|
| 426 |
+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
|
| 427 |
+
>>> m = nn.ReflectionPad2d(2)
|
| 428 |
+
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
|
| 429 |
+
>>> input
|
| 430 |
+
tensor([[[[0., 1., 2.],
|
| 431 |
+
[3., 4., 5.],
|
| 432 |
+
[6., 7., 8.]]]])
|
| 433 |
+
>>> m(input)
|
| 434 |
+
tensor([[[[8., 7., 6., 7., 8., 7., 6.],
|
| 435 |
+
[5., 4., 3., 4., 5., 4., 3.],
|
| 436 |
+
[2., 1., 0., 1., 2., 1., 0.],
|
| 437 |
+
[5., 4., 3., 4., 5., 4., 3.],
|
| 438 |
+
[8., 7., 6., 7., 8., 7., 6.],
|
| 439 |
+
[5., 4., 3., 4., 5., 4., 3.],
|
| 440 |
+
[2., 1., 0., 1., 2., 1., 0.]]]])
|
| 441 |
+
>>> # using different paddings for different sides
|
| 442 |
+
>>> m = nn.ReflectionPad2d((1, 1, 2, 0))
|
| 443 |
+
>>> m(input)
|
| 444 |
+
tensor([[[[7., 6., 7., 8., 7.],
|
| 445 |
+
[4., 3., 4., 5., 4.],
|
| 446 |
+
[1., 0., 1., 2., 1.],
|
| 447 |
+
[4., 3., 4., 5., 4.],
|
| 448 |
+
[7., 6., 7., 8., 7.]]]])
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
padding: Tuple[int, int, int, int]
|
| 452 |
+
|
| 453 |
+
def __init__(self, padding: _size_4_t) -> None:
|
| 454 |
+
super().__init__()
|
| 455 |
+
self.padding = _quadruple(padding)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class ReflectionPad3d(_ReflectionPadNd):
|
| 459 |
+
r"""Pads the input tensor using the reflection of the input boundary.
|
| 460 |
+
|
| 461 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 465 |
+
padding in all boundaries. If a 6-`tuple`, uses
|
| 466 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
| 467 |
+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
| 468 |
+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
| 469 |
+
|
| 470 |
+
Shape:
|
| 471 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 472 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
|
| 473 |
+
where
|
| 474 |
+
|
| 475 |
+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
| 476 |
+
|
| 477 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 478 |
+
|
| 479 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 480 |
+
|
| 481 |
+
Examples::
|
| 482 |
+
|
| 483 |
+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
|
| 484 |
+
>>> m = nn.ReflectionPad3d(1)
|
| 485 |
+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 1, 2, 2, 2)
|
| 486 |
+
>>> m(input)
|
| 487 |
+
tensor([[[[[7., 6., 7., 6.],
|
| 488 |
+
[5., 4., 5., 4.],
|
| 489 |
+
[7., 6., 7., 6.],
|
| 490 |
+
[5., 4., 5., 4.]],
|
| 491 |
+
[[3., 2., 3., 2.],
|
| 492 |
+
[1., 0., 1., 0.],
|
| 493 |
+
[3., 2., 3., 2.],
|
| 494 |
+
[1., 0., 1., 0.]],
|
| 495 |
+
[[7., 6., 7., 6.],
|
| 496 |
+
[5., 4., 5., 4.],
|
| 497 |
+
[7., 6., 7., 6.],
|
| 498 |
+
[5., 4., 5., 4.]],
|
| 499 |
+
[[3., 2., 3., 2.],
|
| 500 |
+
[1., 0., 1., 0.],
|
| 501 |
+
[3., 2., 3., 2.],
|
| 502 |
+
[1., 0., 1., 0.]]]]])
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
padding: Tuple[int, int, int, int, int, int]
|
| 506 |
+
|
| 507 |
+
def __init__(self, padding: _size_6_t) -> None:
|
| 508 |
+
super().__init__()
|
| 509 |
+
self.padding = _ntuple(6)(padding)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class _ReplicationPadNd(Module):
|
| 513 |
+
__constants__ = ['padding']
|
| 514 |
+
padding: Sequence[int]
|
| 515 |
+
|
| 516 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 517 |
+
return F.pad(input, self.padding, 'replicate')
|
| 518 |
+
|
| 519 |
+
def extra_repr(self) -> str:
|
| 520 |
+
return f'{self.padding}'
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class ReplicationPad1d(_ReplicationPadNd):
|
| 524 |
+
r"""Pads the input tensor using replication of the input boundary.
|
| 525 |
+
|
| 526 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 530 |
+
padding in all boundaries. If a 2-`tuple`, uses
|
| 531 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
| 532 |
+
|
| 533 |
+
Shape:
|
| 534 |
+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
|
| 535 |
+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
|
| 536 |
+
|
| 537 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 538 |
+
|
| 539 |
+
Examples::
|
| 540 |
+
|
| 541 |
+
>>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
|
| 542 |
+
>>> m = nn.ReplicationPad1d(2)
|
| 543 |
+
>>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
|
| 544 |
+
>>> input
|
| 545 |
+
tensor([[[0., 1., 2., 3.],
|
| 546 |
+
[4., 5., 6., 7.]]])
|
| 547 |
+
>>> m(input)
|
| 548 |
+
tensor([[[0., 0., 0., 1., 2., 3., 3., 3.],
|
| 549 |
+
[4., 4., 4., 5., 6., 7., 7., 7.]]])
|
| 550 |
+
>>> # using different paddings for different sides
|
| 551 |
+
>>> m = nn.ReplicationPad1d((3, 1))
|
| 552 |
+
>>> m(input)
|
| 553 |
+
tensor([[[0., 0., 0., 0., 1., 2., 3., 3.],
|
| 554 |
+
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
padding: Tuple[int, int]
|
| 558 |
+
|
| 559 |
+
def __init__(self, padding: _size_2_t) -> None:
|
| 560 |
+
super().__init__()
|
| 561 |
+
self.padding = _pair(padding)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class ReplicationPad2d(_ReplicationPadNd):
|
| 565 |
+
r"""Pads the input tensor using replication of the input boundary.
|
| 566 |
+
|
| 567 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 571 |
+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
| 572 |
+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
| 573 |
+
|
| 574 |
+
Shape:
|
| 575 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 576 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 577 |
+
|
| 578 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 579 |
+
|
| 580 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 581 |
+
|
| 582 |
+
Examples::
|
| 583 |
+
|
| 584 |
+
>>> m = nn.ReplicationPad2d(2)
|
| 585 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 586 |
+
>>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
|
| 587 |
+
>>> input
|
| 588 |
+
tensor([[[[0., 1., 2.],
|
| 589 |
+
[3., 4., 5.],
|
| 590 |
+
[6., 7., 8.]]]])
|
| 591 |
+
>>> m(input)
|
| 592 |
+
tensor([[[[0., 0., 0., 1., 2., 2., 2.],
|
| 593 |
+
[0., 0., 0., 1., 2., 2., 2.],
|
| 594 |
+
[0., 0., 0., 1., 2., 2., 2.],
|
| 595 |
+
[3., 3., 3., 4., 5., 5., 5.],
|
| 596 |
+
[6., 6., 6., 7., 8., 8., 8.],
|
| 597 |
+
[6., 6., 6., 7., 8., 8., 8.],
|
| 598 |
+
[6., 6., 6., 7., 8., 8., 8.]]]])
|
| 599 |
+
>>> # using different paddings for different sides
|
| 600 |
+
>>> m = nn.ReplicationPad2d((1, 1, 2, 0))
|
| 601 |
+
>>> m(input)
|
| 602 |
+
tensor([[[[0., 0., 1., 2., 2.],
|
| 603 |
+
[0., 0., 1., 2., 2.],
|
| 604 |
+
[0., 0., 1., 2., 2.],
|
| 605 |
+
[3., 3., 4., 5., 5.],
|
| 606 |
+
[6., 6., 7., 8., 8.]]]])
|
| 607 |
+
"""
|
| 608 |
+
|
| 609 |
+
padding: Tuple[int, int, int, int]
|
| 610 |
+
|
| 611 |
+
def __init__(self, padding: _size_4_t) -> None:
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.padding = _quadruple(padding)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class ReplicationPad3d(_ReplicationPadNd):
|
| 617 |
+
r"""Pads the input tensor using replication of the input boundary.
|
| 618 |
+
|
| 619 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 623 |
+
padding in all boundaries. If a 6-`tuple`, uses
|
| 624 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
| 625 |
+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
| 626 |
+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
| 627 |
+
|
| 628 |
+
Shape:
|
| 629 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 630 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
|
| 631 |
+
where
|
| 632 |
+
|
| 633 |
+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
| 634 |
+
|
| 635 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 636 |
+
|
| 637 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 638 |
+
|
| 639 |
+
Examples::
|
| 640 |
+
|
| 641 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 642 |
+
>>> m = nn.ReplicationPad3d(3)
|
| 643 |
+
>>> input = torch.randn(16, 3, 8, 320, 480)
|
| 644 |
+
>>> output = m(input)
|
| 645 |
+
>>> # using different paddings for different sides
|
| 646 |
+
>>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1))
|
| 647 |
+
>>> output = m(input)
|
| 648 |
+
"""
|
| 649 |
+
|
| 650 |
+
padding: Tuple[int, int, int, int, int, int]
|
| 651 |
+
|
| 652 |
+
def __init__(self, padding: _size_6_t) -> None:
|
| 653 |
+
super().__init__()
|
| 654 |
+
self.padding = _ntuple(6)(padding)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class ZeroPad1d(ConstantPad1d):
|
| 658 |
+
r"""Pads the input tensor boundaries with zero.
|
| 659 |
+
|
| 660 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 661 |
+
|
| 662 |
+
Args:
|
| 663 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 664 |
+
padding in both boundaries. If a 2-`tuple`, uses
|
| 665 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
|
| 666 |
+
|
| 667 |
+
Shape:
|
| 668 |
+
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
|
| 669 |
+
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
|
| 670 |
+
|
| 671 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 672 |
+
|
| 673 |
+
Examples::
|
| 674 |
+
|
| 675 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 676 |
+
>>> m = nn.ZeroPad1d(2)
|
| 677 |
+
>>> input = torch.randn(1, 2, 4)
|
| 678 |
+
>>> input
|
| 679 |
+
tensor([[[-1.0491, -0.7152, -0.0749, 0.8530],
|
| 680 |
+
[-1.3287, 1.8966, 0.1466, -0.2771]]])
|
| 681 |
+
>>> m(input)
|
| 682 |
+
tensor([[[ 0.0000, 0.0000, -1.0491, -0.7152, -0.0749, 0.8530, 0.0000,
|
| 683 |
+
0.0000],
|
| 684 |
+
[ 0.0000, 0.0000, -1.3287, 1.8966, 0.1466, -0.2771, 0.0000,
|
| 685 |
+
0.0000]]])
|
| 686 |
+
>>> m = nn.ZeroPad1d(2)
|
| 687 |
+
>>> input = torch.randn(1, 2, 3)
|
| 688 |
+
>>> input
|
| 689 |
+
tensor([[[ 1.6616, 1.4523, -1.1255],
|
| 690 |
+
[-3.6372, 0.1182, -1.8652]]])
|
| 691 |
+
>>> m(input)
|
| 692 |
+
tensor([[[ 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000, 0.0000],
|
| 693 |
+
[ 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000, 0.0000]]])
|
| 694 |
+
>>> # using different paddings for different sides
|
| 695 |
+
>>> m = nn.ZeroPad1d((3, 1))
|
| 696 |
+
>>> m(input)
|
| 697 |
+
tensor([[[ 0.0000, 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000],
|
| 698 |
+
[ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]])
|
| 699 |
+
"""
|
| 700 |
+
|
| 701 |
+
padding: Tuple[int, int]
|
| 702 |
+
|
| 703 |
+
def __init__(self, padding: _size_2_t) -> None:
|
| 704 |
+
super().__init__(padding, 0.)
|
| 705 |
+
|
| 706 |
+
def extra_repr(self) -> str:
|
| 707 |
+
return f'{self.padding}'
|
| 708 |
+
|
| 709 |
+
class ZeroPad2d(ConstantPad2d):
|
| 710 |
+
r"""Pads the input tensor boundaries with zero.
|
| 711 |
+
|
| 712 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 716 |
+
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
|
| 717 |
+
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
|
| 718 |
+
|
| 719 |
+
Shape:
|
| 720 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 721 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 722 |
+
|
| 723 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 724 |
+
|
| 725 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 726 |
+
|
| 727 |
+
Examples::
|
| 728 |
+
|
| 729 |
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
| 730 |
+
>>> m = nn.ZeroPad2d(2)
|
| 731 |
+
>>> input = torch.randn(1, 1, 3, 3)
|
| 732 |
+
>>> input
|
| 733 |
+
tensor([[[[-0.1678, -0.4418, 1.9466],
|
| 734 |
+
[ 0.9604, -0.4219, -0.5241],
|
| 735 |
+
[-0.9162, -0.5436, -0.6446]]]])
|
| 736 |
+
>>> m(input)
|
| 737 |
+
tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
| 738 |
+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
| 739 |
+
[ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000],
|
| 740 |
+
[ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000],
|
| 741 |
+
[ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000],
|
| 742 |
+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
| 743 |
+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
|
| 744 |
+
>>> # using different paddings for different sides
|
| 745 |
+
>>> m = nn.ZeroPad2d((1, 1, 2, 0))
|
| 746 |
+
>>> m(input)
|
| 747 |
+
tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
| 748 |
+
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
| 749 |
+
[ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000],
|
| 750 |
+
[ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000],
|
| 751 |
+
[ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]])
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
padding: Tuple[int, int, int, int]
|
| 755 |
+
|
| 756 |
+
def __init__(self, padding: _size_4_t) -> None:
|
| 757 |
+
super().__init__(padding, 0.)
|
| 758 |
+
|
| 759 |
+
def extra_repr(self) -> str:
|
| 760 |
+
return f'{self.padding}'
|
| 761 |
+
|
| 762 |
+
class ZeroPad3d(ConstantPad3d):
|
| 763 |
+
r"""Pads the input tensor boundaries with zero.
|
| 764 |
+
|
| 765 |
+
For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
|
| 766 |
+
|
| 767 |
+
Args:
|
| 768 |
+
padding (int, tuple): the size of the padding. If is `int`, uses the same
|
| 769 |
+
padding in all boundaries. If a 6-`tuple`, uses
|
| 770 |
+
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
|
| 771 |
+
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
|
| 772 |
+
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
|
| 773 |
+
|
| 774 |
+
Shape:
|
| 775 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 776 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
|
| 777 |
+
:math:`(C, D_{out}, H_{out}, W_{out})`, where
|
| 778 |
+
|
| 779 |
+
:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
|
| 780 |
+
|
| 781 |
+
:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
|
| 782 |
+
|
| 783 |
+
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
|
| 784 |
+
|
| 785 |
+
Examples::
|
| 786 |
+
|
| 787 |
+
>>> m = nn.ZeroPad3d(3)
|
| 788 |
+
>>> input = torch.randn(16, 3, 10, 20, 30)
|
| 789 |
+
>>> output = m(input)
|
| 790 |
+
>>> # using different paddings for different sides
|
| 791 |
+
>>> m = nn.ZeroPad3d((3, 3, 6, 6, 0, 1))
|
| 792 |
+
>>> output = m(input)
|
| 793 |
+
"""
|
| 794 |
+
|
| 795 |
+
padding: Tuple[int, int, int, int, int, int]
|
| 796 |
+
|
| 797 |
+
def __init__(self, padding: _size_6_t) -> None:
|
| 798 |
+
super().__init__(padding, 0.)
|
| 799 |
+
|
| 800 |
+
def extra_repr(self) -> str:
|
| 801 |
+
return f'{self.padding}'
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pixelshuffle.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
from .. import functional as F
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
__all__ = ['PixelShuffle', 'PixelUnshuffle']
|
| 7 |
+
|
| 8 |
+
class PixelShuffle(Module):
|
| 9 |
+
r"""Rearrange elements in a tensor according to an upscaling factor.
|
| 10 |
+
|
| 11 |
+
Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
| 12 |
+
to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
|
| 13 |
+
|
| 14 |
+
This is useful for implementing efficient sub-pixel convolution
|
| 15 |
+
with a stride of :math:`1/r`.
|
| 16 |
+
|
| 17 |
+
See the paper:
|
| 18 |
+
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
|
| 19 |
+
by Shi et. al (2016) for more details.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
upscale_factor (int): factor to increase spatial resolution by
|
| 23 |
+
|
| 24 |
+
Shape:
|
| 25 |
+
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
|
| 26 |
+
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
|
| 27 |
+
|
| 28 |
+
.. math::
|
| 29 |
+
C_{out} = C_{in} \div \text{upscale\_factor}^2
|
| 30 |
+
|
| 31 |
+
.. math::
|
| 32 |
+
H_{out} = H_{in} \times \text{upscale\_factor}
|
| 33 |
+
|
| 34 |
+
.. math::
|
| 35 |
+
W_{out} = W_{in} \times \text{upscale\_factor}
|
| 36 |
+
|
| 37 |
+
Examples::
|
| 38 |
+
|
| 39 |
+
>>> pixel_shuffle = nn.PixelShuffle(3)
|
| 40 |
+
>>> input = torch.randn(1, 9, 4, 4)
|
| 41 |
+
>>> output = pixel_shuffle(input)
|
| 42 |
+
>>> print(output.size())
|
| 43 |
+
torch.Size([1, 1, 12, 12])
|
| 44 |
+
|
| 45 |
+
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
|
| 46 |
+
https://arxiv.org/abs/1609.05158
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
__constants__ = ['upscale_factor']
|
| 50 |
+
upscale_factor: int
|
| 51 |
+
|
| 52 |
+
def __init__(self, upscale_factor: int) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.upscale_factor = upscale_factor
|
| 55 |
+
|
| 56 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 57 |
+
return F.pixel_shuffle(input, self.upscale_factor)
|
| 58 |
+
|
| 59 |
+
def extra_repr(self) -> str:
|
| 60 |
+
return f'upscale_factor={self.upscale_factor}'
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PixelUnshuffle(Module):
|
| 64 |
+
r"""Reverse the PixelShuffle operation.
|
| 65 |
+
|
| 66 |
+
Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements
|
| 67 |
+
in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
|
| 68 |
+
:math:`(*, C \times r^2, H, W)`, where r is a downscale factor.
|
| 69 |
+
|
| 70 |
+
See the paper:
|
| 71 |
+
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
|
| 72 |
+
by Shi et. al (2016) for more details.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
downscale_factor (int): factor to decrease spatial resolution by
|
| 76 |
+
|
| 77 |
+
Shape:
|
| 78 |
+
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
|
| 79 |
+
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
|
| 80 |
+
|
| 81 |
+
.. math::
|
| 82 |
+
C_{out} = C_{in} \times \text{downscale\_factor}^2
|
| 83 |
+
|
| 84 |
+
.. math::
|
| 85 |
+
H_{out} = H_{in} \div \text{downscale\_factor}
|
| 86 |
+
|
| 87 |
+
.. math::
|
| 88 |
+
W_{out} = W_{in} \div \text{downscale\_factor}
|
| 89 |
+
|
| 90 |
+
Examples::
|
| 91 |
+
|
| 92 |
+
>>> pixel_unshuffle = nn.PixelUnshuffle(3)
|
| 93 |
+
>>> input = torch.randn(1, 1, 12, 12)
|
| 94 |
+
>>> output = pixel_unshuffle(input)
|
| 95 |
+
>>> print(output.size())
|
| 96 |
+
torch.Size([1, 9, 4, 4])
|
| 97 |
+
|
| 98 |
+
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
|
| 99 |
+
https://arxiv.org/abs/1609.05158
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
__constants__ = ['downscale_factor']
|
| 103 |
+
downscale_factor: int
|
| 104 |
+
|
| 105 |
+
def __init__(self, downscale_factor: int) -> None:
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.downscale_factor = downscale_factor
|
| 108 |
+
|
| 109 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 110 |
+
return F.pixel_unshuffle(input, self.downscale_factor)
|
| 111 |
+
|
| 112 |
+
def extra_repr(self) -> str:
|
| 113 |
+
return f'downscale_factor={self.downscale_factor}'
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pooling.py
ADDED
|
@@ -0,0 +1,1306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from .module import Module
|
| 5 |
+
from .utils import _single, _pair, _triple
|
| 6 |
+
from .. import functional as F
|
| 7 |
+
|
| 8 |
+
from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t,
|
| 9 |
+
_ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t)
|
| 10 |
+
|
| 11 |
+
__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
|
| 12 |
+
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d',
|
| 13 |
+
'LPPool2d', 'LPPool3d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
|
| 14 |
+
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
|
| 15 |
+
|
| 16 |
+
class _MaxPoolNd(Module):
|
| 17 |
+
__constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
|
| 18 |
+
'return_indices', 'ceil_mode']
|
| 19 |
+
return_indices: bool
|
| 20 |
+
ceil_mode: bool
|
| 21 |
+
|
| 22 |
+
def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None,
|
| 23 |
+
padding: _size_any_t = 0, dilation: _size_any_t = 1,
|
| 24 |
+
return_indices: bool = False, ceil_mode: bool = False) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.kernel_size = kernel_size
|
| 27 |
+
self.stride = stride if (stride is not None) else kernel_size
|
| 28 |
+
self.padding = padding
|
| 29 |
+
self.dilation = dilation
|
| 30 |
+
self.return_indices = return_indices
|
| 31 |
+
self.ceil_mode = ceil_mode
|
| 32 |
+
|
| 33 |
+
def extra_repr(self) -> str:
|
| 34 |
+
return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
|
| 35 |
+
', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MaxPool1d(_MaxPoolNd):
|
| 39 |
+
r"""Applies a 1D max pooling over an input signal composed of several input planes.
|
| 40 |
+
|
| 41 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C, L)`
|
| 42 |
+
and output :math:`(N, C, L_{out})` can be precisely described as:
|
| 43 |
+
|
| 44 |
+
.. math::
|
| 45 |
+
out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
|
| 46 |
+
input(N_i, C_j, stride \times k + m)
|
| 47 |
+
|
| 48 |
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
| 49 |
+
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
|
| 50 |
+
sliding window. This `link`_ has a nice visualization of the pooling parameters.
|
| 51 |
+
|
| 52 |
+
Note:
|
| 53 |
+
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
|
| 54 |
+
or the input. Sliding windows that would start in the right padded region are ignored.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
kernel_size: The size of the sliding window, must be > 0.
|
| 58 |
+
stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.
|
| 59 |
+
padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
|
| 60 |
+
dilation: The stride between elements within a sliding window, must be > 0.
|
| 61 |
+
return_indices: If ``True``, will return the argmax along with the max values.
|
| 62 |
+
Useful for :class:`torch.nn.MaxUnpool1d` later
|
| 63 |
+
ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
|
| 64 |
+
ensures that every element in the input tensor is covered by a sliding window.
|
| 65 |
+
|
| 66 |
+
Shape:
|
| 67 |
+
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
|
| 68 |
+
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
|
| 69 |
+
|
| 70 |
+
.. math::
|
| 71 |
+
L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
|
| 72 |
+
\times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
|
| 73 |
+
|
| 74 |
+
Examples::
|
| 75 |
+
|
| 76 |
+
>>> # pool of size=3, stride=2
|
| 77 |
+
>>> m = nn.MaxPool1d(3, stride=2)
|
| 78 |
+
>>> input = torch.randn(20, 16, 50)
|
| 79 |
+
>>> output = m(input)
|
| 80 |
+
|
| 81 |
+
.. _link:
|
| 82 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
kernel_size: _size_1_t
|
| 86 |
+
stride: _size_1_t
|
| 87 |
+
padding: _size_1_t
|
| 88 |
+
dilation: _size_1_t
|
| 89 |
+
|
| 90 |
+
def forward(self, input: Tensor):
|
| 91 |
+
return F.max_pool1d(input, self.kernel_size, self.stride,
|
| 92 |
+
self.padding, self.dilation, ceil_mode=self.ceil_mode,
|
| 93 |
+
return_indices=self.return_indices)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class MaxPool2d(_MaxPoolNd):
|
| 97 |
+
r"""Applies a 2D max pooling over an input signal composed of several input planes.
|
| 98 |
+
|
| 99 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
|
| 100 |
+
output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
|
| 101 |
+
can be precisely described as:
|
| 102 |
+
|
| 103 |
+
.. math::
|
| 104 |
+
\begin{aligned}
|
| 105 |
+
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
| 106 |
+
& \text{input}(N_i, C_j, \text{stride[0]} \times h + m,
|
| 107 |
+
\text{stride[1]} \times w + n)
|
| 108 |
+
\end{aligned}
|
| 109 |
+
|
| 110 |
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
| 111 |
+
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
| 112 |
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
| 113 |
+
|
| 114 |
+
Note:
|
| 115 |
+
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
|
| 116 |
+
or the input. Sliding windows that would start in the right padded region are ignored.
|
| 117 |
+
|
| 118 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
|
| 119 |
+
|
| 120 |
+
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
| 121 |
+
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
| 122 |
+
and the second `int` for the width dimension
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
kernel_size: the size of the window to take a max over
|
| 126 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 127 |
+
padding: Implicit negative infinity padding to be added on both sides
|
| 128 |
+
dilation: a parameter that controls the stride of elements in the window
|
| 129 |
+
return_indices: if ``True``, will return the max indices along with the outputs.
|
| 130 |
+
Useful for :class:`torch.nn.MaxUnpool2d` later
|
| 131 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 132 |
+
|
| 133 |
+
Shape:
|
| 134 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`
|
| 135 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 136 |
+
|
| 137 |
+
.. math::
|
| 138 |
+
H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
|
| 139 |
+
\times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
|
| 140 |
+
|
| 141 |
+
.. math::
|
| 142 |
+
W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
|
| 143 |
+
\times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
|
| 144 |
+
|
| 145 |
+
Examples::
|
| 146 |
+
|
| 147 |
+
>>> # pool of square window of size=3, stride=2
|
| 148 |
+
>>> m = nn.MaxPool2d(3, stride=2)
|
| 149 |
+
>>> # pool of non-square window
|
| 150 |
+
>>> m = nn.MaxPool2d((3, 2), stride=(2, 1))
|
| 151 |
+
>>> input = torch.randn(20, 16, 50, 32)
|
| 152 |
+
>>> output = m(input)
|
| 153 |
+
|
| 154 |
+
.. _link:
|
| 155 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
kernel_size: _size_2_t
|
| 159 |
+
stride: _size_2_t
|
| 160 |
+
padding: _size_2_t
|
| 161 |
+
dilation: _size_2_t
|
| 162 |
+
|
| 163 |
+
def forward(self, input: Tensor):
|
| 164 |
+
return F.max_pool2d(input, self.kernel_size, self.stride,
|
| 165 |
+
self.padding, self.dilation, ceil_mode=self.ceil_mode,
|
| 166 |
+
return_indices=self.return_indices)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class MaxPool3d(_MaxPoolNd):
|
| 170 |
+
r"""Applies a 3D max pooling over an input signal composed of several input planes.
|
| 171 |
+
|
| 172 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
|
| 173 |
+
output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
| 174 |
+
can be precisely described as:
|
| 175 |
+
|
| 176 |
+
.. math::
|
| 177 |
+
\begin{aligned}
|
| 178 |
+
\text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
|
| 179 |
+
& \text{input}(N_i, C_j, \text{stride[0]} \times d + k,
|
| 180 |
+
\text{stride[1]} \times h + m, \text{stride[2]} \times w + n)
|
| 181 |
+
\end{aligned}
|
| 182 |
+
|
| 183 |
+
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
|
| 184 |
+
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
|
| 185 |
+
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
| 186 |
+
|
| 187 |
+
Note:
|
| 188 |
+
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
|
| 189 |
+
or the input. Sliding windows that would start in the right padded region are ignored.
|
| 190 |
+
|
| 191 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
|
| 192 |
+
|
| 193 |
+
- a single ``int`` -- in which case the same value is used for the depth, height and width dimension
|
| 194 |
+
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
| 195 |
+
the second `int` for the height dimension and the third `int` for the width dimension
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
kernel_size: the size of the window to take a max over
|
| 199 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 200 |
+
padding: Implicit negative infinity padding to be added on all three sides
|
| 201 |
+
dilation: a parameter that controls the stride of elements in the window
|
| 202 |
+
return_indices: if ``True``, will return the max indices along with the outputs.
|
| 203 |
+
Useful for :class:`torch.nn.MaxUnpool3d` later
|
| 204 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 205 |
+
|
| 206 |
+
Shape:
|
| 207 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 208 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
|
| 209 |
+
|
| 210 |
+
.. math::
|
| 211 |
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
|
| 212 |
+
(\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
|
| 213 |
+
|
| 214 |
+
.. math::
|
| 215 |
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
|
| 216 |
+
(\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
|
| 217 |
+
|
| 218 |
+
.. math::
|
| 219 |
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
|
| 220 |
+
(\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
|
| 221 |
+
|
| 222 |
+
Examples::
|
| 223 |
+
|
| 224 |
+
>>> # pool of square window of size=3, stride=2
|
| 225 |
+
>>> m = nn.MaxPool3d(3, stride=2)
|
| 226 |
+
>>> # pool of non-square window
|
| 227 |
+
>>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
|
| 228 |
+
>>> input = torch.randn(20, 16, 50, 44, 31)
|
| 229 |
+
>>> output = m(input)
|
| 230 |
+
|
| 231 |
+
.. _link:
|
| 232 |
+
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
| 233 |
+
""" # noqa: E501
|
| 234 |
+
|
| 235 |
+
kernel_size: _size_3_t
|
| 236 |
+
stride: _size_3_t
|
| 237 |
+
padding: _size_3_t
|
| 238 |
+
dilation: _size_3_t
|
| 239 |
+
|
| 240 |
+
def forward(self, input: Tensor):
|
| 241 |
+
return F.max_pool3d(input, self.kernel_size, self.stride,
|
| 242 |
+
self.padding, self.dilation, ceil_mode=self.ceil_mode,
|
| 243 |
+
return_indices=self.return_indices)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class _MaxUnpoolNd(Module):
|
| 247 |
+
|
| 248 |
+
def extra_repr(self) -> str:
|
| 249 |
+
return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}'
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class MaxUnpool1d(_MaxUnpoolNd):
|
| 253 |
+
r"""Computes a partial inverse of :class:`MaxPool1d`.
|
| 254 |
+
|
| 255 |
+
:class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
|
| 256 |
+
|
| 257 |
+
:class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`
|
| 258 |
+
including the indices of the maximal values and computes a partial inverse
|
| 259 |
+
in which all non-maximal values are set to zero.
|
| 260 |
+
|
| 261 |
+
Note:
|
| 262 |
+
This operation may behave nondeterministically when the input indices has repeat values.
|
| 263 |
+
See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
|
| 264 |
+
|
| 265 |
+
.. note:: :class:`MaxPool1d` can map several input sizes to the same output
|
| 266 |
+
sizes. Hence, the inversion process can get ambiguous.
|
| 267 |
+
To accommodate this, you can provide the needed output size
|
| 268 |
+
as an additional argument :attr:`output_size` in the forward call.
|
| 269 |
+
See the Inputs and Example below.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
kernel_size (int or tuple): Size of the max pooling window.
|
| 273 |
+
stride (int or tuple): Stride of the max pooling window.
|
| 274 |
+
It is set to :attr:`kernel_size` by default.
|
| 275 |
+
padding (int or tuple): Padding that was added to the input
|
| 276 |
+
|
| 277 |
+
Inputs:
|
| 278 |
+
- `input`: the input Tensor to invert
|
| 279 |
+
- `indices`: the indices given out by :class:`~torch.nn.MaxPool1d`
|
| 280 |
+
- `output_size` (optional): the targeted output size
|
| 281 |
+
|
| 282 |
+
Shape:
|
| 283 |
+
- Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`.
|
| 284 |
+
- Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where
|
| 285 |
+
|
| 286 |
+
.. math::
|
| 287 |
+
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
|
| 288 |
+
|
| 289 |
+
or as given by :attr:`output_size` in the call operator
|
| 290 |
+
|
| 291 |
+
Example::
|
| 292 |
+
|
| 293 |
+
>>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
|
| 294 |
+
>>> pool = nn.MaxPool1d(2, stride=2, return_indices=True)
|
| 295 |
+
>>> unpool = nn.MaxUnpool1d(2, stride=2)
|
| 296 |
+
>>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])
|
| 297 |
+
>>> output, indices = pool(input)
|
| 298 |
+
>>> unpool(output, indices)
|
| 299 |
+
tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
|
| 300 |
+
|
| 301 |
+
>>> # Example showcasing the use of output_size
|
| 302 |
+
>>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])
|
| 303 |
+
>>> output, indices = pool(input)
|
| 304 |
+
>>> unpool(output, indices, output_size=input.size())
|
| 305 |
+
tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]])
|
| 306 |
+
|
| 307 |
+
>>> unpool(output, indices)
|
| 308 |
+
tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
kernel_size: _size_1_t
|
| 312 |
+
stride: _size_1_t
|
| 313 |
+
padding: _size_1_t
|
| 314 |
+
|
| 315 |
+
def __init__(self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0) -> None:
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.kernel_size = _single(kernel_size)
|
| 318 |
+
self.stride = _single(stride if (stride is not None) else kernel_size)
|
| 319 |
+
self.padding = _single(padding)
|
| 320 |
+
|
| 321 |
+
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 322 |
+
return F.max_unpool1d(input, indices, self.kernel_size, self.stride,
|
| 323 |
+
self.padding, output_size)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class MaxUnpool2d(_MaxUnpoolNd):
|
| 327 |
+
r"""Computes a partial inverse of :class:`MaxPool2d`.
|
| 328 |
+
|
| 329 |
+
:class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
|
| 330 |
+
|
| 331 |
+
:class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`
|
| 332 |
+
including the indices of the maximal values and computes a partial inverse
|
| 333 |
+
in which all non-maximal values are set to zero.
|
| 334 |
+
|
| 335 |
+
Note:
|
| 336 |
+
This operation may behave nondeterministically when the input indices has repeat values.
|
| 337 |
+
See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
|
| 338 |
+
|
| 339 |
+
.. note:: :class:`MaxPool2d` can map several input sizes to the same output
|
| 340 |
+
sizes. Hence, the inversion process can get ambiguous.
|
| 341 |
+
To accommodate this, you can provide the needed output size
|
| 342 |
+
as an additional argument :attr:`output_size` in the forward call.
|
| 343 |
+
See the Inputs and Example below.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
kernel_size (int or tuple): Size of the max pooling window.
|
| 347 |
+
stride (int or tuple): Stride of the max pooling window.
|
| 348 |
+
It is set to :attr:`kernel_size` by default.
|
| 349 |
+
padding (int or tuple): Padding that was added to the input
|
| 350 |
+
|
| 351 |
+
Inputs:
|
| 352 |
+
- `input`: the input Tensor to invert
|
| 353 |
+
- `indices`: the indices given out by :class:`~torch.nn.MaxPool2d`
|
| 354 |
+
- `output_size` (optional): the targeted output size
|
| 355 |
+
|
| 356 |
+
Shape:
|
| 357 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 358 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 359 |
+
|
| 360 |
+
.. math::
|
| 361 |
+
H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
|
| 362 |
+
|
| 363 |
+
.. math::
|
| 364 |
+
W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
|
| 365 |
+
|
| 366 |
+
or as given by :attr:`output_size` in the call operator
|
| 367 |
+
|
| 368 |
+
Example::
|
| 369 |
+
|
| 370 |
+
>>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
|
| 371 |
+
>>> unpool = nn.MaxUnpool2d(2, stride=2)
|
| 372 |
+
>>> input = torch.tensor([[[[ 1., 2., 3., 4.],
|
| 373 |
+
[ 5., 6., 7., 8.],
|
| 374 |
+
[ 9., 10., 11., 12.],
|
| 375 |
+
[13., 14., 15., 16.]]]])
|
| 376 |
+
>>> output, indices = pool(input)
|
| 377 |
+
>>> unpool(output, indices)
|
| 378 |
+
tensor([[[[ 0., 0., 0., 0.],
|
| 379 |
+
[ 0., 6., 0., 8.],
|
| 380 |
+
[ 0., 0., 0., 0.],
|
| 381 |
+
[ 0., 14., 0., 16.]]]])
|
| 382 |
+
>>> # Now using output_size to resolve an ambiguous size for the inverse
|
| 383 |
+
>>> input = torch.torch.tensor([[[[ 1., 2., 3., 4., 5.],
|
| 384 |
+
[ 6., 7., 8., 9., 10.],
|
| 385 |
+
[11., 12., 13., 14., 15.],
|
| 386 |
+
[16., 17., 18., 19., 20.]]]])
|
| 387 |
+
>>> output, indices = pool(input)
|
| 388 |
+
>>> # This call will not work without specifying output_size
|
| 389 |
+
>>> unpool(output, indices, output_size=input.size())
|
| 390 |
+
tensor([[[[ 0., 0., 0., 0., 0.],
|
| 391 |
+
[ 0., 7., 0., 9., 0.],
|
| 392 |
+
[ 0., 0., 0., 0., 0.],
|
| 393 |
+
[ 0., 17., 0., 19., 0.]]]])
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
kernel_size: _size_2_t
|
| 399 |
+
stride: _size_2_t
|
| 400 |
+
padding: _size_2_t
|
| 401 |
+
|
| 402 |
+
def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0) -> None:
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.kernel_size = _pair(kernel_size)
|
| 405 |
+
self.stride = _pair(stride if (stride is not None) else kernel_size)
|
| 406 |
+
self.padding = _pair(padding)
|
| 407 |
+
|
| 408 |
+
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 409 |
+
return F.max_unpool2d(input, indices, self.kernel_size, self.stride,
|
| 410 |
+
self.padding, output_size)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class MaxUnpool3d(_MaxUnpoolNd):
|
| 414 |
+
r"""Computes a partial inverse of :class:`MaxPool3d`.
|
| 415 |
+
|
| 416 |
+
:class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.
|
| 417 |
+
:class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d`
|
| 418 |
+
including the indices of the maximal values and computes a partial inverse
|
| 419 |
+
in which all non-maximal values are set to zero.
|
| 420 |
+
|
| 421 |
+
Note:
|
| 422 |
+
This operation may behave nondeterministically when the input indices has repeat values.
|
| 423 |
+
See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
|
| 424 |
+
|
| 425 |
+
.. note:: :class:`MaxPool3d` can map several input sizes to the same output
|
| 426 |
+
sizes. Hence, the inversion process can get ambiguous.
|
| 427 |
+
To accommodate this, you can provide the needed output size
|
| 428 |
+
as an additional argument :attr:`output_size` in the forward call.
|
| 429 |
+
See the Inputs section below.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
kernel_size (int or tuple): Size of the max pooling window.
|
| 433 |
+
stride (int or tuple): Stride of the max pooling window.
|
| 434 |
+
It is set to :attr:`kernel_size` by default.
|
| 435 |
+
padding (int or tuple): Padding that was added to the input
|
| 436 |
+
|
| 437 |
+
Inputs:
|
| 438 |
+
- `input`: the input Tensor to invert
|
| 439 |
+
- `indices`: the indices given out by :class:`~torch.nn.MaxPool3d`
|
| 440 |
+
- `output_size` (optional): the targeted output size
|
| 441 |
+
|
| 442 |
+
Shape:
|
| 443 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 444 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
|
| 445 |
+
|
| 446 |
+
.. math::
|
| 447 |
+
D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
|
| 448 |
+
|
| 449 |
+
.. math::
|
| 450 |
+
H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
|
| 451 |
+
|
| 452 |
+
.. math::
|
| 453 |
+
W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]}
|
| 454 |
+
|
| 455 |
+
or as given by :attr:`output_size` in the call operator
|
| 456 |
+
|
| 457 |
+
Example::
|
| 458 |
+
|
| 459 |
+
>>> # pool of square window of size=3, stride=2
|
| 460 |
+
>>> pool = nn.MaxPool3d(3, stride=2, return_indices=True)
|
| 461 |
+
>>> unpool = nn.MaxUnpool3d(3, stride=2)
|
| 462 |
+
>>> output, indices = pool(torch.randn(20, 16, 51, 33, 15))
|
| 463 |
+
>>> unpooled_output = unpool(output, indices)
|
| 464 |
+
>>> unpooled_output.size()
|
| 465 |
+
torch.Size([20, 16, 51, 33, 15])
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
kernel_size: _size_3_t
|
| 469 |
+
stride: _size_3_t
|
| 470 |
+
padding: _size_3_t
|
| 471 |
+
|
| 472 |
+
def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0) -> None:
|
| 473 |
+
super().__init__()
|
| 474 |
+
self.kernel_size = _triple(kernel_size)
|
| 475 |
+
self.stride = _triple(stride if (stride is not None) else kernel_size)
|
| 476 |
+
self.padding = _triple(padding)
|
| 477 |
+
|
| 478 |
+
def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 479 |
+
return F.max_unpool3d(input, indices, self.kernel_size, self.stride,
|
| 480 |
+
self.padding, output_size)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class _AvgPoolNd(Module):
|
| 484 |
+
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
|
| 485 |
+
|
| 486 |
+
def extra_repr(self) -> str:
|
| 487 |
+
return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}'
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class AvgPool1d(_AvgPoolNd):
|
| 491 |
+
r"""Applies a 1D average pooling over an input signal composed of several input planes.
|
| 492 |
+
|
| 493 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C, L)`,
|
| 494 |
+
output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k`
|
| 495 |
+
can be precisely described as:
|
| 496 |
+
|
| 497 |
+
.. math::
|
| 498 |
+
|
| 499 |
+
\text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1}
|
| 500 |
+
\text{input}(N_i, C_j, \text{stride} \times l + m)
|
| 501 |
+
|
| 502 |
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
| 503 |
+
for :attr:`padding` number of points.
|
| 504 |
+
|
| 505 |
+
Note:
|
| 506 |
+
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
|
| 507 |
+
or the input. Sliding windows that would start in the right padded region are ignored.
|
| 508 |
+
|
| 509 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be
|
| 510 |
+
an ``int`` or a one-element tuple.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
kernel_size: the size of the window
|
| 514 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 515 |
+
padding: implicit zero padding to be added on both sides
|
| 516 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 517 |
+
count_include_pad: when True, will include the zero-padding in the averaging calculation
|
| 518 |
+
|
| 519 |
+
Shape:
|
| 520 |
+
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
|
| 521 |
+
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
|
| 522 |
+
|
| 523 |
+
.. math::
|
| 524 |
+
L_{out} = \left\lfloor \frac{L_{in} +
|
| 525 |
+
2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
| 526 |
+
|
| 527 |
+
Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in}
|
| 528 |
+
+ \text{padding}`, we skip the last window as it would start in the right padded region, resulting in
|
| 529 |
+
:math:`L_{out}` being reduced by one.
|
| 530 |
+
|
| 531 |
+
Examples::
|
| 532 |
+
|
| 533 |
+
>>> # pool with window of size=3, stride=2
|
| 534 |
+
>>> m = nn.AvgPool1d(3, stride=2)
|
| 535 |
+
>>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]]))
|
| 536 |
+
tensor([[[2., 4., 6.]]])
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
kernel_size: _size_1_t
|
| 540 |
+
stride: _size_1_t
|
| 541 |
+
padding: _size_1_t
|
| 542 |
+
ceil_mode: bool
|
| 543 |
+
count_include_pad: bool
|
| 544 |
+
|
| 545 |
+
def __init__(self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _size_1_t = 0, ceil_mode: bool = False,
|
| 546 |
+
count_include_pad: bool = True) -> None:
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.kernel_size = _single(kernel_size)
|
| 549 |
+
self.stride = _single(stride if stride is not None else kernel_size)
|
| 550 |
+
self.padding = _single(padding)
|
| 551 |
+
self.ceil_mode = ceil_mode
|
| 552 |
+
self.count_include_pad = count_include_pad
|
| 553 |
+
|
| 554 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 555 |
+
return F.avg_pool1d(
|
| 556 |
+
input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
|
| 557 |
+
self.count_include_pad)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class AvgPool2d(_AvgPoolNd):
|
| 561 |
+
r"""Applies a 2D average pooling over an input signal composed of several input planes.
|
| 562 |
+
|
| 563 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
|
| 564 |
+
output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
|
| 565 |
+
can be precisely described as:
|
| 566 |
+
|
| 567 |
+
.. math::
|
| 568 |
+
|
| 569 |
+
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
|
| 570 |
+
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
|
| 571 |
+
|
| 572 |
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
|
| 573 |
+
for :attr:`padding` number of points.
|
| 574 |
+
|
| 575 |
+
Note:
|
| 576 |
+
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
|
| 577 |
+
or the input. Sliding windows that would start in the right padded region are ignored.
|
| 578 |
+
|
| 579 |
+
The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be:
|
| 580 |
+
|
| 581 |
+
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
| 582 |
+
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
| 583 |
+
and the second `int` for the width dimension
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
kernel_size: the size of the window
|
| 587 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 588 |
+
padding: implicit zero padding to be added on both sides
|
| 589 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 590 |
+
count_include_pad: when True, will include the zero-padding in the averaging calculation
|
| 591 |
+
divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used.
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
Shape:
|
| 595 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 596 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 597 |
+
|
| 598 |
+
.. math::
|
| 599 |
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
|
| 600 |
+
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
| 601 |
+
|
| 602 |
+
.. math::
|
| 603 |
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
|
| 604 |
+
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
| 605 |
+
|
| 606 |
+
Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in}
|
| 607 |
+
+ \text{padding}[0]`, we skip the last window as it would start in the bottom padded region,
|
| 608 |
+
resulting in :math:`H_{out}` being reduced by one.
|
| 609 |
+
|
| 610 |
+
The same applies for :math:`W_{out}`.
|
| 611 |
+
|
| 612 |
+
Examples::
|
| 613 |
+
|
| 614 |
+
>>> # pool of square window of size=3, stride=2
|
| 615 |
+
>>> m = nn.AvgPool2d(3, stride=2)
|
| 616 |
+
>>> # pool of non-square window
|
| 617 |
+
>>> m = nn.AvgPool2d((3, 2), stride=(2, 1))
|
| 618 |
+
>>> input = torch.randn(20, 16, 50, 32)
|
| 619 |
+
>>> output = m(input)
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override']
|
| 623 |
+
|
| 624 |
+
kernel_size: _size_2_t
|
| 625 |
+
stride: _size_2_t
|
| 626 |
+
padding: _size_2_t
|
| 627 |
+
ceil_mode: bool
|
| 628 |
+
count_include_pad: bool
|
| 629 |
+
|
| 630 |
+
def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0,
|
| 631 |
+
ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None:
|
| 632 |
+
super().__init__()
|
| 633 |
+
self.kernel_size = kernel_size
|
| 634 |
+
self.stride = stride if (stride is not None) else kernel_size
|
| 635 |
+
self.padding = padding
|
| 636 |
+
self.ceil_mode = ceil_mode
|
| 637 |
+
self.count_include_pad = count_include_pad
|
| 638 |
+
self.divisor_override = divisor_override
|
| 639 |
+
|
| 640 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 641 |
+
return F.avg_pool2d(input, self.kernel_size, self.stride,
|
| 642 |
+
self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class AvgPool3d(_AvgPoolNd):
|
| 646 |
+
r"""Applies a 3D average pooling over an input signal composed of several input planes.
|
| 647 |
+
|
| 648 |
+
In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
|
| 649 |
+
output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
|
| 650 |
+
can be precisely described as:
|
| 651 |
+
|
| 652 |
+
.. math::
|
| 653 |
+
\begin{aligned}
|
| 654 |
+
\text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
|
| 655 |
+
& \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k,
|
| 656 |
+
\text{stride}[1] \times h + m, \text{stride}[2] \times w + n)}
|
| 657 |
+
{kD \times kH \times kW}
|
| 658 |
+
\end{aligned}
|
| 659 |
+
|
| 660 |
+
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
|
| 661 |
+
for :attr:`padding` number of points.
|
| 662 |
+
|
| 663 |
+
Note:
|
| 664 |
+
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
|
| 665 |
+
or the input. Sliding windows that would start in the right padded region are ignored.
|
| 666 |
+
|
| 667 |
+
The parameters :attr:`kernel_size`, :attr:`stride` can either be:
|
| 668 |
+
|
| 669 |
+
- a single ``int`` -- in which case the same value is used for the depth, height and width dimension
|
| 670 |
+
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
| 671 |
+
the second `int` for the height dimension and the third `int` for the width dimension
|
| 672 |
+
|
| 673 |
+
Args:
|
| 674 |
+
kernel_size: the size of the window
|
| 675 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 676 |
+
padding: implicit zero padding to be added on all three sides
|
| 677 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 678 |
+
count_include_pad: when True, will include the zero-padding in the averaging calculation
|
| 679 |
+
divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used
|
| 680 |
+
|
| 681 |
+
Shape:
|
| 682 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 683 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
|
| 684 |
+
:math:`(C, D_{out}, H_{out}, W_{out})`, where
|
| 685 |
+
|
| 686 |
+
.. math::
|
| 687 |
+
D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
|
| 688 |
+
\text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
| 689 |
+
|
| 690 |
+
.. math::
|
| 691 |
+
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
|
| 692 |
+
\text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
| 693 |
+
|
| 694 |
+
.. math::
|
| 695 |
+
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
|
| 696 |
+
\text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
| 697 |
+
|
| 698 |
+
Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in}
|
| 699 |
+
+ \text{padding}[0]`, we skip the last window as it would start in the padded region,
|
| 700 |
+
resulting in :math:`D_{out}` being reduced by one.
|
| 701 |
+
|
| 702 |
+
The same applies for :math:`W_{out}` and :math:`H_{out}`.
|
| 703 |
+
|
| 704 |
+
Examples::
|
| 705 |
+
|
| 706 |
+
>>> # pool of square window of size=3, stride=2
|
| 707 |
+
>>> m = nn.AvgPool3d(3, stride=2)
|
| 708 |
+
>>> # pool of non-square window
|
| 709 |
+
>>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))
|
| 710 |
+
>>> input = torch.randn(20, 16, 50, 44, 31)
|
| 711 |
+
>>> output = m(input)
|
| 712 |
+
"""
|
| 713 |
+
|
| 714 |
+
__constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override']
|
| 715 |
+
|
| 716 |
+
kernel_size: _size_3_t
|
| 717 |
+
stride: _size_3_t
|
| 718 |
+
padding: _size_3_t
|
| 719 |
+
ceil_mode: bool
|
| 720 |
+
count_include_pad: bool
|
| 721 |
+
|
| 722 |
+
def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0,
|
| 723 |
+
ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None:
|
| 724 |
+
super().__init__()
|
| 725 |
+
self.kernel_size = kernel_size
|
| 726 |
+
self.stride = stride if (stride is not None) else kernel_size
|
| 727 |
+
self.padding = padding
|
| 728 |
+
self.ceil_mode = ceil_mode
|
| 729 |
+
self.count_include_pad = count_include_pad
|
| 730 |
+
self.divisor_override = divisor_override
|
| 731 |
+
|
| 732 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 733 |
+
return F.avg_pool3d(input, self.kernel_size, self.stride,
|
| 734 |
+
self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override)
|
| 735 |
+
|
| 736 |
+
def __setstate__(self, d):
|
| 737 |
+
super().__setstate__(d)
|
| 738 |
+
self.__dict__.setdefault('padding', 0)
|
| 739 |
+
self.__dict__.setdefault('ceil_mode', False)
|
| 740 |
+
self.__dict__.setdefault('count_include_pad', True)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class FractionalMaxPool2d(Module):
|
| 744 |
+
r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
|
| 745 |
+
|
| 746 |
+
Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
|
| 747 |
+
|
| 748 |
+
The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
|
| 749 |
+
step size determined by the target output size.
|
| 750 |
+
The number of output features is equal to the number of input planes.
|
| 751 |
+
|
| 752 |
+
.. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
kernel_size: the size of the window to take a max over.
|
| 756 |
+
Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)`
|
| 757 |
+
output_size: the target output size of the image of the form `oH x oW`.
|
| 758 |
+
Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`.
|
| 759 |
+
Note that we must have :math:`kH + oH - 1 <= H_{in}` and :math:`kW + oW - 1 <= W_{in}`
|
| 760 |
+
output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
|
| 761 |
+
This has to be a number or tuple in the range (0, 1).
|
| 762 |
+
Note that we must have :math:`kH + (output\_ratio\_H * H_{in}) - 1 <= H_{in}`
|
| 763 |
+
and :math:`kW + (output\_ratio\_W * W_{in}) - 1 <= W_{in}`
|
| 764 |
+
return_indices: if ``True``, will return the indices along with the outputs.
|
| 765 |
+
Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False``
|
| 766 |
+
|
| 767 |
+
Shape:
|
| 768 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 769 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 770 |
+
:math:`(H_{out}, W_{out})=\text{output\_size}` or
|
| 771 |
+
:math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`.
|
| 772 |
+
|
| 773 |
+
Examples:
|
| 774 |
+
>>> # pool of square window of size=3, and target output size 13x12
|
| 775 |
+
>>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12))
|
| 776 |
+
>>> # pool of square window and target output size being half of input image size
|
| 777 |
+
>>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
|
| 778 |
+
>>> input = torch.randn(20, 16, 50, 32)
|
| 779 |
+
>>> output = m(input)
|
| 780 |
+
|
| 781 |
+
.. _Fractional MaxPooling:
|
| 782 |
+
https://arxiv.org/abs/1412.6071
|
| 783 |
+
"""
|
| 784 |
+
|
| 785 |
+
__constants__ = ['kernel_size', 'return_indices', 'output_size',
|
| 786 |
+
'output_ratio']
|
| 787 |
+
|
| 788 |
+
kernel_size: _size_2_t
|
| 789 |
+
return_indices: bool
|
| 790 |
+
output_size: _size_2_t
|
| 791 |
+
output_ratio: _ratio_2_t
|
| 792 |
+
|
| 793 |
+
def __init__(self, kernel_size: _size_2_t, output_size: Optional[_size_2_t] = None,
|
| 794 |
+
output_ratio: Optional[_ratio_2_t] = None,
|
| 795 |
+
return_indices: bool = False, _random_samples=None) -> None:
|
| 796 |
+
super().__init__()
|
| 797 |
+
self.kernel_size = _pair(kernel_size)
|
| 798 |
+
self.return_indices = return_indices
|
| 799 |
+
self.register_buffer('_random_samples', _random_samples)
|
| 800 |
+
self.output_size = _pair(output_size) if output_size is not None else None
|
| 801 |
+
self.output_ratio = _pair(output_ratio) if output_ratio is not None else None
|
| 802 |
+
if output_size is None and output_ratio is None:
|
| 803 |
+
raise ValueError("FractionalMaxPool2d requires specifying either "
|
| 804 |
+
"an output size, or a pooling ratio")
|
| 805 |
+
if output_size is not None and output_ratio is not None:
|
| 806 |
+
raise ValueError("only one of output_size and output_ratio may be specified")
|
| 807 |
+
if self.output_ratio is not None:
|
| 808 |
+
if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
|
| 809 |
+
raise ValueError(f"output_ratio must be between 0 and 1 (got {output_ratio})")
|
| 810 |
+
|
| 811 |
+
def forward(self, input: Tensor):
|
| 812 |
+
return F.fractional_max_pool2d(
|
| 813 |
+
input, self.kernel_size, self.output_size, self.output_ratio,
|
| 814 |
+
self.return_indices,
|
| 815 |
+
_random_samples=self._random_samples)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
class FractionalMaxPool3d(Module):
|
| 819 |
+
r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
|
| 820 |
+
|
| 821 |
+
Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
|
| 822 |
+
|
| 823 |
+
The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
|
| 824 |
+
step size determined by the target output size.
|
| 825 |
+
The number of output features is equal to the number of input planes.
|
| 826 |
+
|
| 827 |
+
.. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
|
| 828 |
+
|
| 829 |
+
Args:
|
| 830 |
+
kernel_size: the size of the window to take a max over.
|
| 831 |
+
Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)`
|
| 832 |
+
output_size: the target output size of the image of the form `oT x oH x oW`.
|
| 833 |
+
Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH`
|
| 834 |
+
output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
|
| 835 |
+
This has to be a number or tuple in the range (0, 1)
|
| 836 |
+
return_indices: if ``True``, will return the indices along with the outputs.
|
| 837 |
+
Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False``
|
| 838 |
+
|
| 839 |
+
Shape:
|
| 840 |
+
- Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
|
| 841 |
+
- Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
|
| 842 |
+
:math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
|
| 843 |
+
:math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
|
| 844 |
+
|
| 845 |
+
Examples:
|
| 846 |
+
>>> # pool of cubic window of size=3, and target output size 13x12x11
|
| 847 |
+
>>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11))
|
| 848 |
+
>>> # pool of cubic window and target output size being half of input size
|
| 849 |
+
>>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5))
|
| 850 |
+
>>> input = torch.randn(20, 16, 50, 32, 16)
|
| 851 |
+
>>> output = m(input)
|
| 852 |
+
|
| 853 |
+
.. _Fractional MaxPooling:
|
| 854 |
+
https://arxiv.org/abs/1412.6071
|
| 855 |
+
"""
|
| 856 |
+
|
| 857 |
+
__constants__ = ['kernel_size', 'return_indices', 'output_size',
|
| 858 |
+
'output_ratio']
|
| 859 |
+
kernel_size: _size_3_t
|
| 860 |
+
return_indices: bool
|
| 861 |
+
output_size: _size_3_t
|
| 862 |
+
output_ratio: _ratio_3_t
|
| 863 |
+
|
| 864 |
+
def __init__(self, kernel_size: _size_3_t, output_size: Optional[_size_3_t] = None,
|
| 865 |
+
output_ratio: Optional[_ratio_3_t] = None,
|
| 866 |
+
return_indices: bool = False, _random_samples=None) -> None:
|
| 867 |
+
super().__init__()
|
| 868 |
+
self.kernel_size = _triple(kernel_size)
|
| 869 |
+
self.return_indices = return_indices
|
| 870 |
+
self.register_buffer('_random_samples', _random_samples)
|
| 871 |
+
self.output_size = _triple(output_size) if output_size is not None else None
|
| 872 |
+
self.output_ratio = _triple(output_ratio) if output_ratio is not None else None
|
| 873 |
+
if output_size is None and output_ratio is None:
|
| 874 |
+
raise ValueError("FractionalMaxPool3d requires specifying either "
|
| 875 |
+
"an output size, or a pooling ratio")
|
| 876 |
+
if output_size is not None and output_ratio is not None:
|
| 877 |
+
raise ValueError("only one of output_size and output_ratio may be specified")
|
| 878 |
+
if self.output_ratio is not None:
|
| 879 |
+
if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1 and 0 < self.output_ratio[2] < 1):
|
| 880 |
+
raise ValueError(f"output_ratio must be between 0 and 1 (got {output_ratio})")
|
| 881 |
+
|
| 882 |
+
def forward(self, input: Tensor):
|
| 883 |
+
return F.fractional_max_pool3d(
|
| 884 |
+
input, self.kernel_size, self.output_size, self.output_ratio,
|
| 885 |
+
self.return_indices,
|
| 886 |
+
_random_samples=self._random_samples)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
class _LPPoolNd(Module):
|
| 890 |
+
__constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
|
| 891 |
+
|
| 892 |
+
norm_type: float
|
| 893 |
+
ceil_mode: bool
|
| 894 |
+
|
| 895 |
+
def __init__(self, norm_type: float, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None,
|
| 896 |
+
ceil_mode: bool = False) -> None:
|
| 897 |
+
super().__init__()
|
| 898 |
+
self.norm_type = norm_type
|
| 899 |
+
self.kernel_size = kernel_size
|
| 900 |
+
self.stride = stride
|
| 901 |
+
self.ceil_mode = ceil_mode
|
| 902 |
+
|
| 903 |
+
def extra_repr(self) -> str:
|
| 904 |
+
return 'norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, ' \
|
| 905 |
+
'ceil_mode={ceil_mode}'.format(**self.__dict__)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
class LPPool1d(_LPPoolNd):
|
| 909 |
+
r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
|
| 910 |
+
|
| 911 |
+
On each window, the function computed is:
|
| 912 |
+
|
| 913 |
+
.. math::
|
| 914 |
+
f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
|
| 915 |
+
|
| 916 |
+
- At p = :math:`\infty`, one gets Max Pooling
|
| 917 |
+
- At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
|
| 918 |
+
|
| 919 |
+
.. note:: If the sum to the power of `p` is zero, the gradient of this function is
|
| 920 |
+
not defined. This implementation will set the gradient to zero in this case.
|
| 921 |
+
|
| 922 |
+
Args:
|
| 923 |
+
kernel_size: a single int, the size of the window
|
| 924 |
+
stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
|
| 925 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 926 |
+
|
| 927 |
+
Shape:
|
| 928 |
+
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
|
| 929 |
+
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
|
| 930 |
+
|
| 931 |
+
.. math::
|
| 932 |
+
L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
|
| 933 |
+
|
| 934 |
+
Examples::
|
| 935 |
+
>>> # power-2 pool of window of length 3, with stride 2.
|
| 936 |
+
>>> m = nn.LPPool1d(2, 3, stride=2)
|
| 937 |
+
>>> input = torch.randn(20, 16, 50)
|
| 938 |
+
>>> output = m(input)
|
| 939 |
+
"""
|
| 940 |
+
|
| 941 |
+
kernel_size: _size_1_t
|
| 942 |
+
stride: _size_1_t
|
| 943 |
+
|
| 944 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 945 |
+
return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
|
| 946 |
+
self.stride, self.ceil_mode)
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
class LPPool2d(_LPPoolNd):
|
| 950 |
+
r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
|
| 951 |
+
|
| 952 |
+
On each window, the function computed is:
|
| 953 |
+
|
| 954 |
+
.. math::
|
| 955 |
+
f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
|
| 956 |
+
|
| 957 |
+
- At p = :math:`\infty`, one gets Max Pooling
|
| 958 |
+
- At p = 1, one gets Sum Pooling (which is proportional to average pooling)
|
| 959 |
+
|
| 960 |
+
The parameters :attr:`kernel_size`, :attr:`stride` can either be:
|
| 961 |
+
|
| 962 |
+
- a single ``int`` -- in which case the same value is used for the height and width dimension
|
| 963 |
+
- a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
|
| 964 |
+
and the second `int` for the width dimension
|
| 965 |
+
|
| 966 |
+
.. note:: If the sum to the power of `p` is zero, the gradient of this function is
|
| 967 |
+
not defined. This implementation will set the gradient to zero in this case.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
kernel_size: the size of the window
|
| 971 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 972 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 973 |
+
|
| 974 |
+
Shape:
|
| 975 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 976 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 977 |
+
|
| 978 |
+
.. math::
|
| 979 |
+
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
| 980 |
+
|
| 981 |
+
.. math::
|
| 982 |
+
W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
| 983 |
+
|
| 984 |
+
Examples::
|
| 985 |
+
|
| 986 |
+
>>> # power-2 pool of square window of size=3, stride=2
|
| 987 |
+
>>> m = nn.LPPool2d(2, 3, stride=2)
|
| 988 |
+
>>> # pool of non-square window of power 1.2
|
| 989 |
+
>>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1))
|
| 990 |
+
>>> input = torch.randn(20, 16, 50, 32)
|
| 991 |
+
>>> output = m(input)
|
| 992 |
+
|
| 993 |
+
"""
|
| 994 |
+
|
| 995 |
+
kernel_size: _size_2_t
|
| 996 |
+
stride: _size_2_t
|
| 997 |
+
|
| 998 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 999 |
+
return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
|
| 1000 |
+
self.stride, self.ceil_mode)
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
class LPPool3d(_LPPoolNd):
|
| 1004 |
+
r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
|
| 1005 |
+
|
| 1006 |
+
On each window, the function computed is:
|
| 1007 |
+
|
| 1008 |
+
.. math::
|
| 1009 |
+
f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
|
| 1010 |
+
|
| 1011 |
+
- At p = :math:`\infty`, one gets Max Pooling
|
| 1012 |
+
- At p = 1, one gets Sum Pooling (which is proportional to average pooling)
|
| 1013 |
+
|
| 1014 |
+
The parameters :attr:`kernel_size`, :attr:`stride` can either be:
|
| 1015 |
+
|
| 1016 |
+
- a single ``int`` -- in which case the same value is used for the height, width and depth dimension
|
| 1017 |
+
- a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
|
| 1018 |
+
the second `int` for the height dimension and the third `int` for the width dimension
|
| 1019 |
+
|
| 1020 |
+
.. note:: If the sum to the power of `p` is zero, the gradient of this function is
|
| 1021 |
+
not defined. This implementation will set the gradient to zero in this case.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
kernel_size: the size of the window
|
| 1025 |
+
stride: the stride of the window. Default value is :attr:`kernel_size`
|
| 1026 |
+
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
|
| 1027 |
+
|
| 1028 |
+
Shape:
|
| 1029 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 1030 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
|
| 1031 |
+
:math:`(C, D_{out}, H_{out}, W_{out})`, where
|
| 1032 |
+
|
| 1033 |
+
.. math::
|
| 1034 |
+
D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
|
| 1035 |
+
|
| 1036 |
+
.. math::
|
| 1037 |
+
H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
|
| 1038 |
+
|
| 1039 |
+
.. math::
|
| 1040 |
+
W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
|
| 1041 |
+
|
| 1042 |
+
Examples::
|
| 1043 |
+
|
| 1044 |
+
>>> # power-2 pool of square window of size=3, stride=2
|
| 1045 |
+
>>> m = nn.LPPool3d(2, 3, stride=2)
|
| 1046 |
+
>>> # pool of non-square window of power 1.2
|
| 1047 |
+
>>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2))
|
| 1048 |
+
>>> input = torch.randn(20, 16, 50, 44, 31)
|
| 1049 |
+
>>> output = m(input)
|
| 1050 |
+
|
| 1051 |
+
"""
|
| 1052 |
+
|
| 1053 |
+
kernel_size: _size_3_t
|
| 1054 |
+
stride: _size_3_t
|
| 1055 |
+
|
| 1056 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1057 |
+
return F.lp_pool3d(input, float(self.norm_type), self.kernel_size,
|
| 1058 |
+
self.stride, self.ceil_mode)
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
class _AdaptiveMaxPoolNd(Module):
|
| 1062 |
+
__constants__ = ['output_size', 'return_indices']
|
| 1063 |
+
return_indices: bool
|
| 1064 |
+
|
| 1065 |
+
def __init__(self, output_size: _size_any_opt_t, return_indices: bool = False) -> None:
|
| 1066 |
+
super().__init__()
|
| 1067 |
+
self.output_size = output_size
|
| 1068 |
+
self.return_indices = return_indices
|
| 1069 |
+
|
| 1070 |
+
def extra_repr(self) -> str:
|
| 1071 |
+
return f'output_size={self.output_size}'
|
| 1072 |
+
|
| 1073 |
+
# FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and
|
| 1074 |
+
# output shapes are, and how the operation computes output.
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
|
| 1078 |
+
r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
|
| 1079 |
+
|
| 1080 |
+
The output size is :math:`L_{out}`, for any input size.
|
| 1081 |
+
The number of output features is equal to the number of input planes.
|
| 1082 |
+
|
| 1083 |
+
Args:
|
| 1084 |
+
output_size: the target output size :math:`L_{out}`.
|
| 1085 |
+
return_indices: if ``True``, will return the indices along with the outputs.
|
| 1086 |
+
Useful to pass to nn.MaxUnpool1d. Default: ``False``
|
| 1087 |
+
|
| 1088 |
+
Shape:
|
| 1089 |
+
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
|
| 1090 |
+
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
|
| 1091 |
+
:math:`L_{out}=\text{output\_size}`.
|
| 1092 |
+
|
| 1093 |
+
Examples:
|
| 1094 |
+
>>> # target output size of 5
|
| 1095 |
+
>>> m = nn.AdaptiveMaxPool1d(5)
|
| 1096 |
+
>>> input = torch.randn(1, 64, 8)
|
| 1097 |
+
>>> output = m(input)
|
| 1098 |
+
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
output_size: _size_1_t
|
| 1102 |
+
|
| 1103 |
+
def forward(self, input: Tensor):
|
| 1104 |
+
return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
|
| 1108 |
+
r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
|
| 1109 |
+
|
| 1110 |
+
The output is of size :math:`H_{out} \times W_{out}`, for any input size.
|
| 1111 |
+
The number of output features is equal to the number of input planes.
|
| 1112 |
+
|
| 1113 |
+
Args:
|
| 1114 |
+
output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`.
|
| 1115 |
+
Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a
|
| 1116 |
+
square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}`
|
| 1117 |
+
can be either a ``int``, or ``None`` which means the size will be the same as that
|
| 1118 |
+
of the input.
|
| 1119 |
+
return_indices: if ``True``, will return the indices along with the outputs.
|
| 1120 |
+
Useful to pass to nn.MaxUnpool2d. Default: ``False``
|
| 1121 |
+
|
| 1122 |
+
Shape:
|
| 1123 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 1124 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
|
| 1125 |
+
:math:`(H_{out}, W_{out})=\text{output\_size}`.
|
| 1126 |
+
|
| 1127 |
+
Examples:
|
| 1128 |
+
>>> # target output size of 5x7
|
| 1129 |
+
>>> m = nn.AdaptiveMaxPool2d((5, 7))
|
| 1130 |
+
>>> input = torch.randn(1, 64, 8, 9)
|
| 1131 |
+
>>> output = m(input)
|
| 1132 |
+
>>> # target output size of 7x7 (square)
|
| 1133 |
+
>>> m = nn.AdaptiveMaxPool2d(7)
|
| 1134 |
+
>>> input = torch.randn(1, 64, 10, 9)
|
| 1135 |
+
>>> output = m(input)
|
| 1136 |
+
>>> # target output size of 10x7
|
| 1137 |
+
>>> m = nn.AdaptiveMaxPool2d((None, 7))
|
| 1138 |
+
>>> input = torch.randn(1, 64, 10, 9)
|
| 1139 |
+
>>> output = m(input)
|
| 1140 |
+
|
| 1141 |
+
"""
|
| 1142 |
+
|
| 1143 |
+
output_size: _size_2_opt_t
|
| 1144 |
+
|
| 1145 |
+
def forward(self, input: Tensor):
|
| 1146 |
+
return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
|
| 1150 |
+
r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
|
| 1151 |
+
|
| 1152 |
+
The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
|
| 1153 |
+
The number of output features is equal to the number of input planes.
|
| 1154 |
+
|
| 1155 |
+
Args:
|
| 1156 |
+
output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`.
|
| 1157 |
+
Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single
|
| 1158 |
+
:math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`.
|
| 1159 |
+
:math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` can be either a
|
| 1160 |
+
``int``, or ``None`` which means the size will be the same as that of the input.
|
| 1161 |
+
|
| 1162 |
+
return_indices: if ``True``, will return the indices along with the outputs.
|
| 1163 |
+
Useful to pass to nn.MaxUnpool3d. Default: ``False``
|
| 1164 |
+
|
| 1165 |
+
Shape:
|
| 1166 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 1167 |
+
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
|
| 1168 |
+
where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
|
| 1169 |
+
|
| 1170 |
+
Examples:
|
| 1171 |
+
>>> # target output size of 5x7x9
|
| 1172 |
+
>>> m = nn.AdaptiveMaxPool3d((5, 7, 9))
|
| 1173 |
+
>>> input = torch.randn(1, 64, 8, 9, 10)
|
| 1174 |
+
>>> output = m(input)
|
| 1175 |
+
>>> # target output size of 7x7x7 (cube)
|
| 1176 |
+
>>> m = nn.AdaptiveMaxPool3d(7)
|
| 1177 |
+
>>> input = torch.randn(1, 64, 10, 9, 8)
|
| 1178 |
+
>>> output = m(input)
|
| 1179 |
+
>>> # target output size of 7x9x8
|
| 1180 |
+
>>> m = nn.AdaptiveMaxPool3d((7, None, None))
|
| 1181 |
+
>>> input = torch.randn(1, 64, 10, 9, 8)
|
| 1182 |
+
>>> output = m(input)
|
| 1183 |
+
|
| 1184 |
+
"""
|
| 1185 |
+
|
| 1186 |
+
output_size: _size_3_opt_t
|
| 1187 |
+
|
| 1188 |
+
def forward(self, input: Tensor):
|
| 1189 |
+
return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
class _AdaptiveAvgPoolNd(Module):
|
| 1193 |
+
__constants__ = ['output_size']
|
| 1194 |
+
|
| 1195 |
+
def __init__(self, output_size: _size_any_opt_t) -> None:
|
| 1196 |
+
super().__init__()
|
| 1197 |
+
self.output_size = output_size
|
| 1198 |
+
|
| 1199 |
+
def extra_repr(self) -> str:
|
| 1200 |
+
return f'output_size={self.output_size}'
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
| 1204 |
+
r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
| 1205 |
+
|
| 1206 |
+
The output size is :math:`L_{out}`, for any input size.
|
| 1207 |
+
The number of output features is equal to the number of input planes.
|
| 1208 |
+
|
| 1209 |
+
Args:
|
| 1210 |
+
output_size: the target output size :math:`L_{out}`.
|
| 1211 |
+
|
| 1212 |
+
Shape:
|
| 1213 |
+
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
|
| 1214 |
+
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
|
| 1215 |
+
:math:`L_{out}=\text{output\_size}`.
|
| 1216 |
+
|
| 1217 |
+
Examples:
|
| 1218 |
+
>>> # target output size of 5
|
| 1219 |
+
>>> m = nn.AdaptiveAvgPool1d(5)
|
| 1220 |
+
>>> input = torch.randn(1, 64, 8)
|
| 1221 |
+
>>> output = m(input)
|
| 1222 |
+
|
| 1223 |
+
"""
|
| 1224 |
+
|
| 1225 |
+
output_size: _size_1_t
|
| 1226 |
+
|
| 1227 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1228 |
+
return F.adaptive_avg_pool1d(input, self.output_size)
|
| 1229 |
+
|
| 1230 |
+
|
| 1231 |
+
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
| 1232 |
+
r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
| 1233 |
+
|
| 1234 |
+
The output is of size H x W, for any input size.
|
| 1235 |
+
The number of output features is equal to the number of input planes.
|
| 1236 |
+
|
| 1237 |
+
Args:
|
| 1238 |
+
output_size: the target output size of the image of the form H x W.
|
| 1239 |
+
Can be a tuple (H, W) or a single H for a square image H x H.
|
| 1240 |
+
H and W can be either a ``int``, or ``None`` which means the size will
|
| 1241 |
+
be the same as that of the input.
|
| 1242 |
+
|
| 1243 |
+
Shape:
|
| 1244 |
+
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
|
| 1245 |
+
- Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where
|
| 1246 |
+
:math:`S=\text{output\_size}`.
|
| 1247 |
+
|
| 1248 |
+
Examples:
|
| 1249 |
+
>>> # target output size of 5x7
|
| 1250 |
+
>>> m = nn.AdaptiveAvgPool2d((5, 7))
|
| 1251 |
+
>>> input = torch.randn(1, 64, 8, 9)
|
| 1252 |
+
>>> output = m(input)
|
| 1253 |
+
>>> # target output size of 7x7 (square)
|
| 1254 |
+
>>> m = nn.AdaptiveAvgPool2d(7)
|
| 1255 |
+
>>> input = torch.randn(1, 64, 10, 9)
|
| 1256 |
+
>>> output = m(input)
|
| 1257 |
+
>>> # target output size of 10x7
|
| 1258 |
+
>>> m = nn.AdaptiveAvgPool2d((None, 7))
|
| 1259 |
+
>>> input = torch.randn(1, 64, 10, 9)
|
| 1260 |
+
>>> output = m(input)
|
| 1261 |
+
|
| 1262 |
+
"""
|
| 1263 |
+
|
| 1264 |
+
output_size: _size_2_opt_t
|
| 1265 |
+
|
| 1266 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1267 |
+
return F.adaptive_avg_pool2d(input, self.output_size)
|
| 1268 |
+
|
| 1269 |
+
|
| 1270 |
+
class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
|
| 1271 |
+
r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
|
| 1272 |
+
|
| 1273 |
+
The output is of size D x H x W, for any input size.
|
| 1274 |
+
The number of output features is equal to the number of input planes.
|
| 1275 |
+
|
| 1276 |
+
Args:
|
| 1277 |
+
output_size: the target output size of the form D x H x W.
|
| 1278 |
+
Can be a tuple (D, H, W) or a single number D for a cube D x D x D.
|
| 1279 |
+
D, H and W can be either a ``int``, or ``None`` which means the size will
|
| 1280 |
+
be the same as that of the input.
|
| 1281 |
+
|
| 1282 |
+
Shape:
|
| 1283 |
+
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
|
| 1284 |
+
- Output: :math:`(N, C, S_{0}, S_{1}, S_{2})` or :math:`(C, S_{0}, S_{1}, S_{2})`,
|
| 1285 |
+
where :math:`S=\text{output\_size}`.
|
| 1286 |
+
|
| 1287 |
+
Examples:
|
| 1288 |
+
>>> # target output size of 5x7x9
|
| 1289 |
+
>>> m = nn.AdaptiveAvgPool3d((5, 7, 9))
|
| 1290 |
+
>>> input = torch.randn(1, 64, 8, 9, 10)
|
| 1291 |
+
>>> output = m(input)
|
| 1292 |
+
>>> # target output size of 7x7x7 (cube)
|
| 1293 |
+
>>> m = nn.AdaptiveAvgPool3d(7)
|
| 1294 |
+
>>> input = torch.randn(1, 64, 10, 9, 8)
|
| 1295 |
+
>>> output = m(input)
|
| 1296 |
+
>>> # target output size of 7x9x8
|
| 1297 |
+
>>> m = nn.AdaptiveAvgPool3d((7, None, None))
|
| 1298 |
+
>>> input = torch.randn(1, 64, 10, 9, 8)
|
| 1299 |
+
>>> output = m(input)
|
| 1300 |
+
|
| 1301 |
+
"""
|
| 1302 |
+
|
| 1303 |
+
output_size: _size_3_opt_t
|
| 1304 |
+
|
| 1305 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 1306 |
+
return F.adaptive_avg_pool3d(input, self.output_size)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Optional, Any, Union, Callable
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import warnings
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from .. import functional as F
|
| 8 |
+
from .module import Module
|
| 9 |
+
from .activation import MultiheadAttention
|
| 10 |
+
from .container import ModuleList
|
| 11 |
+
from ..init import xavier_uniform_
|
| 12 |
+
from .dropout import Dropout
|
| 13 |
+
from .linear import Linear
|
| 14 |
+
from .normalization import LayerNorm
|
| 15 |
+
|
| 16 |
+
__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
|
| 17 |
+
|
| 18 |
+
def _generate_square_subsequent_mask(
|
| 19 |
+
sz: int,
|
| 20 |
+
device: Optional[torch.device] = None,
|
| 21 |
+
dtype: Optional[torch.dtype] = None,
|
| 22 |
+
) -> Tensor:
|
| 23 |
+
r"""Generate a square causal mask for the sequence.
|
| 24 |
+
|
| 25 |
+
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
| 26 |
+
"""
|
| 27 |
+
if device is None:
|
| 28 |
+
device = torch.device('cpu')
|
| 29 |
+
if dtype is None:
|
| 30 |
+
dtype = torch.float32
|
| 31 |
+
return torch.triu(
|
| 32 |
+
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
|
| 33 |
+
diagonal=1,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _get_seq_len(
|
| 38 |
+
src: Tensor,
|
| 39 |
+
batch_first: bool
|
| 40 |
+
) -> Optional[int]:
|
| 41 |
+
|
| 42 |
+
if src.is_nested:
|
| 43 |
+
return None
|
| 44 |
+
else:
|
| 45 |
+
src_size = src.size()
|
| 46 |
+
if len(src_size) == 2:
|
| 47 |
+
# unbatched: S, E
|
| 48 |
+
return src_size[0]
|
| 49 |
+
else:
|
| 50 |
+
# batched: B, S, E if batch_first else S, B, E
|
| 51 |
+
seq_len_pos = 1 if batch_first else 0
|
| 52 |
+
return src_size[seq_len_pos]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Transformer(Module):
|
| 56 |
+
r"""A transformer model.
|
| 57 |
+
|
| 58 |
+
User is able to modify the attributes as needed. The architecture
|
| 59 |
+
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
| 60 |
+
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
| 61 |
+
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
| 62 |
+
Processing Systems, pages 6000-6010.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
| 66 |
+
nhead: the number of heads in the multiheadattention models (default=8).
|
| 67 |
+
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
| 68 |
+
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
| 69 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 70 |
+
dropout: the dropout value (default=0.1).
|
| 71 |
+
activation: the activation function of encoder/decoder intermediate layer, can be a string
|
| 72 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 73 |
+
custom_encoder: custom encoder (default=None).
|
| 74 |
+
custom_decoder: custom decoder (default=None).
|
| 75 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 76 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 77 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 78 |
+
norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
|
| 79 |
+
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
|
| 80 |
+
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
| 81 |
+
bias. Default: ``True``.
|
| 82 |
+
|
| 83 |
+
Examples::
|
| 84 |
+
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
| 85 |
+
>>> src = torch.rand((10, 32, 512))
|
| 86 |
+
>>> tgt = torch.rand((20, 32, 512))
|
| 87 |
+
>>> out = transformer_model(src, tgt)
|
| 88 |
+
|
| 89 |
+
Note: A full example to apply nn.Transformer module for the word language model is available in
|
| 90 |
+
https://github.com/pytorch/examples/tree/master/word_language_model
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
|
| 94 |
+
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 95 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 96 |
+
custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
|
| 97 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
| 98 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 99 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 100 |
+
super().__init__()
|
| 101 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
| 102 |
+
|
| 103 |
+
if custom_encoder is not None:
|
| 104 |
+
self.encoder = custom_encoder
|
| 105 |
+
else:
|
| 106 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
| 107 |
+
activation, layer_norm_eps, batch_first, norm_first,
|
| 108 |
+
bias, **factory_kwargs)
|
| 109 |
+
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 110 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 111 |
+
|
| 112 |
+
if custom_decoder is not None:
|
| 113 |
+
self.decoder = custom_decoder
|
| 114 |
+
else:
|
| 115 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
|
| 116 |
+
activation, layer_norm_eps, batch_first, norm_first,
|
| 117 |
+
bias, **factory_kwargs)
|
| 118 |
+
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 119 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
| 120 |
+
|
| 121 |
+
self._reset_parameters()
|
| 122 |
+
|
| 123 |
+
self.d_model = d_model
|
| 124 |
+
self.nhead = nhead
|
| 125 |
+
|
| 126 |
+
self.batch_first = batch_first
|
| 127 |
+
|
| 128 |
+
def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
|
| 129 |
+
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
|
| 130 |
+
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
|
| 131 |
+
src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
|
| 132 |
+
memory_is_causal: bool = False) -> Tensor:
|
| 133 |
+
r"""Take in and process masked source/target sequences.
|
| 134 |
+
|
| 135 |
+
.. note::
|
| 136 |
+
|
| 137 |
+
If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
|
| 138 |
+
not allowed to participate in the attention,
|
| 139 |
+
which is the opposite of the definition for :attr:`attn_mask`
|
| 140 |
+
in :func:`torch.nn.functional.scaled_dot_product_attention`.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
src: the sequence to the encoder (required).
|
| 144 |
+
tgt: the sequence to the decoder (required).
|
| 145 |
+
src_mask: the additive mask for the src sequence (optional).
|
| 146 |
+
tgt_mask: the additive mask for the tgt sequence (optional).
|
| 147 |
+
memory_mask: the additive mask for the encoder output (optional).
|
| 148 |
+
src_key_padding_mask: the Tensor mask for src keys per batch (optional).
|
| 149 |
+
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
|
| 150 |
+
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
|
| 151 |
+
src_is_causal: If specified, applies a causal mask as ``src_mask``.
|
| 152 |
+
Default: ``None``; try to detect a causal mask.
|
| 153 |
+
Warning:
|
| 154 |
+
``src_is_causal`` provides a hint that ``src_mask`` is
|
| 155 |
+
the causal mask. Providing incorrect hints can result in
|
| 156 |
+
incorrect execution, including forward and backward
|
| 157 |
+
compatibility.
|
| 158 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
|
| 159 |
+
Default: ``None``; try to detect a causal mask.
|
| 160 |
+
Warning:
|
| 161 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
| 162 |
+
the causal mask. Providing incorrect hints can result in
|
| 163 |
+
incorrect execution, including forward and backward
|
| 164 |
+
compatibility.
|
| 165 |
+
memory_is_causal: If specified, applies a causal mask as
|
| 166 |
+
``memory_mask``.
|
| 167 |
+
Default: ``False``.
|
| 168 |
+
Warning:
|
| 169 |
+
``memory_is_causal`` provides a hint that
|
| 170 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
| 171 |
+
hints can result in incorrect execution, including
|
| 172 |
+
forward and backward compatibility.
|
| 173 |
+
|
| 174 |
+
Shape:
|
| 175 |
+
- src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
|
| 176 |
+
`(N, S, E)` if `batch_first=True`.
|
| 177 |
+
- tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
| 178 |
+
`(N, T, E)` if `batch_first=True`.
|
| 179 |
+
- src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
|
| 180 |
+
- tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
|
| 181 |
+
- memory_mask: :math:`(T, S)`.
|
| 182 |
+
- src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
| 183 |
+
- tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
|
| 184 |
+
- memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
| 185 |
+
|
| 186 |
+
Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
|
| 187 |
+
positions. If a BoolTensor is provided, positions with ``True``
|
| 188 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
| 189 |
+
is provided, it will be added to the attention weight.
|
| 190 |
+
[src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
|
| 191 |
+
the attention. If a BoolTensor is provided, the positions with the
|
| 192 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
| 193 |
+
|
| 194 |
+
- output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
| 195 |
+
`(N, T, E)` if `batch_first=True`.
|
| 196 |
+
|
| 197 |
+
Note: Due to the multi-head attention architecture in the transformer model,
|
| 198 |
+
the output sequence length of a transformer is same as the input sequence
|
| 199 |
+
(i.e. target) length of the decoder.
|
| 200 |
+
|
| 201 |
+
where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
|
| 202 |
+
batch size, :math:`E` is the feature number
|
| 203 |
+
|
| 204 |
+
Examples:
|
| 205 |
+
>>> # xdoctest: +SKIP
|
| 206 |
+
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
| 207 |
+
"""
|
| 208 |
+
is_batched = src.dim() == 3
|
| 209 |
+
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
|
| 210 |
+
raise RuntimeError("the batch number of src and tgt must be equal")
|
| 211 |
+
elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
|
| 212 |
+
raise RuntimeError("the batch number of src and tgt must be equal")
|
| 213 |
+
|
| 214 |
+
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
|
| 215 |
+
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
|
| 216 |
+
|
| 217 |
+
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
|
| 218 |
+
is_causal=src_is_causal)
|
| 219 |
+
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
|
| 220 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 221 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 222 |
+
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
|
| 223 |
+
return output
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def generate_square_subsequent_mask(
|
| 227 |
+
sz: int,
|
| 228 |
+
device: Optional[torch.device] = None,
|
| 229 |
+
dtype: Optional[torch.dtype] = None,
|
| 230 |
+
) -> Tensor:
|
| 231 |
+
r"""Generate a square causal mask for the sequence.
|
| 232 |
+
|
| 233 |
+
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
| 234 |
+
"""
|
| 235 |
+
return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
|
| 236 |
+
|
| 237 |
+
def _reset_parameters(self):
|
| 238 |
+
r"""Initiate parameters in the transformer model."""
|
| 239 |
+
for p in self.parameters():
|
| 240 |
+
if p.dim() > 1:
|
| 241 |
+
xavier_uniform_(p)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class TransformerEncoder(Module):
|
| 245 |
+
r"""TransformerEncoder is a stack of N encoder layers.
|
| 246 |
+
|
| 247 |
+
Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
| 251 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
| 252 |
+
norm: the layer normalization component (optional).
|
| 253 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
| 254 |
+
(and convert back on output). This will improve the overall performance of
|
| 255 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
| 256 |
+
|
| 257 |
+
Examples::
|
| 258 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
| 259 |
+
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
| 260 |
+
>>> src = torch.rand(10, 32, 512)
|
| 261 |
+
>>> out = transformer_encoder(src)
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
__constants__ = ['norm']
|
| 265 |
+
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
encoder_layer: "TransformerEncoderLayer",
|
| 269 |
+
num_layers: int,
|
| 270 |
+
norm: Optional[Module] = None,
|
| 271 |
+
enable_nested_tensor: bool = True,
|
| 272 |
+
mask_check: bool = True
|
| 273 |
+
) -> None:
|
| 274 |
+
super().__init__()
|
| 275 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
| 276 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 277 |
+
self.num_layers = num_layers
|
| 278 |
+
self.norm = norm
|
| 279 |
+
# this attribute saves the value providedat object construction
|
| 280 |
+
self.enable_nested_tensor = enable_nested_tensor
|
| 281 |
+
# this attribute controls whether nested tensors are used
|
| 282 |
+
self.use_nested_tensor = enable_nested_tensor
|
| 283 |
+
self.mask_check = mask_check
|
| 284 |
+
|
| 285 |
+
enc_layer = "encoder_layer"
|
| 286 |
+
why_not_sparsity_fast_path = ''
|
| 287 |
+
if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
|
| 288 |
+
why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
|
| 289 |
+
elif encoder_layer.norm_first :
|
| 290 |
+
why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
|
| 291 |
+
elif not encoder_layer.self_attn.batch_first:
|
| 292 |
+
why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
|
| 293 |
+
"(use batch_first for better inference performance)")
|
| 294 |
+
elif not encoder_layer.self_attn._qkv_same_embed_dim:
|
| 295 |
+
why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
|
| 296 |
+
elif encoder_layer.self_attn.in_proj_bias is None:
|
| 297 |
+
why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
|
| 298 |
+
elif not encoder_layer.activation_relu_or_gelu:
|
| 299 |
+
why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
|
| 300 |
+
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
|
| 301 |
+
why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
|
| 302 |
+
elif encoder_layer.self_attn.num_heads % 2 == 1:
|
| 303 |
+
why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
|
| 304 |
+
|
| 305 |
+
if enable_nested_tensor and why_not_sparsity_fast_path:
|
| 306 |
+
warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
|
| 307 |
+
self.use_nested_tensor = False
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def forward(
|
| 311 |
+
self,
|
| 312 |
+
src: Tensor,
|
| 313 |
+
mask: Optional[Tensor] = None,
|
| 314 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 315 |
+
is_causal: Optional[bool] = None) -> Tensor:
|
| 316 |
+
r"""Pass the input through the encoder layers in turn.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
src: the sequence to the encoder (required).
|
| 320 |
+
mask: the mask for the src sequence (optional).
|
| 321 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 322 |
+
is_causal: If specified, applies a causal mask as ``mask``.
|
| 323 |
+
Default: ``None``; try to detect a causal mask.
|
| 324 |
+
Warning:
|
| 325 |
+
``is_causal`` provides a hint that ``mask`` is the
|
| 326 |
+
causal mask. Providing incorrect hints can result in
|
| 327 |
+
incorrect execution, including forward and backward
|
| 328 |
+
compatibility.
|
| 329 |
+
|
| 330 |
+
Shape:
|
| 331 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 332 |
+
"""
|
| 333 |
+
src_key_padding_mask = F._canonical_mask(
|
| 334 |
+
mask=src_key_padding_mask,
|
| 335 |
+
mask_name="src_key_padding_mask",
|
| 336 |
+
other_type=F._none_or_dtype(mask),
|
| 337 |
+
other_name="mask",
|
| 338 |
+
target_type=src.dtype
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
mask = F._canonical_mask(
|
| 342 |
+
mask=mask,
|
| 343 |
+
mask_name="mask",
|
| 344 |
+
other_type=None,
|
| 345 |
+
other_name="",
|
| 346 |
+
target_type=src.dtype,
|
| 347 |
+
check_other=False,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
output = src
|
| 351 |
+
convert_to_nested = False
|
| 352 |
+
first_layer = self.layers[0]
|
| 353 |
+
src_key_padding_mask_for_layers = src_key_padding_mask
|
| 354 |
+
why_not_sparsity_fast_path = ''
|
| 355 |
+
str_first_layer = "self.layers[0]"
|
| 356 |
+
batch_first = first_layer.self_attn.batch_first
|
| 357 |
+
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
| 358 |
+
|
| 359 |
+
if not is_fastpath_enabled:
|
| 360 |
+
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
| 361 |
+
elif not hasattr(self, "use_nested_tensor"):
|
| 362 |
+
why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
|
| 363 |
+
elif not self.use_nested_tensor:
|
| 364 |
+
why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
|
| 365 |
+
elif first_layer.training:
|
| 366 |
+
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
|
| 367 |
+
elif not src.dim() == 3:
|
| 368 |
+
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
| 369 |
+
elif src_key_padding_mask is None:
|
| 370 |
+
why_not_sparsity_fast_path = "src_key_padding_mask was None"
|
| 371 |
+
elif (((not hasattr(self, "mask_check")) or self.mask_check)
|
| 372 |
+
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
|
| 373 |
+
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
|
| 374 |
+
elif output.is_nested:
|
| 375 |
+
why_not_sparsity_fast_path = "NestedTensor input is not supported"
|
| 376 |
+
elif mask is not None:
|
| 377 |
+
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
|
| 378 |
+
elif torch.is_autocast_enabled():
|
| 379 |
+
why_not_sparsity_fast_path = "autocast is enabled"
|
| 380 |
+
|
| 381 |
+
if not why_not_sparsity_fast_path:
|
| 382 |
+
tensor_args = (
|
| 383 |
+
src,
|
| 384 |
+
first_layer.self_attn.in_proj_weight,
|
| 385 |
+
first_layer.self_attn.in_proj_bias,
|
| 386 |
+
first_layer.self_attn.out_proj.weight,
|
| 387 |
+
first_layer.self_attn.out_proj.bias,
|
| 388 |
+
first_layer.norm1.weight,
|
| 389 |
+
first_layer.norm1.bias,
|
| 390 |
+
first_layer.norm2.weight,
|
| 391 |
+
first_layer.norm2.bias,
|
| 392 |
+
first_layer.linear1.weight,
|
| 393 |
+
first_layer.linear1.bias,
|
| 394 |
+
first_layer.linear2.weight,
|
| 395 |
+
first_layer.linear2.bias,
|
| 396 |
+
)
|
| 397 |
+
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
| 398 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 399 |
+
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
| 400 |
+
elif src.device.type not in _supported_device_type:
|
| 401 |
+
why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
|
| 402 |
+
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
| 403 |
+
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
| 404 |
+
"input/output projection weights or biases requires_grad")
|
| 405 |
+
|
| 406 |
+
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
|
| 407 |
+
convert_to_nested = True
|
| 408 |
+
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
|
| 409 |
+
src_key_padding_mask_for_layers = None
|
| 410 |
+
|
| 411 |
+
seq_len = _get_seq_len(src, batch_first)
|
| 412 |
+
is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
|
| 413 |
+
|
| 414 |
+
for mod in self.layers:
|
| 415 |
+
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
|
| 416 |
+
|
| 417 |
+
if convert_to_nested:
|
| 418 |
+
output = output.to_padded_tensor(0., src.size())
|
| 419 |
+
|
| 420 |
+
if self.norm is not None:
|
| 421 |
+
output = self.norm(output)
|
| 422 |
+
|
| 423 |
+
return output
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class TransformerDecoder(Module):
|
| 427 |
+
r"""TransformerDecoder is a stack of N decoder layers.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
| 431 |
+
num_layers: the number of sub-decoder-layers in the decoder (required).
|
| 432 |
+
norm: the layer normalization component (optional).
|
| 433 |
+
|
| 434 |
+
Examples::
|
| 435 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
| 436 |
+
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
| 437 |
+
>>> memory = torch.rand(10, 32, 512)
|
| 438 |
+
>>> tgt = torch.rand(20, 32, 512)
|
| 439 |
+
>>> out = transformer_decoder(tgt, memory)
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
__constants__ = ['norm']
|
| 443 |
+
|
| 444 |
+
def __init__(
|
| 445 |
+
self,
|
| 446 |
+
decoder_layer: "TransformerDecoderLayer",
|
| 447 |
+
num_layers: int,
|
| 448 |
+
norm: Optional[Module] = None
|
| 449 |
+
) -> None:
|
| 450 |
+
super().__init__()
|
| 451 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
| 452 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 453 |
+
self.num_layers = num_layers
|
| 454 |
+
self.norm = norm
|
| 455 |
+
|
| 456 |
+
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
| 457 |
+
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
| 458 |
+
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
|
| 459 |
+
memory_is_causal: bool = False) -> Tensor:
|
| 460 |
+
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
tgt: the sequence to the decoder (required).
|
| 464 |
+
memory: the sequence from the last layer of the encoder (required).
|
| 465 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
| 466 |
+
memory_mask: the mask for the memory sequence (optional).
|
| 467 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
| 468 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
| 469 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
| 470 |
+
Default: ``None``; try to detect a causal mask.
|
| 471 |
+
Warning:
|
| 472 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
| 473 |
+
the causal mask. Providing incorrect hints can result in
|
| 474 |
+
incorrect execution, including forward and backward
|
| 475 |
+
compatibility.
|
| 476 |
+
memory_is_causal: If specified, applies a causal mask as
|
| 477 |
+
``memory mask``.
|
| 478 |
+
Default: ``False``.
|
| 479 |
+
Warning:
|
| 480 |
+
``memory_is_causal`` provides a hint that
|
| 481 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
| 482 |
+
hints can result in incorrect execution, including
|
| 483 |
+
forward and backward compatibility.
|
| 484 |
+
|
| 485 |
+
Shape:
|
| 486 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 487 |
+
"""
|
| 488 |
+
output = tgt
|
| 489 |
+
|
| 490 |
+
seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
|
| 491 |
+
tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
|
| 492 |
+
|
| 493 |
+
for mod in self.layers:
|
| 494 |
+
output = mod(output, memory, tgt_mask=tgt_mask,
|
| 495 |
+
memory_mask=memory_mask,
|
| 496 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 497 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 498 |
+
tgt_is_causal=tgt_is_causal,
|
| 499 |
+
memory_is_causal=memory_is_causal)
|
| 500 |
+
|
| 501 |
+
if self.norm is not None:
|
| 502 |
+
output = self.norm(output)
|
| 503 |
+
|
| 504 |
+
return output
|
| 505 |
+
|
| 506 |
+
class TransformerEncoderLayer(Module):
|
| 507 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
| 508 |
+
|
| 509 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
| 510 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
| 511 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
| 512 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
| 513 |
+
in a different way during application.
|
| 514 |
+
|
| 515 |
+
TransformerEncoderLayer can handle either traditional torch.tensor inputs,
|
| 516 |
+
or Nested Tensor inputs. Derived classes are expected to similarly accept
|
| 517 |
+
both input formats. (Not all combinations of inputs are currently
|
| 518 |
+
supported by TransformerEncoderLayer while Nested Tensor is in prototype
|
| 519 |
+
state.)
|
| 520 |
+
|
| 521 |
+
If you are implementing a custom layer, you may derive it either from
|
| 522 |
+
the Module or TransformerEncoderLayer class. If your custom layer
|
| 523 |
+
supports both torch.Tensors and Nested Tensors inputs, make its
|
| 524 |
+
implementation a derived class of TransformerEncoderLayer. If your custom
|
| 525 |
+
Layer supports only torch.Tensor inputs, derive its implementation from
|
| 526 |
+
Module.
|
| 527 |
+
|
| 528 |
+
Args:
|
| 529 |
+
d_model: the number of expected features in the input (required).
|
| 530 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 531 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 532 |
+
dropout: the dropout value (default=0.1).
|
| 533 |
+
activation: the activation function of the intermediate layer, can be a string
|
| 534 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 535 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 536 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 537 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 538 |
+
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
| 539 |
+
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
|
| 540 |
+
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
| 541 |
+
bias. Default: ``True``.
|
| 542 |
+
|
| 543 |
+
Examples::
|
| 544 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
| 545 |
+
>>> src = torch.rand(10, 32, 512)
|
| 546 |
+
>>> out = encoder_layer(src)
|
| 547 |
+
|
| 548 |
+
Alternatively, when ``batch_first`` is ``True``:
|
| 549 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
| 550 |
+
>>> src = torch.rand(32, 10, 512)
|
| 551 |
+
>>> out = encoder_layer(src)
|
| 552 |
+
|
| 553 |
+
Fast path:
|
| 554 |
+
forward() will use a special optimized implementation described in
|
| 555 |
+
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
|
| 556 |
+
conditions are met:
|
| 557 |
+
|
| 558 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
|
| 559 |
+
argument ``requires_grad``
|
| 560 |
+
- training is disabled (using ``.eval()``)
|
| 561 |
+
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
|
| 562 |
+
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
|
| 563 |
+
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
|
| 564 |
+
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
|
| 565 |
+
nor ``src_key_padding_mask`` is passed
|
| 566 |
+
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
|
| 567 |
+
unless the caller has manually modified one without modifying the other)
|
| 568 |
+
|
| 569 |
+
If the optimized implementation is in use, a
|
| 570 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
|
| 571 |
+
passed for ``src`` to represent padding more efficiently than using a padding
|
| 572 |
+
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
|
| 573 |
+
returned, and an additional speedup proportional to the fraction of the input that
|
| 574 |
+
is padding can be expected.
|
| 575 |
+
|
| 576 |
+
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
| 577 |
+
https://arxiv.org/abs/2205.14135
|
| 578 |
+
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
__constants__ = ['norm_first']
|
| 582 |
+
|
| 583 |
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 584 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 585 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
| 586 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 587 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 588 |
+
super().__init__()
|
| 589 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
|
| 590 |
+
bias=bias, batch_first=batch_first,
|
| 591 |
+
**factory_kwargs)
|
| 592 |
+
# Implementation of Feedforward model
|
| 593 |
+
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
| 594 |
+
self.dropout = Dropout(dropout)
|
| 595 |
+
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
| 596 |
+
|
| 597 |
+
self.norm_first = norm_first
|
| 598 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 599 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 600 |
+
self.dropout1 = Dropout(dropout)
|
| 601 |
+
self.dropout2 = Dropout(dropout)
|
| 602 |
+
|
| 603 |
+
# Legacy string support for activation function.
|
| 604 |
+
if isinstance(activation, str):
|
| 605 |
+
activation = _get_activation_fn(activation)
|
| 606 |
+
|
| 607 |
+
# We can't test self.activation in forward() in TorchScript,
|
| 608 |
+
# so stash some information about it instead.
|
| 609 |
+
if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
| 610 |
+
self.activation_relu_or_gelu = 1
|
| 611 |
+
elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
| 612 |
+
self.activation_relu_or_gelu = 2
|
| 613 |
+
else:
|
| 614 |
+
self.activation_relu_or_gelu = 0
|
| 615 |
+
self.activation = activation
|
| 616 |
+
|
| 617 |
+
def __setstate__(self, state):
|
| 618 |
+
super().__setstate__(state)
|
| 619 |
+
if not hasattr(self, 'activation'):
|
| 620 |
+
self.activation = F.relu
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def forward(
|
| 624 |
+
self,
|
| 625 |
+
src: Tensor,
|
| 626 |
+
src_mask: Optional[Tensor] = None,
|
| 627 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 628 |
+
is_causal: bool = False) -> Tensor:
|
| 629 |
+
r"""Pass the input through the encoder layer.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
src: the sequence to the encoder layer (required).
|
| 633 |
+
src_mask: the mask for the src sequence (optional).
|
| 634 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 635 |
+
is_causal: If specified, applies a causal mask as ``src mask``.
|
| 636 |
+
Default: ``False``.
|
| 637 |
+
Warning:
|
| 638 |
+
``is_causal`` provides a hint that ``src_mask`` is the
|
| 639 |
+
causal mask. Providing incorrect hints can result in
|
| 640 |
+
incorrect execution, including forward and backward
|
| 641 |
+
compatibility.
|
| 642 |
+
|
| 643 |
+
Shape:
|
| 644 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 645 |
+
"""
|
| 646 |
+
src_key_padding_mask = F._canonical_mask(
|
| 647 |
+
mask=src_key_padding_mask,
|
| 648 |
+
mask_name="src_key_padding_mask",
|
| 649 |
+
other_type=F._none_or_dtype(src_mask),
|
| 650 |
+
other_name="src_mask",
|
| 651 |
+
target_type=src.dtype
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
src_mask = F._canonical_mask(
|
| 655 |
+
mask=src_mask,
|
| 656 |
+
mask_name="src_mask",
|
| 657 |
+
other_type=None,
|
| 658 |
+
other_name="",
|
| 659 |
+
target_type=src.dtype,
|
| 660 |
+
check_other=False,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
| 664 |
+
|
| 665 |
+
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
| 666 |
+
why_not_sparsity_fast_path = ''
|
| 667 |
+
if not is_fastpath_enabled:
|
| 668 |
+
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
| 669 |
+
elif not src.dim() == 3:
|
| 670 |
+
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
| 671 |
+
elif self.training:
|
| 672 |
+
why_not_sparsity_fast_path = "training is enabled"
|
| 673 |
+
elif not self.self_attn.batch_first:
|
| 674 |
+
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
|
| 675 |
+
elif self.self_attn.in_proj_bias is None:
|
| 676 |
+
why_not_sparsity_fast_path = "self_attn was passed bias=False"
|
| 677 |
+
elif not self.self_attn._qkv_same_embed_dim:
|
| 678 |
+
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
|
| 679 |
+
elif not self.activation_relu_or_gelu:
|
| 680 |
+
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
|
| 681 |
+
elif not (self.norm1.eps == self.norm2.eps):
|
| 682 |
+
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
|
| 683 |
+
elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
|
| 684 |
+
why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
|
| 685 |
+
elif self.self_attn.num_heads % 2 == 1:
|
| 686 |
+
why_not_sparsity_fast_path = "num_head is odd"
|
| 687 |
+
elif torch.is_autocast_enabled():
|
| 688 |
+
why_not_sparsity_fast_path = "autocast is enabled"
|
| 689 |
+
if not why_not_sparsity_fast_path:
|
| 690 |
+
tensor_args = (
|
| 691 |
+
src,
|
| 692 |
+
self.self_attn.in_proj_weight,
|
| 693 |
+
self.self_attn.in_proj_bias,
|
| 694 |
+
self.self_attn.out_proj.weight,
|
| 695 |
+
self.self_attn.out_proj.bias,
|
| 696 |
+
self.norm1.weight,
|
| 697 |
+
self.norm1.bias,
|
| 698 |
+
self.norm2.weight,
|
| 699 |
+
self.norm2.bias,
|
| 700 |
+
self.linear1.weight,
|
| 701 |
+
self.linear1.bias,
|
| 702 |
+
self.linear2.weight,
|
| 703 |
+
self.linear2.bias,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
# We have to use list comprehensions below because TorchScript does not support
|
| 707 |
+
# generator expressions.
|
| 708 |
+
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
| 709 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 710 |
+
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
| 711 |
+
elif not all((x.device.type in _supported_device_type) for x in tensor_args):
|
| 712 |
+
why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
|
| 713 |
+
f"{_supported_device_type}")
|
| 714 |
+
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
| 715 |
+
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
| 716 |
+
"input/output projection weights or biases requires_grad")
|
| 717 |
+
|
| 718 |
+
if not why_not_sparsity_fast_path:
|
| 719 |
+
merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
|
| 720 |
+
return torch._transformer_encoder_layer_fwd(
|
| 721 |
+
src,
|
| 722 |
+
self.self_attn.embed_dim,
|
| 723 |
+
self.self_attn.num_heads,
|
| 724 |
+
self.self_attn.in_proj_weight,
|
| 725 |
+
self.self_attn.in_proj_bias,
|
| 726 |
+
self.self_attn.out_proj.weight,
|
| 727 |
+
self.self_attn.out_proj.bias,
|
| 728 |
+
self.activation_relu_or_gelu == 2,
|
| 729 |
+
self.norm_first,
|
| 730 |
+
self.norm1.eps,
|
| 731 |
+
self.norm1.weight,
|
| 732 |
+
self.norm1.bias,
|
| 733 |
+
self.norm2.weight,
|
| 734 |
+
self.norm2.bias,
|
| 735 |
+
self.linear1.weight,
|
| 736 |
+
self.linear1.bias,
|
| 737 |
+
self.linear2.weight,
|
| 738 |
+
self.linear2.bias,
|
| 739 |
+
merged_mask,
|
| 740 |
+
mask_type,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
x = src
|
| 745 |
+
if self.norm_first:
|
| 746 |
+
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
|
| 747 |
+
x = x + self._ff_block(self.norm2(x))
|
| 748 |
+
else:
|
| 749 |
+
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
|
| 750 |
+
x = self.norm2(x + self._ff_block(x))
|
| 751 |
+
|
| 752 |
+
return x
|
| 753 |
+
|
| 754 |
+
# self-attention block
|
| 755 |
+
def _sa_block(self, x: Tensor,
|
| 756 |
+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
| 757 |
+
x = self.self_attn(x, x, x,
|
| 758 |
+
attn_mask=attn_mask,
|
| 759 |
+
key_padding_mask=key_padding_mask,
|
| 760 |
+
need_weights=False, is_causal=is_causal)[0]
|
| 761 |
+
return self.dropout1(x)
|
| 762 |
+
|
| 763 |
+
# feed forward block
|
| 764 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 765 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 766 |
+
return self.dropout2(x)
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
class TransformerDecoderLayer(Module):
|
| 770 |
+
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
| 771 |
+
|
| 772 |
+
This standard decoder layer is based on the paper "Attention Is All You Need".
|
| 773 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
| 774 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
| 775 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
| 776 |
+
in a different way during application.
|
| 777 |
+
|
| 778 |
+
Args:
|
| 779 |
+
d_model: the number of expected features in the input (required).
|
| 780 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 781 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 782 |
+
dropout: the dropout value (default=0.1).
|
| 783 |
+
activation: the activation function of the intermediate layer, can be a string
|
| 784 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 785 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 786 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 787 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 788 |
+
norm_first: if ``True``, layer norm is done prior to self attention, multihead
|
| 789 |
+
attention and feedforward operations, respectively. Otherwise it's done after.
|
| 790 |
+
Default: ``False`` (after).
|
| 791 |
+
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
| 792 |
+
bias. Default: ``True``.
|
| 793 |
+
|
| 794 |
+
Examples::
|
| 795 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
| 796 |
+
>>> memory = torch.rand(10, 32, 512)
|
| 797 |
+
>>> tgt = torch.rand(20, 32, 512)
|
| 798 |
+
>>> out = decoder_layer(tgt, memory)
|
| 799 |
+
|
| 800 |
+
Alternatively, when ``batch_first`` is ``True``:
|
| 801 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
| 802 |
+
>>> memory = torch.rand(32, 10, 512)
|
| 803 |
+
>>> tgt = torch.rand(32, 20, 512)
|
| 804 |
+
>>> out = decoder_layer(tgt, memory)
|
| 805 |
+
"""
|
| 806 |
+
|
| 807 |
+
__constants__ = ['norm_first']
|
| 808 |
+
|
| 809 |
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 810 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 811 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
| 812 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 813 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 814 |
+
super().__init__()
|
| 815 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 816 |
+
bias=bias, **factory_kwargs)
|
| 817 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 818 |
+
bias=bias, **factory_kwargs)
|
| 819 |
+
# Implementation of Feedforward model
|
| 820 |
+
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
| 821 |
+
self.dropout = Dropout(dropout)
|
| 822 |
+
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
| 823 |
+
|
| 824 |
+
self.norm_first = norm_first
|
| 825 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 826 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 827 |
+
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 828 |
+
self.dropout1 = Dropout(dropout)
|
| 829 |
+
self.dropout2 = Dropout(dropout)
|
| 830 |
+
self.dropout3 = Dropout(dropout)
|
| 831 |
+
|
| 832 |
+
# Legacy string support for activation function.
|
| 833 |
+
if isinstance(activation, str):
|
| 834 |
+
self.activation = _get_activation_fn(activation)
|
| 835 |
+
else:
|
| 836 |
+
self.activation = activation
|
| 837 |
+
|
| 838 |
+
def __setstate__(self, state):
|
| 839 |
+
if 'activation' not in state:
|
| 840 |
+
state['activation'] = F.relu
|
| 841 |
+
super().__setstate__(state)
|
| 842 |
+
|
| 843 |
+
def forward(
|
| 844 |
+
self,
|
| 845 |
+
tgt: Tensor,
|
| 846 |
+
memory: Tensor,
|
| 847 |
+
tgt_mask: Optional[Tensor] = None,
|
| 848 |
+
memory_mask: Optional[Tensor] = None,
|
| 849 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 850 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 851 |
+
tgt_is_causal: bool = False,
|
| 852 |
+
memory_is_causal: bool = False,
|
| 853 |
+
) -> Tensor:
|
| 854 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
tgt: the sequence to the decoder layer (required).
|
| 858 |
+
memory: the sequence from the last layer of the encoder (required).
|
| 859 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
| 860 |
+
memory_mask: the mask for the memory sequence (optional).
|
| 861 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
| 862 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
| 863 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
| 864 |
+
Default: ``False``.
|
| 865 |
+
Warning:
|
| 866 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
| 867 |
+
the causal mask. Providing incorrect hints can result in
|
| 868 |
+
incorrect execution, including forward and backward
|
| 869 |
+
compatibility.
|
| 870 |
+
memory_is_causal: If specified, applies a causal mask as
|
| 871 |
+
``memory mask``.
|
| 872 |
+
Default: ``False``.
|
| 873 |
+
Warning:
|
| 874 |
+
``memory_is_causal`` provides a hint that
|
| 875 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
| 876 |
+
hints can result in incorrect execution, including
|
| 877 |
+
forward and backward compatibility.
|
| 878 |
+
|
| 879 |
+
Shape:
|
| 880 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 881 |
+
"""
|
| 882 |
+
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
| 883 |
+
|
| 884 |
+
x = tgt
|
| 885 |
+
if self.norm_first:
|
| 886 |
+
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
|
| 887 |
+
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
|
| 888 |
+
x = x + self._ff_block(self.norm3(x))
|
| 889 |
+
else:
|
| 890 |
+
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
|
| 891 |
+
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
|
| 892 |
+
x = self.norm3(x + self._ff_block(x))
|
| 893 |
+
|
| 894 |
+
return x
|
| 895 |
+
|
| 896 |
+
# self-attention block
|
| 897 |
+
def _sa_block(self, x: Tensor,
|
| 898 |
+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
| 899 |
+
x = self.self_attn(x, x, x,
|
| 900 |
+
attn_mask=attn_mask,
|
| 901 |
+
key_padding_mask=key_padding_mask,
|
| 902 |
+
is_causal=is_causal,
|
| 903 |
+
need_weights=False)[0]
|
| 904 |
+
return self.dropout1(x)
|
| 905 |
+
|
| 906 |
+
# multihead attention block
|
| 907 |
+
def _mha_block(self, x: Tensor, mem: Tensor,
|
| 908 |
+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
| 909 |
+
x = self.multihead_attn(x, mem, mem,
|
| 910 |
+
attn_mask=attn_mask,
|
| 911 |
+
key_padding_mask=key_padding_mask,
|
| 912 |
+
is_causal=is_causal,
|
| 913 |
+
need_weights=False)[0]
|
| 914 |
+
return self.dropout2(x)
|
| 915 |
+
|
| 916 |
+
# feed forward block
|
| 917 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 918 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 919 |
+
return self.dropout3(x)
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
def _get_clones(module, N):
|
| 923 |
+
# FIXME: copy.deepcopy() is not defined on nn.module
|
| 924 |
+
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
| 928 |
+
if activation == "relu":
|
| 929 |
+
return F.relu
|
| 930 |
+
elif activation == "gelu":
|
| 931 |
+
return F.gelu
|
| 932 |
+
|
| 933 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def _detect_is_causal_mask(
|
| 937 |
+
mask: Optional[Tensor],
|
| 938 |
+
is_causal: Optional[bool] = None,
|
| 939 |
+
size: Optional[int] = None,
|
| 940 |
+
) -> bool:
|
| 941 |
+
"""Return whether the given attention mask is causal.
|
| 942 |
+
|
| 943 |
+
Warning:
|
| 944 |
+
If ``is_causal`` is not ``None``, its value will be returned as is. If a
|
| 945 |
+
user supplies an incorrect ``is_causal`` hint,
|
| 946 |
+
|
| 947 |
+
``is_causal=False`` when the mask is in fact a causal attention.mask
|
| 948 |
+
may lead to reduced performance relative to what would be achievable
|
| 949 |
+
with ``is_causal=True``;
|
| 950 |
+
``is_causal=True`` when the mask is in fact not a causal attention.mask
|
| 951 |
+
may lead to incorrect and unpredictable execution - in some scenarios,
|
| 952 |
+
a causal mask may be applied based on the hint, in other execution
|
| 953 |
+
scenarios the specified mask may be used. The choice may not appear
|
| 954 |
+
to be deterministic, in that a number of factors like alignment,
|
| 955 |
+
hardware SKU, etc influence the decision whether to use a mask or
|
| 956 |
+
rely on the hint.
|
| 957 |
+
``size`` if not None, check whether the mask is a causal mask of the provided size
|
| 958 |
+
Otherwise, checks for any causal mask.
|
| 959 |
+
"""
|
| 960 |
+
# Prevent type refinement
|
| 961 |
+
make_causal = (is_causal is True)
|
| 962 |
+
|
| 963 |
+
if is_causal is None and mask is not None:
|
| 964 |
+
sz = size if size is not None else mask.size(-2)
|
| 965 |
+
causal_comparison = _generate_square_subsequent_mask(
|
| 966 |
+
sz, device=mask.device, dtype=mask.dtype)
|
| 967 |
+
|
| 968 |
+
# Do not use `torch.equal` so we handle batched masks by
|
| 969 |
+
# broadcasting the comparison.
|
| 970 |
+
if mask.size() == causal_comparison.size():
|
| 971 |
+
make_causal = bool((mask == causal_comparison).all())
|
| 972 |
+
else:
|
| 973 |
+
make_causal = False
|
| 974 |
+
|
| 975 |
+
return make_causal
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|