koichi12 commited on
Commit
82ed4ab
·
verified ·
1 Parent(s): 0c36bb3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/fft.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_matcher.py +460 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_passes.py +950 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_safeguard.py +42 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_tree_utils.py +64 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/__init__.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/common_types.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/cpp.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/init.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/parameter.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__init__.py +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/__init__.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/thnn.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/thnn.py +4 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__init__.py +35 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/fused.py +30 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__init__.py +13 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +5 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py +12 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__init__.py +68 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/_functions.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/container.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/conv.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/flatten.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/fold.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/lazy.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/activation.py +1624 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py +849 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/channelshuffle.py +57 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py +911 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/dropout.py +294 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/flatten.py +144 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py +297 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/padding.py +801 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pixelshuffle.py +113 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pooling.py +1306 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py +975 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/fft.cpython-311.pyc ADDED
Binary file (29.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-311.pyc ADDED
Binary file (33 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-311.pyc ADDED
Binary file (1.41 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-311.pyc ADDED
Binary file (8.03 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-311.pyc ADDED
Binary file (23.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_matcher.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import enum
3
+
4
+ import torch
5
+ toq = torch.ops.quantized
6
+
7
+ from torch.fx import GraphModule
8
+ from torch.fx.graph import Graph, Node
9
+
10
+ from torch.ao.quantization.utils import getattr_from_fqn
11
+ from .ns_types import NSSubgraph, NSNodeTargetType
12
+ from .mappings import (
13
+ get_base_name_to_sets_of_related_ops,
14
+ get_unmatchable_types_map,
15
+ )
16
+ from .pattern_utils import (
17
+ get_type_a_related_to_b,
18
+ get_reversed_fusions,
19
+ end_node_matches_reversed_fusion,
20
+ )
21
+ from torch.ao.quantization import (
22
+ ObserverBase,
23
+ FakeQuantizeBase,
24
+ )
25
+
26
+ from typing import Dict, Tuple, List, Optional, Set, Any
27
+
28
+ def _get_output_nodes(g: Graph) -> List[Node]:
29
+ return [n for n in g.nodes if n.op == 'output']
30
+
31
+ class _NSGraphMatchableSubgraphsIterator:
32
+ """
33
+ Iterates through the graph of gm, starting with the output nodes
34
+ and continuing backwards.
35
+ 1. Returns matchable subgraphs, in order. A subgraph is defined by
36
+ (start_node, end_node).
37
+ 2. Skips over non-matchable subgraphs
38
+ """
39
+ def __init__(
40
+ self,
41
+ gm: GraphModule,
42
+ non_matchable_functions: Set[NSNodeTargetType],
43
+ non_matchable_modules: Set[NSNodeTargetType],
44
+ non_matchable_methods: Set[NSNodeTargetType],
45
+ ):
46
+ self.gm: GraphModule = gm
47
+ self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
48
+ self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
49
+ self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
50
+ self.seen_nodes: Set[Node] = set()
51
+ self.stack: List[Node] = []
52
+ for start_node in _get_output_nodes(self.gm.graph):
53
+ self.stack.append(start_node)
54
+
55
+ def __iter__(self):
56
+ return self
57
+
58
+ def __next__(self) -> NSSubgraph:
59
+ """
60
+ Returns the next matchable subgraph.
61
+ """
62
+ while len(self.stack) > 0:
63
+ cur_end_node = self.stack.pop()
64
+ if cur_end_node in self.seen_nodes:
65
+ continue
66
+
67
+ # for subgraphs which are single nodes, start_node == end_node
68
+ # for subgraphs with more than one node, start node != end_node
69
+ cur_start_node = cur_end_node
70
+ # Subgraphs like linear-relu have the base node as the start node.
71
+ # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
72
+ # base node as the second node.
73
+ # The cur_base_op_node var will move to the actual node during
74
+ # the fusion matching later in this code block.
75
+ cur_base_op_node = cur_end_node
76
+
77
+ # Check for potential fusions. For now, we are greedy
78
+ # and always skip all non-base nodes of a fusion. For example,
79
+ # if we match linear-relu backwards, we will always skip the
80
+ # relu node and attempt to match the linear node. This can
81
+ # be made configurable later if needed.
82
+ for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
83
+ is_match = end_node_matches_reversed_fusion(
84
+ cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes)
85
+ if is_match:
86
+ # navigate to the base node
87
+ for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
88
+ self.seen_nodes.add(cur_start_node)
89
+ # for now, assume that there are no other nodes
90
+ # which need to be added to the stack
91
+ cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
92
+ # if the base op index matches the current node, set it
93
+ rev_base_op_idx = \
94
+ len(_reverse_fusion_ops) - 2 - base_op_idx
95
+ if rev_fusion_idx == rev_base_op_idx:
96
+ cur_base_op_node = cur_start_node
97
+ break
98
+
99
+ self.seen_nodes.add(cur_start_node)
100
+ # add args of previous nodes to stack
101
+ for arg in cur_start_node.all_input_nodes:
102
+ self._recursively_add_node_arg_to_stack(arg)
103
+
104
+ # skip unmatchable nodes
105
+ # note: this check is done on the start_node, i.e.
106
+ # if we are matching linear-relu in reverse, this would do the matchable
107
+ # check on the linear
108
+ if not self._is_matchable(cur_base_op_node):
109
+ continue
110
+
111
+ # If an observer or a fake_quant was not matched as a part of
112
+ # a pattern of multiple nodes, ignore it. One case where this is
113
+ # relevant is an observer on a graph input, which was added because
114
+ # it is necessary for the next node.
115
+ if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
116
+ maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
117
+ if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
118
+ continue
119
+
120
+ return NSSubgraph(
121
+ start_node=cur_start_node, end_node=cur_end_node,
122
+ base_op_node=cur_base_op_node)
123
+
124
+ raise StopIteration
125
+
126
+ def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
127
+ """
128
+ Adds all of the nodes in this arg to the stack, properly navigating
129
+ through list, dicts and tuples.
130
+ """
131
+ if isinstance(arg, Node):
132
+ self.stack.append(arg)
133
+ elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
134
+ for inner_arg in arg:
135
+ self._recursively_add_node_arg_to_stack(inner_arg)
136
+ elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
137
+ for value in arg.values():
138
+ self._recursively_add_node_arg_to_stack(value)
139
+
140
+ def _is_matchable(self, node: Node) -> bool:
141
+ if node.op == 'call_function':
142
+ return node.target not in self.non_matchable_functions
143
+ elif node.op == 'call_module':
144
+ assert isinstance(node.target, str)
145
+ target_mod = getattr_from_fqn(self.gm, node.target)
146
+ return not \
147
+ any(isinstance(target_mod, t) # type: ignore[arg-type]
148
+ for t in self.non_matchable_modules)
149
+ elif node.op == 'call_method':
150
+ return node.target not in self.non_matchable_methods
151
+ else:
152
+ return False
153
+
154
+ class GraphMatchingException(Exception):
155
+ """
156
+ Exception raised when two graphs cannot be matched.
157
+ """
158
+ pass
159
+
160
+ class SubgraphTypeRelationship(enum.Enum):
161
+ # same type, known
162
+ # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
163
+ EQUAL = enum.auto()
164
+ # same type, but the type is not known to Numerical Suite
165
+ # (user defined type, etc).
166
+ EQUAL_BUT_UKNOWN = enum.auto()
167
+ # known, same subgraph_relationship set, but not the same type
168
+ # example: F.linear and toq.linear
169
+ RELATED_BUT_NOT_EQUAL = enum.auto()
170
+ # not related
171
+ NOT_RELATED = enum.auto()
172
+
173
+ def _get_subgraph_relationship_type(
174
+ subgraph_a: NSSubgraph,
175
+ subgraph_b: NSSubgraph,
176
+ gm_a: GraphModule,
177
+ gm_b: GraphModule,
178
+ type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
179
+ ) -> SubgraphTypeRelationship:
180
+ node_a = subgraph_a.base_op_node
181
+ node_b = subgraph_b.base_op_node
182
+
183
+ # TODO(next): make this code handle matching by what is before the base op
184
+ if node_a.op != node_b.op:
185
+ if not (
186
+ node_a.op in ('call_function', 'call_method') and
187
+ node_b.op in ('call_function', 'call_method')
188
+ ):
189
+ return SubgraphTypeRelationship.NOT_RELATED
190
+
191
+ if node_a.op in ('call_function', 'call_method'):
192
+ key = (node_a.target, node_b.target)
193
+
194
+ if key not in type_a_related_to_b:
195
+ if node_a.target == node_b.target:
196
+ return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
197
+ else:
198
+ return SubgraphTypeRelationship.NOT_RELATED
199
+ # after this point, we are dealing with known types
200
+
201
+ if node_a.target == node_b.target:
202
+ node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
203
+ node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
204
+ if node_a_has_prev and (not node_b_has_prev):
205
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
206
+ elif (not node_a_has_prev) and node_b_has_prev:
207
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
208
+ elif (not node_a_has_prev) and (not node_b_has_prev):
209
+ return SubgraphTypeRelationship.EQUAL
210
+ else:
211
+ # TODO(future PR): check for matches start_op_node and base_op_node
212
+ return SubgraphTypeRelationship.EQUAL
213
+
214
+ if key in type_a_related_to_b:
215
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
216
+ else:
217
+ return SubgraphTypeRelationship.NOT_RELATED
218
+ elif node_a.op == 'call_module':
219
+ assert (subgraph_a.base_op_node == subgraph_a.start_node and
220
+ subgraph_b.base_op_node == subgraph_b.start_node), \
221
+ "Matching call_module patterns where base_op_node != start_node is not supported yet"
222
+ # for call_module, we need to look up the modules to do the type check
223
+ assert isinstance(node_a.target, str)
224
+ mod_a = getattr_from_fqn(gm_a, node_a.target)
225
+ assert isinstance(node_b.target, str)
226
+ mod_b = getattr_from_fqn(gm_b, node_b.target)
227
+
228
+ key = (type(mod_a), type(mod_b))
229
+
230
+ if key not in type_a_related_to_b:
231
+ if type(mod_a) == type(mod_b):
232
+ return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
233
+ else:
234
+ return SubgraphTypeRelationship.NOT_RELATED
235
+ elif type(mod_a) == type(mod_b):
236
+ return SubgraphTypeRelationship.EQUAL
237
+ else:
238
+ return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
239
+
240
+ return SubgraphTypeRelationship.NOT_RELATED
241
+
242
+ def _get_name_for_subgraph(
243
+ subgraph_a: NSSubgraph,
244
+ gm_a: GraphModule,
245
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
246
+ existing_names: Set[str],
247
+ ) -> str:
248
+ """
249
+ Returns a unique name for a subgraph. This name is based on two things:
250
+ 1. the name of the set containing the underlying type of the base op in the
251
+ subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
252
+ 2. the number of previous subgraphs with related underlying type of the base op
253
+
254
+ For example, in the graph
255
+
256
+ linear0 -> relu0 -> linear1 -> relu1
257
+
258
+ The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
259
+ from the output node backwards, the name given to (linear1, relu1) will be
260
+ `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
261
+ will be `base_op_torch.nn.functional.linear_1`.
262
+
263
+ Why are we not just using the node name? Answer: because of two requirements:
264
+ A. fusions must be supported
265
+ B. some Numeric Suite APIs can be called without having all of the models in memory
266
+
267
+ For example, let's say we need to match nodes of
268
+
269
+ (1) ... -> linear0 -> relu0 -> ...
270
+
271
+ And
272
+
273
+ (2) ... -> linear_relu0 -> ...
274
+
275
+ Without being able to inspect them together. With the current naming scheme, if
276
+ we iterate through both of these graphs in the same order, and assuming the rest
277
+ of the graphs match, both of these subgraphs will get the same name without
278
+ (1) and (2) knowing anything about each other.
279
+ """
280
+ target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
281
+ target_base_type = None
282
+ for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
283
+ if target_type in sets_of_related_ops:
284
+ target_base_type = base_name
285
+ target_base_name = 'base_op_' + str(target_base_type)
286
+ counter = 0
287
+ proposed_name = target_base_name + '_' + str(counter)
288
+ while proposed_name in existing_names:
289
+ counter += 1
290
+ proposed_name = target_base_name + '_' + str(counter)
291
+ existing_names.add(proposed_name)
292
+ return proposed_name
293
+
294
+ def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
295
+ if node.op in ('call_function', 'call_method'):
296
+ return node.target
297
+ elif node.op == 'call_module':
298
+ assert isinstance(node.target, str)
299
+ mod = getattr_from_fqn(gm, node.target)
300
+ return type(mod)
301
+ return None
302
+
303
+ def get_matching_subgraph_pairs(
304
+ gm_a: GraphModule,
305
+ gm_b: GraphModule,
306
+ base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
307
+ unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
308
+ ) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
309
+ """
310
+ Matches matchable subgraphs of graph_a to graph_b.
311
+
312
+ For a node, "matchable" is defined as a node which is not an observer,
313
+ fake_quants, quant or dequant.
314
+
315
+ A subgraph can contain one or more nodes. A subgraph is matchable if
316
+ at least one node inside of it is matchable. Currently, all nodes in
317
+ a subgraph must be matchable (because we assume no observers will be
318
+ inserted in the middle of a fusion).
319
+
320
+ A subgraph is defined by (start_node, end_node). We assume that only
321
+ start_node and end_node are linked with the surrounding graph, all other
322
+ nodes in a subgraph are self-contained.
323
+
324
+ A pair of nodes is "related" if both nodes represent the same mathematical
325
+ operation across different quantization flavors. For example,
326
+ `F.linear` and `torch.ops.quantized.linear` are related, and
327
+ `F.linear` and `torch.nn.Conv` are not related.
328
+
329
+ For each matchable pair of nodes node_a and node_b, they will match
330
+ if node_a and node_b are related.
331
+
332
+ For graphs A and B, they will match iff:
333
+ 1. the number of matchable subgraphs in A and B is equivalent
334
+ 2. when iterating through the matchable subgraphs of A and B in the same order, each
335
+ corresponding pair of base nodes is related.
336
+
337
+ This enables us to find the corresponding subgraphs between
338
+ graphs of related models. For example, if we had two graphs such as:
339
+
340
+ graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
341
+ w -/
342
+ b -/
343
+
344
+ graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
345
+ packed_params_0 -/
346
+
347
+ This function will return the following result:
348
+ {
349
+ 'conv_0': ( # the name of the node in graph_b
350
+ (conv_0, conv_0), # (start_node_a, end_node_a)
351
+ (qconv_0, qconv_0), # (start_node_b, end_node_b)
352
+ ),
353
+ }
354
+
355
+ Or, if we have a fusion pattern,
356
+
357
+ graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
358
+ w -/
359
+ b -/
360
+
361
+ graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
362
+ packed_params_0 -/
363
+
364
+ This function will return the following result:
365
+ {
366
+ 'linear_relu_0': ( # the name of the node in graph_b
367
+ (linear_0, relu_0), # (start_node_a, end_node_a)
368
+ (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
369
+ ),
370
+ }
371
+ """
372
+ if unmatchable_types_map is None:
373
+ unmatchable_types_map = get_unmatchable_types_map()
374
+ non_matchable_functions = unmatchable_types_map['funs_unmatchable']
375
+ non_matchable_modules = unmatchable_types_map['mods_unmatchable']
376
+ non_matchable_methods = unmatchable_types_map['meths_unmatchable']
377
+
378
+ graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
379
+ gm_a, non_matchable_functions, non_matchable_modules,
380
+ non_matchable_methods)
381
+ graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
382
+ gm_b, non_matchable_functions, non_matchable_modules,
383
+ non_matchable_methods)
384
+ results = collections.OrderedDict()
385
+ if base_name_to_sets_of_related_ops is None:
386
+ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
387
+ type_a_related_to_b = \
388
+ get_type_a_related_to_b(base_name_to_sets_of_related_ops)
389
+
390
+ existing_names_a: Set[str] = set()
391
+ existing_names_b: Set[str] = set()
392
+
393
+ while True:
394
+ # fetch the next subgraphs from a and b
395
+ cur_subgraph_a, cur_subgraph_b = None, None
396
+ try:
397
+ cur_subgraph_a = next(graph_a_iterator)
398
+ except StopIteration:
399
+ pass
400
+ try:
401
+ cur_subgraph_b = next(graph_b_iterator)
402
+ except StopIteration:
403
+ pass
404
+
405
+ # look up types of a and b for useful error messages
406
+ type_start_a, type_start_b = None, None
407
+ if cur_subgraph_a is not None:
408
+ type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
409
+ if cur_subgraph_b is not None:
410
+ type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
411
+
412
+ # check for results and determine what to do next
413
+ if cur_subgraph_a is not None and cur_subgraph_b is not None:
414
+ # both nodes were fetched, check for subgraph_relationship
415
+ # note: subgraph_relationship is checked on the start node, i.e.
416
+ # if a linear-relu pattern is checked, we would check for subgraph_relationship
417
+ # of the linear
418
+ subgraph_relationship = _get_subgraph_relationship_type(
419
+ cur_subgraph_a, cur_subgraph_b,
420
+ gm_a, gm_b, type_a_related_to_b)
421
+ if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
422
+ msg = f"""
423
+ The subgraphs
424
+ ({cur_subgraph_a}, {type_start_a}) and
425
+ ({cur_subgraph_b}, {type_start_b})
426
+ are not related. Please ensure that the two models you pass in have the same number
427
+ of subgraphs, and each pair of subgraphs is related to each other."""
428
+ raise GraphMatchingException(msg)
429
+ elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
430
+ # skip matching but unknown types
431
+ continue
432
+ key_name_a = _get_name_for_subgraph(
433
+ cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops,
434
+ existing_names_a)
435
+ key_name_b = _get_name_for_subgraph(
436
+ cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops,
437
+ existing_names_b)
438
+ assert key_name_a == key_name_b, \
439
+ f"Subgraph names {key_name_a} and {key_name_b} do not match"
440
+ results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
441
+ continue
442
+ elif cur_subgraph_a is None and cur_subgraph_b is None:
443
+ # we reached the end of both graphs
444
+ break
445
+ else:
446
+ # only one node was fetched, no match possible, throw error
447
+ msg = f"""
448
+ Attempting to match
449
+ ({cur_subgraph_a}, {type_start_a}) and
450
+ ({cur_subgraph_b}, {type_start_b}),
451
+ one of which is empty. Please ensure that the two models you pass in have the same number
452
+ of subgraphs."""
453
+ raise GraphMatchingException(msg)
454
+
455
+ # The subgraph pairs are originally created by traversing the two graphs
456
+ # from the outputs to the inputs. Reverse the results to return the
457
+ # subgraphs in their order of execution.
458
+ results = collections.OrderedDict(reversed(list(results.items())))
459
+
460
+ return results
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/graph_passes.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx import GraphModule, map_arg
3
+ from torch.fx.graph import Graph, Node
4
+ from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
5
+
6
+ from .utils import (
7
+ get_node_first_input_and_output_type,
8
+ getattr_from_fqn,
9
+ NodeInputOrOutputType,
10
+ return_first_non_observer_node,
11
+ get_number_of_non_param_args,
12
+ get_target_type_str,
13
+ get_arg_indices_of_inputs_to_log,
14
+ get_node_input_qparams,
15
+ op_type_supports_shadowing,
16
+ get_normalized_nth_input,
17
+ )
18
+
19
+ from .ns_types import (
20
+ NSSingleResultValuesType,
21
+ NSSubgraph,
22
+ NSNodeTargetType,
23
+ )
24
+ from torch.ao.ns.fx.mappings import (
25
+ get_node_type_to_io_type_map,
26
+ )
27
+ from torch.ao.quantization.observer import _is_activation_post_process
28
+
29
+ from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
30
+
31
+ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
32
+ fqn = None
33
+ if hasattr(gm, '_node_name_to_scope'):
34
+ # fqn on observers is not present, because they do not
35
+ # exist when the fqns are created during tracing. If this is
36
+ # an observer, get the fqn of the node being observed.
37
+ node_to_use_for_fqn = node
38
+ if node.op == 'call_module':
39
+ assert isinstance(node.target, str)
40
+ module = getattr_from_fqn(gm, node.target)
41
+ if _is_activation_post_process(module):
42
+ node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
43
+ fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
44
+ return fqn # type: ignore[return-value]
45
+
46
+ def _insert_logger_after_node(
47
+ node: Node,
48
+ gm: GraphModule,
49
+ logger_cls: Callable,
50
+ logger_node_name_suffix: str,
51
+ ref_node_name: str,
52
+ model_name: str,
53
+ ref_name: str,
54
+ ref_node_target_type: str,
55
+ results_type: str,
56
+ index_within_arg: int,
57
+ index_of_arg: int,
58
+ fqn: Optional[str],
59
+ ) -> Node:
60
+ """
61
+ Given a starting graph of
62
+
63
+ prev_node -> node -> next_node
64
+
65
+ This function creates a new logger_cls obj and adds it
66
+ after node, resulting in
67
+
68
+ prev_node -> node -> logger_obj -> next_node
69
+ """
70
+ # create new name
71
+ logger_node_name = \
72
+ get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
73
+ target_type = get_target_type_str(node, gm)
74
+ # create the logger object
75
+ logger_obj = logger_cls(
76
+ ref_node_name, node.name, model_name, ref_name, target_type,
77
+ ref_node_target_type,
78
+ results_type, index_within_arg, index_of_arg, fqn)
79
+ # attach the logger object to the parent module
80
+ setattr(gm, logger_node_name, logger_obj)
81
+ logger_node = node.graph.create_node(
82
+ 'call_module', logger_node_name, (node,), {})
83
+ return logger_node
84
+
85
+ def add_loggers_to_model(
86
+ gm: GraphModule,
87
+ node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
88
+ node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
89
+ logger_cls: Callable,
90
+ model_name: str,
91
+ ) -> GraphModule:
92
+ """
93
+ Takes the graph of gm, adds loggers to the output
94
+ of each node in nodes_to_instrument. Returns a GraphModule with the new
95
+ graph.
96
+ """
97
+
98
+ new_graph = Graph()
99
+ env: Dict[str, Any] = {}
100
+ modules = dict(gm.named_modules())
101
+
102
+ def load_arg(a):
103
+ return map_arg(a, lambda node: env[node.name])
104
+
105
+ for node in gm.graph.nodes:
106
+ if node.op == 'output':
107
+ new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
108
+ continue
109
+
110
+ if (
111
+ (node in node_to_instrument_inputs_to_ref_node_name) or
112
+ (node in node_to_instrument_outputs_to_ref_node_name)
113
+ ):
114
+ fqn = _maybe_get_fqn(node, gm)
115
+
116
+ if node in node_to_instrument_inputs_to_ref_node_name:
117
+ ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
118
+ # Ops such add and mul are special because either
119
+ # one or two of the first two arguments can be tensors,
120
+ # and if one argument is a tensor it can be first or
121
+ # second (x + 1 versus 1 + x).
122
+ arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
123
+ for node_arg_idx in arg_indices_to_log:
124
+ node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
125
+ if type(node_arg) == Node:
126
+ # create a single input logger
127
+ prev_node = env[node_arg.name]
128
+ env[node_arg.name] = _insert_logger_after_node(
129
+ prev_node, gm, logger_cls, '_ns_logger_', node.name,
130
+ model_name, ref_name, ref_node_type,
131
+ NSSingleResultValuesType.NODE_INPUT.value,
132
+ index_within_arg=0, index_of_arg=node_arg_idx,
133
+ fqn=fqn)
134
+ elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
135
+ # create N input loggers, one for each node
136
+ for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
137
+ prev_node = env[arg.name]
138
+ env[prev_node.name] = _insert_logger_after_node(
139
+ prev_node, gm, logger_cls, '_ns_logger_', node.name,
140
+ model_name, ref_name, ref_node_type,
141
+ NSSingleResultValuesType.NODE_INPUT.value,
142
+ index_within_arg=arg_idx, index_of_arg=node_arg_idx,
143
+ fqn=fqn)
144
+ else:
145
+ pass
146
+
147
+ # ensure env is populated with base node
148
+ # Note: runs for both inputs and outputs
149
+ env[node.name] = new_graph.node_copy(node, load_arg)
150
+
151
+ if node in node_to_instrument_outputs_to_ref_node_name:
152
+ ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
153
+ # add the logger after the base node
154
+ env[node.name] = _insert_logger_after_node(
155
+ env[node.name], gm, logger_cls, '_ns_logger_', node.name,
156
+ model_name, ref_name, ref_node_type,
157
+ NSSingleResultValuesType.NODE_OUTPUT.value,
158
+ index_within_arg=0, index_of_arg=0, fqn=fqn)
159
+
160
+ else:
161
+ env[node.name] = new_graph.node_copy(node, load_arg)
162
+
163
+ new_gm = GraphModule(gm, new_graph)
164
+ return new_gm
165
+
166
+ def _insert_quantize_per_tensor_node(
167
+ prev_node_c: Node,
168
+ node_a: Node,
169
+ gm_b: GraphModule,
170
+ graph_c: Graph,
171
+ scale: Union[torch.Tensor, float],
172
+ zero_point: Union[torch.Tensor, int],
173
+ dtype_cast_name: str,
174
+ ) -> Node:
175
+ # copy scale
176
+ scale_node_name = \
177
+ get_new_attr_name_with_prefix(
178
+ node_a.name + '_input_scale_')(gm_b)
179
+ setattr(gm_b, scale_node_name, scale)
180
+ scale_node = graph_c.create_node(
181
+ 'get_attr', scale_node_name, (), {}, scale_node_name)
182
+ # copy zero_point
183
+ zero_point_node_name = \
184
+ get_new_attr_name_with_prefix(
185
+ node_a.name + '_input_zero_point_')(gm_b)
186
+ setattr(gm_b, zero_point_node_name, zero_point)
187
+ zero_point_node = graph_c.create_node(
188
+ 'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
189
+ # create the quantize_per_tensor call
190
+ return graph_c.create_node(
191
+ 'call_function', torch.quantize_per_tensor,
192
+ (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
193
+ dtype_cast_name)
194
+
195
+ def _insert_dtype_cast_after_node(
196
+ node_a: Node,
197
+ node_c: Node,
198
+ prev_node_c: Union[Node, List[Node]],
199
+ gm_a: GraphModule,
200
+ gm_b: GraphModule,
201
+ graph_c: Graph,
202
+ node_name_prefix: str,
203
+ logger_cls: Callable,
204
+ node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
205
+ ) -> Union[Node, List[Node]]:
206
+ """
207
+ Given a starting graph C (derived from graph B) of
208
+
209
+ ... -> prev_node_c -> node_c -> ...
210
+
211
+ And a corresponding related node_a, inserts the correct dtype
212
+ cast node after prev_node_c to cast into the dtype expected
213
+ by node_a, resulting in:
214
+
215
+ dtype_cast
216
+ /
217
+ ... -> prev_node_c -> node_c -> ...
218
+
219
+ For example, if node_c is an int8 op and node_a is an fp32 op, this function
220
+ will insert a dequant.
221
+ """
222
+ dtype_cast_op = None
223
+ dtype_cast_mod_cls = None
224
+ dtype_cast_method = None
225
+ dtype_cast_method_dtype = None
226
+ dtype_cast_scale = None
227
+ dtype_cast_zero_point = None
228
+ node_input_type_a, _node_output_type_a = \
229
+ get_node_first_input_and_output_type(
230
+ node_a, gm_a, logger_cls, node_type_to_io_type_map)
231
+ node_input_type_c, _node_output_type_c = \
232
+ get_node_first_input_and_output_type(
233
+ node_c, gm_b, logger_cls, node_type_to_io_type_map)
234
+
235
+ if (
236
+ (node_input_type_a == NodeInputOrOutputType.FP32 and
237
+ node_input_type_c == NodeInputOrOutputType.INT8) or
238
+ (node_input_type_a == NodeInputOrOutputType.FP32 and
239
+ node_input_type_c == NodeInputOrOutputType.FP16) or
240
+ # TODO(future PR): determine the actual dtype of node_c,
241
+ # the current code only works because dequantize works with
242
+ # multiple input dtypes.
243
+ (node_input_type_a == NodeInputOrOutputType.FP32 and
244
+ node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
245
+ ):
246
+ dtype_cast_op = torch.dequantize
247
+ elif (
248
+ node_input_type_a == node_input_type_c and
249
+ node_input_type_a != NodeInputOrOutputType.UNKNOWN
250
+ ):
251
+ dtype_cast_mod_cls = torch.nn.Identity
252
+ elif (
253
+ node_input_type_a == NodeInputOrOutputType.INT8 and
254
+ node_input_type_c == NodeInputOrOutputType.FP32
255
+ ):
256
+ # int8 shadows fp32, the dtype cast needs to quantize to int8
257
+ # with the right qparams.
258
+ node_a_input_qparams = get_node_input_qparams(
259
+ node_a, gm_a, node_type_to_io_type_map)
260
+ if node_a_input_qparams is not None:
261
+ dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
262
+ dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
263
+ elif (
264
+ node_input_type_a == NodeInputOrOutputType.FP16 and
265
+ node_input_type_c == NodeInputOrOutputType.FP32
266
+ ):
267
+ dtype_cast_method = 'to'
268
+ dtype_cast_method_dtype = torch.float16
269
+ else:
270
+ raise AssertionError(
271
+ f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
272
+ f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
273
+
274
+ if isinstance(prev_node_c, Node):
275
+ new_dtype_cast_name = \
276
+ get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
277
+ if dtype_cast_op:
278
+ if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
279
+ return _insert_quantize_per_tensor_node(
280
+ prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
281
+ dtype_cast_zero_point, new_dtype_cast_name)
282
+ else:
283
+ return graph_c.create_node(
284
+ 'call_function', dtype_cast_op, (prev_node_c,), {},
285
+ new_dtype_cast_name)
286
+ elif dtype_cast_method:
287
+ return graph_c.create_node(
288
+ 'call_method', dtype_cast_method,
289
+ (prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
290
+ else:
291
+ assert dtype_cast_mod_cls
292
+ dtype_cast_mod = dtype_cast_mod_cls()
293
+ setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
294
+ return graph_c.create_node(
295
+ 'call_module', new_dtype_cast_name, (prev_node_c,), {},
296
+ new_dtype_cast_name)
297
+ elif isinstance(prev_node_c, list):
298
+ results = []
299
+ for prev_node_c_inner in prev_node_c:
300
+ new_dtype_cast_name = \
301
+ get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
302
+ if dtype_cast_op:
303
+ # TODO(future PR): add handling for quantize_per_tensor
304
+ new_dtype_cast_node = graph_c.create_node(
305
+ 'call_function', dtype_cast_op, (prev_node_c_inner,), {},
306
+ new_dtype_cast_name)
307
+ results.append(new_dtype_cast_node)
308
+ else:
309
+ assert dtype_cast_mod_cls
310
+ dtype_cast_mod = dtype_cast_mod_cls()
311
+ setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
312
+ new_dtype_cast_node = graph_c.create_node(
313
+ 'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
314
+ new_dtype_cast_name)
315
+ results.append(new_dtype_cast_node)
316
+ return results
317
+ else:
318
+ raise AssertionError(f"type f{type(prev_node_c)} is not handled")
319
+
320
+ # TODO(future PR): look into using copy_node API instead
321
+ def _copy_node_from_a_to_c(
322
+ node_a: Node,
323
+ gm_a: GraphModule,
324
+ gm_b: GraphModule,
325
+ graph_c: Graph,
326
+ ) -> Node:
327
+ """
328
+ Simple copy of node_a to graph_c.
329
+ """
330
+ if node_a.op == 'get_attr':
331
+ node_a_copy_name = \
332
+ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
333
+ node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
334
+ if torch.is_tensor(node_a_obj):
335
+ node_a_obj = node_a_obj.detach()
336
+ setattr(gm_b, node_a_copy_name, node_a_obj)
337
+ node_a_copy = graph_c.create_node(
338
+ node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
339
+ return node_a_copy
340
+ elif node_a.op == 'call_method':
341
+ assert node_a.target in ('dequantize', 'to'), \
342
+ f"target {node_a.target} is not implemented"
343
+ if node_a.target == 'dequantize':
344
+ arg_copy = _copy_node_from_a_to_c(
345
+ get_normalized_nth_input(node_a, gm_a, 0),
346
+ gm_a, gm_b, graph_c) # type: ignore[arg-type]
347
+ node_a_copy_name = \
348
+ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
349
+ node_a_copy = graph_c.create_node(
350
+ node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
351
+ return node_a_copy
352
+ else: # to
353
+ arg_copy = _copy_node_from_a_to_c(
354
+ get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type]
355
+ node_a_copy_name = \
356
+ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
357
+ node_a_copy = graph_c.create_node(
358
+ node_a.op, node_a.target,
359
+ (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
360
+ {}, node_a_copy_name)
361
+ return node_a_copy
362
+
363
+ else:
364
+ raise AssertionError(
365
+ f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
366
+
367
+ def _can_insert_copy_of_subgraph_a(
368
+ subgraph_a: NSSubgraph,
369
+ gm_a: GraphModule,
370
+ num_non_param_args_node_a: int,
371
+ ) -> bool:
372
+ """
373
+ This function returns `False` if the input subgraph cannot be copied by
374
+ `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
375
+ that there is a corner case logic for which copy is not yet implemented.
376
+ """
377
+ # populate the list of nodes we need to check
378
+ nodes = []
379
+ cur_node = subgraph_a.end_node
380
+ while cur_node != subgraph_a.start_node:
381
+ nodes.append(cur_node)
382
+ cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
383
+ nodes.append(cur_node)
384
+ nodes.reverse()
385
+
386
+ def _can_insert(node_a_arg, gm_a):
387
+ if isinstance(node_a_arg, Node):
388
+ arg_a = return_first_non_observer_node(node_a_arg, gm_a)
389
+ if arg_a.op == 'call_method':
390
+ return arg_a.target in ('dequantize', 'to')
391
+ elif arg_a.op == 'get_attr':
392
+ return True
393
+ else:
394
+ return False
395
+ elif isinstance(node_a_arg, (list, tuple)):
396
+ for el in node_a_arg:
397
+ if not isinstance(el, Node):
398
+ return False
399
+ return True
400
+
401
+ # For each node, check if we handle the copy behavior. This follows the
402
+ # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
403
+ for node_a in nodes:
404
+
405
+ local_num_non_param_args_node_a = num_non_param_args_node_a \
406
+ if node_a is nodes[0] else 1
407
+
408
+ norm_args_kwargs = node_a.normalized_arguments(
409
+ gm_a, normalize_to_only_use_kwargs=True)
410
+ if norm_args_kwargs is not None:
411
+ norm_args, norm_kwargs = norm_args_kwargs
412
+ else:
413
+ norm_args, norm_kwargs = node_a.args, node_a.kwargs
414
+
415
+ cur_idx = 0
416
+
417
+ while cur_idx < len(norm_args):
418
+ if cur_idx == 0:
419
+ pass
420
+ elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
421
+ pass
422
+ else:
423
+ if not _can_insert(norm_args[cur_idx], gm_a):
424
+ return False
425
+ cur_idx += 1
426
+
427
+ for kwarg_val in norm_kwargs.values():
428
+ # stitch the inputs from base graph
429
+ if cur_idx == 0:
430
+ pass
431
+ elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
432
+ pass
433
+ else:
434
+ if not _can_insert(kwarg_val, gm_a):
435
+ return False
436
+ cur_idx += 1
437
+
438
+ return True
439
+
440
+ def _insert_copy_of_subgraph_a_after_input_node_c(
441
+ input_node_c: Union[Node, List[Node]],
442
+ input_node_c_2: Optional[Union[Node, List[Node]]],
443
+ subgraph_a: NSSubgraph,
444
+ gm_a: GraphModule,
445
+ gm_b: GraphModule,
446
+ node_name_prefix: str,
447
+ ) -> Node:
448
+ """
449
+ TODO(before land): real docblock
450
+ """
451
+ if isinstance(input_node_c, Node):
452
+ graph_c = input_node_c.graph
453
+ else:
454
+ assert isinstance(input_node_c, list)
455
+ graph_c = input_node_c[0].graph
456
+
457
+ # create a sequential list of the subgraphs' nodes from start to end,
458
+ # because we need to add the nodes to graph C in non-reverse order
459
+ nodes_of_a = [subgraph_a.end_node]
460
+ cur_node = subgraph_a.end_node
461
+ while cur_node != subgraph_a.start_node:
462
+ cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
463
+ nodes_of_a.insert(0, cur_node)
464
+
465
+ # go through nodes of a in order, and insert them into the graph of c
466
+ # sequentially
467
+ cur_node_a = nodes_of_a[0]
468
+ cur_node_c = _insert_copy_of_node_a_after_input_node_c(
469
+ input_node_c,
470
+ input_node_c_2,
471
+ cur_node_a,
472
+ gm_a,
473
+ gm_b,
474
+ node_name_prefix)
475
+ for cur_idx_a in range(1, len(nodes_of_a)):
476
+ cur_node_a = nodes_of_a[cur_idx_a]
477
+ prev_node_c = cur_node_c # previous added node is the input to next node
478
+ cur_node_c = _insert_copy_of_node_a_after_input_node_c(
479
+ prev_node_c,
480
+ # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
481
+ None,
482
+ cur_node_a,
483
+ gm_a,
484
+ gm_b,
485
+ node_name_prefix)
486
+ # return the last inserted node
487
+ return cur_node_c
488
+
489
+
490
+ def _insert_copy_of_node_a_after_input_node_c(
491
+ input_node_c: Union[Node, List[Node]],
492
+ input_node_c_2: Optional[Union[Node, List[Node]]],
493
+ node_a: Node,
494
+ gm_a: GraphModule,
495
+ gm_b: GraphModule,
496
+ node_name_prefix: str,
497
+ ) -> Node:
498
+ """
499
+ Assume that node_a from graph_a has
500
+ args (input, (input2)?, arg1, ...), and
501
+ kwargs {kw0: kwarg0, ...}
502
+
503
+ Note: input2 is optional. If it equals to None, we assume that the op
504
+ has a single non-param input. If it is specified, we assume that the op
505
+ has two non-param inputs.
506
+
507
+ Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
508
+ and creates the corresponding nodes in graph_c. Note: observers are ignored,
509
+ so if an arg is an observer we navigate up until we find a non-observer parent.
510
+
511
+ If node_a is a call_module, points the module pointed to by node_a to gm_b.
512
+
513
+ Creates the copy of node_a in graph_c, with input as the first arg,
514
+ and all other args and kwargs pointing to the copies of the objects
515
+ in gm_b created above.
516
+
517
+ An example in pictures:
518
+
519
+ graph A:
520
+ ========
521
+
522
+ input -------------> node_a
523
+ / / /
524
+ (input_2)?----------/ / /
525
+ / /
526
+ weight -> weight_obs /
527
+ /
528
+ bias ----------------
529
+
530
+ graph C (derived from B):
531
+ =========================
532
+
533
+ input_node_c --> node_a_copy
534
+ / / /
535
+ (input_node_c_2)? / /
536
+ / /
537
+ weight_copy ----/ /
538
+ /
539
+ bias_copy ------/
540
+ """
541
+ if isinstance(input_node_c, Node):
542
+ graph_c = input_node_c.graph
543
+ else:
544
+ assert isinstance(input_node_c, list)
545
+ graph_c = input_node_c[0].graph
546
+
547
+ norm_args_kwargs = node_a.normalized_arguments(
548
+ gm_a, normalize_to_only_use_kwargs=True)
549
+ if norm_args_kwargs is not None:
550
+ norm_args, norm_kwargs = norm_args_kwargs
551
+ else:
552
+ norm_args, norm_kwargs = node_a.args, node_a.kwargs
553
+
554
+ new_args = []
555
+ new_kwargs = {}
556
+
557
+ def _copy_arg(arg):
558
+ # copy the other inputs from the other graph
559
+ if isinstance(arg, Node):
560
+ arg = return_first_non_observer_node(arg, gm_a)
561
+ arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
562
+ return arg
563
+ elif isinstance(arg, (int, float, torch.dtype)):
564
+ return arg
565
+ elif isinstance(kwarg_val, (list, tuple)):
566
+ for el in kwarg_val:
567
+ assert not isinstance(el, Node), \
568
+ "handling of Node inside list is not implemented"
569
+ return arg
570
+ else:
571
+ raise AssertionError(
572
+ f"handling for kwarg of type {type(kwarg_val)} is not implemented")
573
+
574
+ cur_idx = 0
575
+
576
+ while cur_idx < len(norm_args):
577
+ if cur_idx == 0:
578
+ new_arg = input_node_c
579
+ elif cur_idx == 1 and input_node_c_2 is not None:
580
+ new_arg = input_node_c_2
581
+ else:
582
+ new_arg = _copy_arg(norm_args[cur_idx])
583
+ new_args.append(new_arg)
584
+ cur_idx += 1
585
+
586
+ for kwarg_name, kwarg_val in norm_kwargs.items():
587
+ # stitch the inputs from base graph
588
+ if cur_idx == 0:
589
+ new_kwargs[kwarg_name] = input_node_c
590
+ elif cur_idx == 1 and input_node_c_2 is not None:
591
+ new_kwargs[kwarg_name] = input_node_c_2
592
+ else:
593
+ new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
594
+ cur_idx += 1
595
+
596
+ new_args = tuple(new_args) # type: ignore[assignment]
597
+
598
+ node_a_shadows_c_name = \
599
+ get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
600
+
601
+ if node_a.op == 'call_module':
602
+ # if target is a module, we point to the module from gm_b
603
+ new_mod_copy_name = \
604
+ get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
605
+ # fetch the corresponding module from gm_a
606
+ assert isinstance(node_a.target, str)
607
+ mod_a = getattr_from_fqn(gm_a, node_a.target)
608
+ setattr(gm_b, new_mod_copy_name, mod_a)
609
+ node_a_shadows_c = graph_c.create_node(
610
+ node_a.op, new_mod_copy_name, new_args,
611
+ new_kwargs, node_a_shadows_c_name)
612
+ return node_a_shadows_c
613
+ else:
614
+ assert node_a.op in ('call_function', 'call_method')
615
+ node_a_shadows_c = graph_c.create_node(
616
+ node_a.op, node_a.target, new_args,
617
+ new_kwargs, node_a_shadows_c_name)
618
+ return node_a_shadows_c
619
+
620
+ def create_a_shadows_b(
621
+ name_a: str,
622
+ gm_a: GraphModule,
623
+ name_b: str,
624
+ gm_b: GraphModule,
625
+ matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
626
+ logger_cls: Callable,
627
+ should_log_inputs: bool,
628
+ node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
629
+ ) -> GraphModule:
630
+ """
631
+ Creates a new GraphModule consisting of the graph of C, with the meaningful
632
+ nodes of A shadowing the corresponding nodes of B. For example,
633
+
634
+ Graph A:
635
+ a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
636
+
637
+ Graph B:
638
+ b0 -> op0_int8 -> b1 -> op1_int8 -> b2
639
+
640
+ matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
641
+
642
+ Graph C (A shadows B):
643
+
644
+ / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
645
+ / /
646
+ b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
647
+
648
+ In a nutshell, this function does the following for each node pair:
649
+ * copies the necessary attributes and modules from gm_a to gm_b,
650
+ keeping names unique
651
+ * adds a dtype cast op (dequant, quant, etc)
652
+ * adds a copy of node_a in gm_b's graph
653
+ * adds loggers to the outputs of node_a and node_b
654
+ """
655
+
656
+ if node_type_to_io_type_map is None:
657
+ node_type_to_io_type_map = get_node_type_to_io_type_map()
658
+
659
+ # graph_c is the graph created from copying the nodes of graph_b and inserting
660
+ # the shadows with the nodes copied from graph_a
661
+ graph_c = Graph()
662
+ env_c: Dict[str, Any] = {}
663
+ modules = dict(gm_b.named_modules())
664
+
665
+ def load_arg(a):
666
+ return map_arg(a, lambda node: env_c[node.name])
667
+
668
+ start_node_b_to_matched_subgraph_a_and_name = {}
669
+ end_node_b_to_matched_subgraph_a_and_name = {}
670
+ for match_name, match in matched_subgraph_pairs.items():
671
+ subgraph_a, subgraph_b = match
672
+ ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
673
+ ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
674
+ start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
675
+ (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
676
+ end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
677
+ (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
678
+
679
+ for node_b in gm_b.graph.nodes:
680
+ if node_b.op == 'output':
681
+ graph_c.output(map_arg(node_b.args[0], load_arg))
682
+ continue
683
+
684
+ # calculate the flags to determine what to do with this node
685
+ node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
686
+ node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
687
+
688
+ if (node_b_is_start_node or node_b_is_end_node):
689
+
690
+ if node_b_is_start_node:
691
+ subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
692
+ start_node_b_to_matched_subgraph_a_and_name[node_b]
693
+ else:
694
+ assert node_b_is_end_node
695
+ subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
696
+ end_node_b_to_matched_subgraph_a_and_name[node_b]
697
+
698
+ all_op_types_support_shadowing = (
699
+ op_type_supports_shadowing(subgraph_a.start_node) and
700
+ op_type_supports_shadowing(node_b)
701
+ )
702
+ if not all_op_types_support_shadowing:
703
+ print(
704
+ f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
705
+ f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
706
+ ', unsupported')
707
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
708
+ continue
709
+
710
+ # For both start_node and end_node verify that we know how to do
711
+ # the dtype cast. If we do not, skip.
712
+ node_input_type_a, node_output_type_a = \
713
+ get_node_first_input_and_output_type(
714
+ subgraph_a.start_node, gm_a, logger_cls,
715
+ node_type_to_io_type_map)
716
+ node_input_type_b, node_output_type_b = \
717
+ get_node_first_input_and_output_type(
718
+ node_b, gm_b, logger_cls,
719
+ node_type_to_io_type_map)
720
+ node_io_types_known_a_and_b = (
721
+ node_input_type_a != NodeInputOrOutputType.UNKNOWN and
722
+ node_output_type_a != NodeInputOrOutputType.UNKNOWN and
723
+ node_input_type_b != NodeInputOrOutputType.UNKNOWN and
724
+ node_output_type_b != NodeInputOrOutputType.UNKNOWN
725
+ )
726
+ if not node_io_types_known_a_and_b:
727
+ print(
728
+ f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
729
+ f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
730
+ ', unknown dtype cast')
731
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
732
+ continue
733
+
734
+ # If we are shadowing from fp32 to int8, we need to insert
735
+ # quantize_per_tensor call with qparams from the previous node.
736
+ # Only do this if we are able to infer these qparams from the graph.
737
+ if (
738
+ node_input_type_a == NodeInputOrOutputType.INT8 and
739
+ node_input_type_b == NodeInputOrOutputType.FP32
740
+ ):
741
+ node_a_input_qparams = get_node_input_qparams(
742
+ subgraph_a.start_node, gm_a, node_type_to_io_type_map)
743
+ if not node_a_input_qparams:
744
+ print(
745
+ f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
746
+ f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
747
+ ', unknown input qparams')
748
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
749
+ continue
750
+
751
+ num_non_param_args_node_a = \
752
+ get_number_of_non_param_args(subgraph_a.start_node, gm_a)
753
+ if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
754
+ print(
755
+ f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
756
+ f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
757
+ ', unhandled logic in subgraph copy')
758
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
759
+ continue
760
+
761
+ fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
762
+ fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b) # type: ignore[possibly-undefined]
763
+
764
+ if node_b_is_start_node:
765
+
766
+ # if necessary, log the input of node_c
767
+ if should_log_inputs:
768
+ prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
769
+ if isinstance(prev_node_b, Node):
770
+ prev_node_c = env_c[prev_node_b.name]
771
+ env_c[prev_node_c.name] = _insert_logger_after_node(
772
+ prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
773
+ node_b.name, name_b, ref_name, ref_node_type_b,
774
+ NSSingleResultValuesType.NODE_INPUT.value,
775
+ index_within_arg=0, index_of_arg=0,
776
+ fqn=fqn_base_b)
777
+ elif isinstance(prev_node_b, list):
778
+ # first, save the prev_node instances, because they
779
+ # will be overwritten in the env after the first logger
780
+ # is added
781
+ prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
782
+
783
+ for arg_idx, arg in enumerate(prev_node_b):
784
+ prev_node_c = prev_node_c_list[arg_idx]
785
+ env_c[prev_node_c.name] = _insert_logger_after_node(
786
+ prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
787
+ node_b.name, name_b, ref_name, ref_node_type_b,
788
+ NSSingleResultValuesType.NODE_INPUT.value,
789
+ index_within_arg=arg_idx, index_of_arg=0,
790
+ fqn=fqn_base_b)
791
+ else:
792
+ # logging of inputs which are not lists is not supported yet
793
+ raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
794
+ # subgraph so far:
795
+ #
796
+ # (prev_node_c)+ -> (logger_c_input)?
797
+
798
+ # Note: this if statement is always True, spelling it out to clarify code
799
+ # intent.
800
+ if node_b_is_start_node or node_b_is_end_node:
801
+ # ensure env_c is populated with base node
802
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
803
+ node_c = env_c[node_b.name]
804
+
805
+ # after this point,
806
+ #
807
+ # node_a is the original node from graph_a, with parent module gm_a
808
+ # node_b is the original node from graph_b, with parent module gm_b
809
+ # node_c is the copy of node_b in graph_c
810
+ #
811
+ # subgraph so far:
812
+ #
813
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
814
+
815
+ if node_b_is_start_node:
816
+
817
+ # cast dtype from the dtype of node_c's input to the dtype of
818
+ # node_a's input (dequant, etc)
819
+ # prev_node_c = node_c.args[0]
820
+ prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) # type: ignore[possibly-undefined]
821
+ if should_log_inputs:
822
+ # skip the input logger when inserting a dtype cast
823
+ if isinstance(prev_node_c, Node):
824
+ prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
825
+ elif isinstance(prev_node_c, list):
826
+ prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
827
+ dtype_cast_node = _insert_dtype_cast_after_node(
828
+ subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
829
+ node_b.name + '_dtype_cast_', logger_cls,
830
+ node_type_to_io_type_map)
831
+ # note: not inserting to env_c because all nodes which use the dtype
832
+ # casts are copied from graph_a
833
+ #
834
+ # subgraph so far:
835
+ #
836
+ # (dtype_cast_node)+
837
+ # /
838
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
839
+
840
+ # if input logging is enabled, log the input to the subgraph
841
+ if should_log_inputs:
842
+ # TODO: explain this
843
+ ref_node_name = ''
844
+ if isinstance(dtype_cast_node, Node):
845
+ dtype_cast_node = _insert_logger_after_node(
846
+ dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
847
+ ref_node_name, name_a, ref_name, ref_node_type_a,
848
+ NSSingleResultValuesType.NODE_INPUT.value,
849
+ index_within_arg=0, index_of_arg=0,
850
+ fqn=fqn_base_a)
851
+ input_logger: Union[Node, List[Node]] = dtype_cast_node
852
+ else:
853
+ assert isinstance(dtype_cast_node, list)
854
+ new_loggers = []
855
+ for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
856
+ dtype_cast_logger = _insert_logger_after_node(
857
+ dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
858
+ ref_node_name, name_a, ref_name, ref_node_type_a,
859
+ NSSingleResultValuesType.NODE_INPUT.value,
860
+ index_within_arg=dtype_cast_idx,
861
+ index_of_arg=0,
862
+ fqn=fqn_base_a)
863
+ new_loggers.append(dtype_cast_logger)
864
+ dtype_cast_node = new_loggers
865
+ input_logger = dtype_cast_node
866
+ # subgraph so far:
867
+ #
868
+ # (dtype_cast_node)+ -> (logger_a_input)?
869
+ # /
870
+ # prev_node_c -> (logger_c_input)? -> node_start_c
871
+
872
+ # hook up the new mod_a copy to be in the graph, receiving the
873
+ # same inputs as mod_b does, with dtype cast to match a
874
+ # Some ops, such as LSTMs, have two non-param inputs. If we have
875
+ # such an op, pass the second param as well. Note: dtype casting
876
+ # for the second param is not implemented yet, it can be added
877
+ # later if there is a use case.
878
+ node_c_second_non_param_arg = None
879
+ num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
880
+ if num_non_param_args_node_a == 2:
881
+ # node_c_second_non_param_arg = node_c.args[1]
882
+ node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
883
+ node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
884
+ dtype_cast_node, node_c_second_non_param_arg,
885
+ subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
886
+ env_c[node_a_shadows_c.name] = node_a_shadows_c
887
+ # subgraph so far:
888
+ #
889
+ # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
890
+ # /
891
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
892
+
893
+ if should_log_inputs:
894
+ # When we created the input logger, we left the ref_node_name
895
+ # as an empty string, because the subgraph copy did not exist
896
+ # yet. Now that the subgraph copy exists, we modify this name
897
+ # to its true value.
898
+ # Note: the alternative to this is to create the input logger
899
+ # after creating the subgraph, which is slightly more
900
+ # complicated. This is the lesser of two evils.
901
+ # input_logger = env_c[dtype_cast_node.name]
902
+ # Find the first node in the subgraph
903
+ cur_node = node_a_shadows_c
904
+ while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
905
+ cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
906
+ if isinstance(input_logger, Node):
907
+ input_logger_mod = getattr(gm_b, input_logger.name)
908
+ input_logger_mod.ref_node_name = cur_node.name
909
+ else:
910
+ assert isinstance(input_logger, list)
911
+ for input_logger_inner in input_logger:
912
+ input_logger_mod = getattr(gm_b, input_logger_inner.name)
913
+ input_logger_mod.ref_node_name = cur_node.name
914
+
915
+ # hook up a logger to the mod_a copy
916
+ env_c[node_a_shadows_c.name] = _insert_logger_after_node(
917
+ env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
918
+ node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
919
+ NSSingleResultValuesType.NODE_OUTPUT.value,
920
+ index_within_arg=0, index_of_arg=0,
921
+ fqn=fqn_base_a)
922
+ # subgraph so far:
923
+ #
924
+ # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
925
+ # /
926
+ # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
927
+
928
+ if node_b_is_end_node:
929
+
930
+ # hook up a logger to the mod_b copy
931
+ env_c[node_b.name] = _insert_logger_after_node(
932
+ env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
933
+ node_b.name, name_b, ref_name, ref_node_type_b,
934
+ NSSingleResultValuesType.NODE_OUTPUT.value,
935
+ index_within_arg=0, index_of_arg=0,
936
+ fqn=fqn_base_b)
937
+ # subgraph so far:
938
+ #
939
+ # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
940
+ # /
941
+ # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
942
+ #
943
+ # Note: node_start_c may be the same node as node_end_c, or they
944
+ # may have nodes inbetween.
945
+
946
+ else:
947
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
948
+
949
+ gm_c = GraphModule(gm_b, graph_c)
950
+ return gm_c
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_safeguard.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
3
+ from torch.overrides import TorchFunctionMode
4
+
5
+
6
+ class AutogradStateOpsFailSafeguard(TorchFunctionMode):
7
+ """
8
+ Detect grad state ops during exporting the graph and fail the process by
9
+ raising an error, to avoid unexpected behavior. Those grad mode ops could be:
10
+ `torch.no_grad`
11
+ `torch.enable_grad`
12
+ `torch.set_grad_enabled`
13
+
14
+ Export with predispatch mode is exempted.
15
+ """
16
+
17
+ def __torch_function__(self, func, types, args=(), kwargs=None):
18
+ kwargs = kwargs or {}
19
+ unsupported_grad_mode_ops = [
20
+ torch._C._set_grad_enabled,
21
+ ]
22
+ # It's only enabled while tracing, by confirming the torch dispatch mode is
23
+ # any active PROXY. This is to allow the autograd ops out of tracing.
24
+ current_state = torch._C.is_grad_enabled()
25
+ if func in unsupported_grad_mode_ops:
26
+ assert len(args) == 1
27
+ changed_state = args[0]
28
+ mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
29
+ # Intend to check if it's not the pre_dispatch mode. It's allowed to use
30
+ # autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
31
+ if (
32
+ mode
33
+ and isinstance(mode, ProxyTorchDispatchMode)
34
+ and not mode.pre_dispatch
35
+ and changed_state != current_state
36
+ ):
37
+ raise RuntimeError(
38
+ f"Encountered autograd state manager op {func} trying to change global autograd state "
39
+ "while exporting. This is unsafe because we don't capture this op in torch.export "
40
+ "today, hence we can't reflect the user intention soundly."
41
+ )
42
+ return func(*args, **kwargs)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_tree_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, Optional
2
+
3
+ from torch.utils._pytree import Context, TreeSpec
4
+
5
+
6
+ def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
7
+ """Reorder user-provided kwargs to match the order in `spec`. `spec` is
8
+ expected to be the in_spec of an exported program, i.e. the spec that
9
+ results from flattening `(args, kwargs)`.
10
+
11
+ We need this to provide consistent input ordering, such so that users can
12
+ pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result.
13
+ """
14
+ # Make sure that the spec is actually shaped like (args, kwargs)
15
+ assert spec.type is tuple
16
+ assert spec.num_children == 2
17
+ kwargs_spec = spec.children_specs[1]
18
+ assert kwargs_spec.type is dict
19
+
20
+ if set(user_kwargs) != set(kwargs_spec.context):
21
+ raise ValueError(
22
+ f"kwarg key mismatch: "
23
+ f"Got {list(user_kwargs)} but expected {kwargs_spec.context}"
24
+ )
25
+
26
+ reordered_kwargs = {}
27
+ for kw in kwargs_spec.context:
28
+ reordered_kwargs[kw] = user_kwargs[kw]
29
+
30
+ return reordered_kwargs
31
+
32
+
33
+ def is_equivalent(
34
+ spec1: TreeSpec,
35
+ spec2: TreeSpec,
36
+ equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool],
37
+ ) -> bool:
38
+ """Customizable equivalence check for two TreeSpecs.
39
+
40
+ Arguments:
41
+ spec1: The first TreeSpec to compare
42
+ spec2: The second TreeSpec to compare
43
+ equivalence_fn: A function to determine the equivalence of two
44
+ TreeSpecs by examining their types and contexts. It will be called like:
45
+
46
+ equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context)
47
+
48
+ This function will be applied recursively to all children.
49
+
50
+ Returns:
51
+ True if the two TreeSpecs are equivalent, False otherwise.
52
+ """
53
+ if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context):
54
+ return False
55
+
56
+ # Recurse on children
57
+ if len(spec1.children_specs) != len(spec2.children_specs):
58
+ return False
59
+
60
+ for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
61
+ if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
62
+ return False
63
+
64
+ return True
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.69 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/common_types.cpython-311.pyc ADDED
Binary file (1.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/cpp.cpython-311.pyc ADDED
Binary file (5.42 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/init.cpython-311.pyc ADDED
Binary file (28.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/parameter.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (218 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/__pycache__/thnn.cpython-311.pyc ADDED
Binary file (346 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/backends/thnn.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # this is for historical pickle deserialization, it is not used otherwise
2
+
3
+ def _get_thnn_function_backend():
4
+ pass
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic import ConvBn1d
2
+ from torch.ao.nn.intrinsic import ConvBn2d
3
+ from torch.ao.nn.intrinsic import ConvBn3d
4
+ from torch.ao.nn.intrinsic import ConvBnReLU1d
5
+ from torch.ao.nn.intrinsic import ConvBnReLU2d
6
+ from torch.ao.nn.intrinsic import ConvBnReLU3d
7
+ from torch.ao.nn.intrinsic import ConvReLU1d
8
+ from torch.ao.nn.intrinsic import ConvReLU2d
9
+ from torch.ao.nn.intrinsic import ConvReLU3d
10
+ from torch.ao.nn.intrinsic import LinearReLU
11
+ from torch.ao.nn.intrinsic import BNReLU2d
12
+ from torch.ao.nn.intrinsic import BNReLU3d
13
+ from torch.ao.nn.intrinsic import LinearBn1d
14
+ from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
15
+
16
+ # Include the subpackages in case user imports from it directly
17
+ from . import modules # noqa: F401
18
+ from . import qat # noqa: F401
19
+ from . import quantized # noqa: F401
20
+
21
+ __all__ = [
22
+ 'ConvBn1d',
23
+ 'ConvBn2d',
24
+ 'ConvBn3d',
25
+ 'ConvBnReLU1d',
26
+ 'ConvBnReLU2d',
27
+ 'ConvBnReLU3d',
28
+ 'ConvReLU1d',
29
+ 'ConvReLU2d',
30
+ 'ConvReLU3d',
31
+ 'LinearReLU',
32
+ 'BNReLU2d',
33
+ 'BNReLU3d',
34
+ 'LinearBn1d',
35
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/fused.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic import BNReLU2d
2
+ from torch.ao.nn.intrinsic import BNReLU3d
3
+ from torch.ao.nn.intrinsic import ConvBn1d
4
+ from torch.ao.nn.intrinsic import ConvBn2d
5
+ from torch.ao.nn.intrinsic import ConvBn3d
6
+ from torch.ao.nn.intrinsic import ConvBnReLU1d
7
+ from torch.ao.nn.intrinsic import ConvBnReLU2d
8
+ from torch.ao.nn.intrinsic import ConvBnReLU3d
9
+ from torch.ao.nn.intrinsic import ConvReLU1d
10
+ from torch.ao.nn.intrinsic import ConvReLU2d
11
+ from torch.ao.nn.intrinsic import ConvReLU3d
12
+ from torch.ao.nn.intrinsic import LinearBn1d
13
+ from torch.ao.nn.intrinsic import LinearReLU
14
+ from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
15
+
16
+ __all__ = [
17
+ 'BNReLU2d',
18
+ 'BNReLU3d',
19
+ 'ConvBn1d',
20
+ 'ConvBn2d',
21
+ 'ConvBn3d',
22
+ 'ConvBnReLU1d',
23
+ 'ConvBnReLU2d',
24
+ 'ConvBnReLU3d',
25
+ 'ConvReLU1d',
26
+ 'ConvReLU2d',
27
+ 'ConvReLU3d',
28
+ 'LinearBn1d',
29
+ 'LinearReLU',
30
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-311.pyc ADDED
Binary file (1.27 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-311.pyc ADDED
Binary file (712 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import * # noqa: F403
2
+ # to ensure customers can use the module below
3
+ # without importing it directly
4
+ import torch.nn.intrinsic.quantized.dynamic
5
+
6
+ __all__ = [
7
+ 'BNReLU2d',
8
+ 'BNReLU3d',
9
+ 'ConvReLU1d',
10
+ 'ConvReLU2d',
11
+ 'ConvReLU3d',
12
+ 'LinearReLU',
13
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (335 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
2
+
3
+ __all__ = [
4
+ 'LinearReLU',
5
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .linear_relu import LinearReLU
2
+ from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
3
+ from .bn_relu import BNReLU2d, BNReLU3d
4
+
5
+ __all__ = [
6
+ 'LinearReLU',
7
+ 'ConvReLU1d',
8
+ 'ConvReLU2d',
9
+ 'ConvReLU3d',
10
+ 'BNReLU2d',
11
+ 'BNReLU3d',
12
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (556 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-311.pyc ADDED
Binary file (350 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+ from .linear import Identity, Linear, Bilinear, LazyLinear
3
+ from .conv import Conv1d, Conv2d, Conv3d, \
4
+ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, \
5
+ LazyConv1d, LazyConv2d, LazyConv3d, LazyConvTranspose1d, LazyConvTranspose2d, LazyConvTranspose3d
6
+ from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
7
+ Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \
8
+ Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \
9
+ Hardsigmoid, Hardswish, SiLU, Mish
10
+ from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
11
+ CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
12
+ MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, HuberLoss, \
13
+ SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss
14
+ from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
15
+ from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
16
+ MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, LPPool3d, \
17
+ AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
18
+ from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
19
+ LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
20
+ from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \
21
+ LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d
22
+ from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
23
+ from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
24
+ from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \
25
+ ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \
26
+ CircularPad1d, CircularPad2d, CircularPad3d
27
+ from .sparse import Embedding, EmbeddingBag
28
+ from .rnn import RNNBase, RNN, LSTM, GRU, \
29
+ RNNCellBase, RNNCell, LSTMCell, GRUCell
30
+ from .pixelshuffle import PixelShuffle, PixelUnshuffle
31
+ from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample
32
+ from .distance import PairwiseDistance, CosineSimilarity
33
+ from .fold import Fold, Unfold
34
+ from .adaptive import AdaptiveLogSoftmaxWithLoss
35
+ from .transformer import TransformerEncoder, TransformerDecoder, \
36
+ TransformerEncoderLayer, TransformerDecoderLayer, Transformer
37
+ from .flatten import Flatten, Unflatten
38
+ from .channelshuffle import ChannelShuffle
39
+
40
+ __all__ = [
41
+ 'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
42
+ 'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
43
+ 'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
44
+ 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
45
+ 'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss',
46
+ 'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss',
47
+ 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'GaussianNLLLoss',
48
+ 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict',
49
+ 'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
50
+ 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
51
+ 'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
52
+ 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
53
+ 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
54
+ 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
55
+ 'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
56
+ 'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
57
+ 'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
58
+ 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d',
59
+ 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
60
+ 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
61
+ 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
62
+ 'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
63
+ 'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
64
+ 'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
65
+ 'LazyInstanceNorm1d', 'LazyInstanceNorm2d', 'LazyInstanceNorm3d',
66
+ 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
67
+ 'CircularPad1d', 'CircularPad2d', 'CircularPad3d'
68
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (7.04 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/_functions.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/container.cpython-311.pyc ADDED
Binary file (55.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (76.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/flatten.cpython-311.pyc ADDED
Binary file (8.13 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/fold.cpython-311.pyc ADDED
Binary file (14.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/lazy.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/activation.py ADDED
@@ -0,0 +1,1624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from .linear import NonDynamicallyQuantizableLinear
7
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
8
+ from torch.nn.parameter import Parameter
9
+ from .module import Module
10
+ from .. import functional as F
11
+
12
+ __all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
13
+ 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
14
+ 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
15
+ 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
16
+
17
+
18
+ class Threshold(Module):
19
+ r"""Thresholds each element of the input Tensor.
20
+
21
+ Threshold is defined as:
22
+
23
+ .. math::
24
+ y =
25
+ \begin{cases}
26
+ x, &\text{ if } x > \text{threshold} \\
27
+ \text{value}, &\text{ otherwise }
28
+ \end{cases}
29
+
30
+ Args:
31
+ threshold: The value to threshold at
32
+ value: The value to replace with
33
+ inplace: can optionally do the operation in-place. Default: ``False``
34
+
35
+ Shape:
36
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
37
+ - Output: :math:`(*)`, same shape as the input.
38
+
39
+ Examples::
40
+
41
+ >>> m = nn.Threshold(0.1, 20)
42
+ >>> input = torch.randn(2)
43
+ >>> output = m(input)
44
+ """
45
+
46
+ __constants__ = ['threshold', 'value', 'inplace']
47
+
48
+ threshold: float
49
+ value: float
50
+ inplace: bool
51
+
52
+ def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
53
+ super().__init__()
54
+ self.threshold = threshold
55
+ self.value = value
56
+ self.inplace = inplace
57
+ # TODO: check in THNN (if inplace == True, then assert value <= threshold)
58
+
59
+ def forward(self, input: Tensor) -> Tensor:
60
+ return F.threshold(input, self.threshold, self.value, self.inplace)
61
+
62
+ def extra_repr(self):
63
+ inplace_str = ', inplace=True' if self.inplace else ''
64
+ return f'threshold={self.threshold}, value={self.value}{inplace_str}'
65
+
66
+
67
+ class ReLU(Module):
68
+ r"""Applies the rectified linear unit function element-wise.
69
+
70
+ :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
71
+
72
+ Args:
73
+ inplace: can optionally do the operation in-place. Default: ``False``
74
+
75
+ Shape:
76
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
77
+ - Output: :math:`(*)`, same shape as the input.
78
+
79
+ .. image:: ../scripts/activation_images/ReLU.png
80
+
81
+ Examples::
82
+
83
+ >>> m = nn.ReLU()
84
+ >>> input = torch.randn(2)
85
+ >>> output = m(input)
86
+
87
+
88
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
89
+
90
+ >>> m = nn.ReLU()
91
+ >>> input = torch.randn(2).unsqueeze(0)
92
+ >>> output = torch.cat((m(input), m(-input)))
93
+ """
94
+
95
+ __constants__ = ['inplace']
96
+ inplace: bool
97
+
98
+ def __init__(self, inplace: bool = False):
99
+ super().__init__()
100
+ self.inplace = inplace
101
+
102
+ def forward(self, input: Tensor) -> Tensor:
103
+ return F.relu(input, inplace=self.inplace)
104
+
105
+ def extra_repr(self) -> str:
106
+ inplace_str = 'inplace=True' if self.inplace else ''
107
+ return inplace_str
108
+
109
+
110
+ class RReLU(Module):
111
+ r"""Applies the randomized leaky rectified linear unit function, element-wise.
112
+
113
+ Method described in the paper:
114
+ `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
115
+
116
+ The function is defined as:
117
+
118
+ .. math::
119
+ \text{RReLU}(x) =
120
+ \begin{cases}
121
+ x & \text{if } x \geq 0 \\
122
+ ax & \text{ otherwise }
123
+ \end{cases}
124
+
125
+ where :math:`a` is randomly sampled from uniform distribution
126
+ :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
127
+ evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
128
+
129
+ Args:
130
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
131
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
132
+ inplace: can optionally do the operation in-place. Default: ``False``
133
+
134
+ Shape:
135
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
136
+ - Output: :math:`(*)`, same shape as the input.
137
+
138
+ .. image:: ../scripts/activation_images/RReLU.png
139
+
140
+ Examples::
141
+
142
+ >>> m = nn.RReLU(0.1, 0.3)
143
+ >>> input = torch.randn(2)
144
+ >>> output = m(input)
145
+
146
+ """
147
+
148
+ __constants__ = ['lower', 'upper', 'inplace']
149
+
150
+ lower: float
151
+ upper: float
152
+ inplace: bool
153
+
154
+ def __init__(
155
+ self,
156
+ lower: float = 1. / 8,
157
+ upper: float = 1. / 3,
158
+ inplace: bool = False
159
+ ):
160
+ super().__init__()
161
+ self.lower = lower
162
+ self.upper = upper
163
+ self.inplace = inplace
164
+
165
+ def forward(self, input: Tensor) -> Tensor:
166
+ return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
167
+
168
+ def extra_repr(self):
169
+ inplace_str = ', inplace=True' if self.inplace else ''
170
+ return f'lower={self.lower}, upper={self.upper}{inplace_str}'
171
+
172
+
173
+ class Hardtanh(Module):
174
+ r"""Applies the HardTanh function element-wise.
175
+
176
+ HardTanh is defined as:
177
+
178
+ .. math::
179
+ \text{HardTanh}(x) = \begin{cases}
180
+ \text{max\_val} & \text{ if } x > \text{ max\_val } \\
181
+ \text{min\_val} & \text{ if } x < \text{ min\_val } \\
182
+ x & \text{ otherwise } \\
183
+ \end{cases}
184
+
185
+ Args:
186
+ min_val: minimum value of the linear region range. Default: -1
187
+ max_val: maximum value of the linear region range. Default: 1
188
+ inplace: can optionally do the operation in-place. Default: ``False``
189
+
190
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
191
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
192
+
193
+ Shape:
194
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
195
+ - Output: :math:`(*)`, same shape as the input.
196
+
197
+ .. image:: ../scripts/activation_images/Hardtanh.png
198
+
199
+ Examples::
200
+
201
+ >>> m = nn.Hardtanh(-2, 2)
202
+ >>> input = torch.randn(2)
203
+ >>> output = m(input)
204
+ """
205
+
206
+ __constants__ = ['min_val', 'max_val', 'inplace']
207
+
208
+ min_val: float
209
+ max_val: float
210
+ inplace: bool
211
+
212
+ def __init__(
213
+ self,
214
+ min_val: float = -1.,
215
+ max_val: float = 1.,
216
+ inplace: bool = False,
217
+ min_value: Optional[float] = None,
218
+ max_value: Optional[float] = None
219
+ ) -> None:
220
+ super().__init__()
221
+ if min_value is not None:
222
+ warnings.warn("keyword argument min_value is deprecated and rename to min_val")
223
+ min_val = min_value
224
+ if max_value is not None:
225
+ warnings.warn("keyword argument max_value is deprecated and rename to max_val")
226
+ max_val = max_value
227
+
228
+ self.min_val = min_val
229
+ self.max_val = max_val
230
+ self.inplace = inplace
231
+ assert self.max_val > self.min_val
232
+
233
+ def forward(self, input: Tensor) -> Tensor:
234
+ return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
235
+
236
+ def extra_repr(self) -> str:
237
+ inplace_str = ', inplace=True' if self.inplace else ''
238
+ return f'min_val={self.min_val}, max_val={self.max_val}{inplace_str}'
239
+
240
+
241
+ class ReLU6(Hardtanh):
242
+ r"""Applies the ReLU6 function element-wise.
243
+
244
+ .. math::
245
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
246
+
247
+ Args:
248
+ inplace: can optionally do the operation in-place. Default: ``False``
249
+
250
+ Shape:
251
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
252
+ - Output: :math:`(*)`, same shape as the input.
253
+
254
+ .. image:: ../scripts/activation_images/ReLU6.png
255
+
256
+ Examples::
257
+
258
+ >>> m = nn.ReLU6()
259
+ >>> input = torch.randn(2)
260
+ >>> output = m(input)
261
+ """
262
+
263
+ def __init__(self, inplace: bool = False):
264
+ super().__init__(0., 6., inplace)
265
+
266
+ def extra_repr(self) -> str:
267
+ inplace_str = 'inplace=True' if self.inplace else ''
268
+ return inplace_str
269
+
270
+
271
+ class Sigmoid(Module):
272
+ r"""Applies the Sigmoid function element-wise.
273
+
274
+ .. math::
275
+ \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
276
+
277
+
278
+ Shape:
279
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
280
+ - Output: :math:`(*)`, same shape as the input.
281
+
282
+ .. image:: ../scripts/activation_images/Sigmoid.png
283
+
284
+ Examples::
285
+
286
+ >>> m = nn.Sigmoid()
287
+ >>> input = torch.randn(2)
288
+ >>> output = m(input)
289
+ """
290
+
291
+ def forward(self, input: Tensor) -> Tensor:
292
+ return torch.sigmoid(input)
293
+
294
+
295
+ class Hardsigmoid(Module):
296
+ r"""Applies the Hardsigmoid function element-wise.
297
+
298
+ Hardsigmoid is defined as:
299
+
300
+ .. math::
301
+ \text{Hardsigmoid}(x) = \begin{cases}
302
+ 0 & \text{if~} x \le -3, \\
303
+ 1 & \text{if~} x \ge +3, \\
304
+ x / 6 + 1 / 2 & \text{otherwise}
305
+ \end{cases}
306
+
307
+ Args:
308
+ inplace: can optionally do the operation in-place. Default: ``False``
309
+
310
+ Shape:
311
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
312
+ - Output: :math:`(*)`, same shape as the input.
313
+
314
+ .. image:: ../scripts/activation_images/Hardsigmoid.png
315
+
316
+ Examples::
317
+
318
+ >>> m = nn.Hardsigmoid()
319
+ >>> input = torch.randn(2)
320
+ >>> output = m(input)
321
+ """
322
+
323
+ __constants__ = ['inplace']
324
+
325
+ inplace: bool
326
+
327
+ def __init__(self, inplace : bool = False) -> None:
328
+ super().__init__()
329
+ self.inplace = inplace
330
+
331
+ def forward(self, input: Tensor) -> Tensor:
332
+ return F.hardsigmoid(input, self.inplace)
333
+
334
+
335
+ class Tanh(Module):
336
+ r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
337
+
338
+ Tanh is defined as:
339
+
340
+ .. math::
341
+ \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
342
+
343
+ Shape:
344
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
345
+ - Output: :math:`(*)`, same shape as the input.
346
+
347
+ .. image:: ../scripts/activation_images/Tanh.png
348
+
349
+ Examples::
350
+
351
+ >>> m = nn.Tanh()
352
+ >>> input = torch.randn(2)
353
+ >>> output = m(input)
354
+ """
355
+
356
+ def forward(self, input: Tensor) -> Tensor:
357
+ return torch.tanh(input)
358
+
359
+ class SiLU(Module):
360
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
361
+
362
+ The SiLU function is also known as the swish function.
363
+
364
+ .. math::
365
+ \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
366
+
367
+ .. note::
368
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
369
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
370
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
371
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
372
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
373
+ where the SiLU was experimented with later.
374
+
375
+ Shape:
376
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
377
+ - Output: :math:`(*)`, same shape as the input.
378
+
379
+ .. image:: ../scripts/activation_images/SiLU.png
380
+
381
+ Examples::
382
+
383
+ >>> m = nn.SiLU()
384
+ >>> input = torch.randn(2)
385
+ >>> output = m(input)
386
+ """
387
+
388
+ __constants__ = ['inplace']
389
+ inplace: bool
390
+
391
+ def __init__(self, inplace: bool = False):
392
+ super().__init__()
393
+ self.inplace = inplace
394
+
395
+ def forward(self, input: Tensor) -> Tensor:
396
+ return F.silu(input, inplace=self.inplace)
397
+
398
+ def extra_repr(self) -> str:
399
+ inplace_str = 'inplace=True' if self.inplace else ''
400
+ return inplace_str
401
+
402
+ class Mish(Module):
403
+ r"""Applies the Mish function, element-wise.
404
+
405
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
406
+
407
+ .. math::
408
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
409
+
410
+ .. note::
411
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
412
+
413
+ Shape:
414
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
415
+ - Output: :math:`(*)`, same shape as the input.
416
+
417
+ .. image:: ../scripts/activation_images/Mish.png
418
+
419
+ Examples::
420
+
421
+ >>> m = nn.Mish()
422
+ >>> input = torch.randn(2)
423
+ >>> output = m(input)
424
+ """
425
+
426
+ __constants__ = ['inplace']
427
+ inplace: bool
428
+
429
+ def __init__(self, inplace: bool = False):
430
+ super().__init__()
431
+ self.inplace = inplace
432
+
433
+ def forward(self, input: Tensor) -> Tensor:
434
+ return F.mish(input, inplace=self.inplace)
435
+
436
+ def extra_repr(self) -> str:
437
+ inplace_str = 'inplace=True' if self.inplace else ''
438
+ return inplace_str
439
+
440
+ class Hardswish(Module):
441
+ r"""Applies the Hardswish function, element-wise.
442
+
443
+ Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
444
+
445
+ Hardswish is defined as:
446
+
447
+ .. math::
448
+ \text{Hardswish}(x) = \begin{cases}
449
+ 0 & \text{if~} x \le -3, \\
450
+ x & \text{if~} x \ge +3, \\
451
+ x \cdot (x + 3) /6 & \text{otherwise}
452
+ \end{cases}
453
+
454
+ Args:
455
+ inplace: can optionally do the operation in-place. Default: ``False``
456
+
457
+ Shape:
458
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
459
+ - Output: :math:`(*)`, same shape as the input.
460
+
461
+ .. image:: ../scripts/activation_images/Hardswish.png
462
+
463
+ Examples::
464
+
465
+ >>> m = nn.Hardswish()
466
+ >>> input = torch.randn(2)
467
+ >>> output = m(input)
468
+ """
469
+
470
+ __constants__ = ['inplace']
471
+
472
+ inplace: bool
473
+
474
+ def __init__(self, inplace : bool = False) -> None:
475
+ super().__init__()
476
+ self.inplace = inplace
477
+
478
+ def forward(self, input: Tensor) -> Tensor:
479
+ return F.hardswish(input, self.inplace)
480
+
481
+
482
+ class ELU(Module):
483
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
484
+
485
+ Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
486
+ Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
487
+
488
+ ELU is defined as:
489
+
490
+ .. math::
491
+ \text{ELU}(x) = \begin{cases}
492
+ x, & \text{ if } x > 0\\
493
+ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
494
+ \end{cases}
495
+
496
+ Args:
497
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
498
+ inplace: can optionally do the operation in-place. Default: ``False``
499
+
500
+ Shape:
501
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
502
+ - Output: :math:`(*)`, same shape as the input.
503
+
504
+ .. image:: ../scripts/activation_images/ELU.png
505
+
506
+ Examples::
507
+
508
+ >>> m = nn.ELU()
509
+ >>> input = torch.randn(2)
510
+ >>> output = m(input)
511
+ """
512
+
513
+ __constants__ = ['alpha', 'inplace']
514
+ alpha: float
515
+ inplace: bool
516
+
517
+ def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
518
+ super().__init__()
519
+ self.alpha = alpha
520
+ self.inplace = inplace
521
+
522
+ def forward(self, input: Tensor) -> Tensor:
523
+ return F.elu(input, self.alpha, self.inplace)
524
+
525
+ def extra_repr(self) -> str:
526
+ inplace_str = ', inplace=True' if self.inplace else ''
527
+ return f'alpha={self.alpha}{inplace_str}'
528
+
529
+
530
+ class CELU(Module):
531
+ r"""Applies the CELU function element-wise.
532
+
533
+ .. math::
534
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
535
+
536
+ More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
537
+
538
+ Args:
539
+ alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
540
+ inplace: can optionally do the operation in-place. Default: ``False``
541
+
542
+ Shape:
543
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
544
+ - Output: :math:`(*)`, same shape as the input.
545
+
546
+ .. image:: ../scripts/activation_images/CELU.png
547
+
548
+ Examples::
549
+
550
+ >>> m = nn.CELU()
551
+ >>> input = torch.randn(2)
552
+ >>> output = m(input)
553
+
554
+ .. _`Continuously Differentiable Exponential Linear Units`:
555
+ https://arxiv.org/abs/1704.07483
556
+ """
557
+
558
+ __constants__ = ['alpha', 'inplace']
559
+ alpha: float
560
+ inplace: bool
561
+
562
+ def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
563
+ super().__init__()
564
+ self.alpha = alpha
565
+ self.inplace = inplace
566
+
567
+ def forward(self, input: Tensor) -> Tensor:
568
+ return F.celu(input, self.alpha, self.inplace)
569
+
570
+ def extra_repr(self) -> str:
571
+ inplace_str = ', inplace=True' if self.inplace else ''
572
+ return f'alpha={self.alpha}{inplace_str}'
573
+
574
+
575
+ class SELU(Module):
576
+ r"""Applies the SELU function element-wise.
577
+
578
+ .. math::
579
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
580
+
581
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
582
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
583
+
584
+ .. warning::
585
+ When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
586
+ ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
587
+ in order to get `Self-Normalizing Neural Networks`_.
588
+ See :func:`torch.nn.init.calculate_gain` for more information.
589
+
590
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
591
+
592
+ Args:
593
+ inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
594
+
595
+ Shape:
596
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
597
+ - Output: :math:`(*)`, same shape as the input.
598
+
599
+ .. image:: ../scripts/activation_images/SELU.png
600
+
601
+ Examples::
602
+
603
+ >>> m = nn.SELU()
604
+ >>> input = torch.randn(2)
605
+ >>> output = m(input)
606
+
607
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
608
+ """
609
+
610
+ __constants__ = ['inplace']
611
+ inplace: bool
612
+
613
+ def __init__(self, inplace: bool = False) -> None:
614
+ super().__init__()
615
+ self.inplace = inplace
616
+
617
+ def forward(self, input: Tensor) -> Tensor:
618
+ return F.selu(input, self.inplace)
619
+
620
+ def extra_repr(self) -> str:
621
+ inplace_str = 'inplace=True' if self.inplace else ''
622
+ return inplace_str
623
+
624
+
625
+ class GLU(Module):
626
+ r"""Applies the gated linear unit function.
627
+
628
+ :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
629
+ of the input matrices and :math:`b` is the second half.
630
+
631
+ Args:
632
+ dim (int): the dimension on which to split the input. Default: -1
633
+
634
+ Shape:
635
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
636
+ dimensions
637
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
638
+
639
+ Examples::
640
+
641
+ >>> m = nn.GLU()
642
+ >>> input = torch.randn(4, 2)
643
+ >>> output = m(input)
644
+ """
645
+
646
+ __constants__ = ['dim']
647
+ dim: int
648
+
649
+ def __init__(self, dim: int = -1) -> None:
650
+ super().__init__()
651
+ self.dim = dim
652
+
653
+ def forward(self, input: Tensor) -> Tensor:
654
+ return F.glu(input, self.dim)
655
+
656
+ def extra_repr(self) -> str:
657
+ return f'dim={self.dim}'
658
+
659
+
660
+ class GELU(Module):
661
+ r"""Applies the Gaussian Error Linear Units function.
662
+
663
+ .. math:: \text{GELU}(x) = x * \Phi(x)
664
+
665
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
666
+
667
+ When the approximate argument is 'tanh', Gelu is estimated with:
668
+
669
+ .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
670
+
671
+ Args:
672
+ approximate (str, optional): the gelu approximation algorithm to use:
673
+ ``'none'`` | ``'tanh'``. Default: ``'none'``
674
+
675
+ Shape:
676
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
677
+ - Output: :math:`(*)`, same shape as the input.
678
+
679
+ .. image:: ../scripts/activation_images/GELU.png
680
+
681
+ Examples::
682
+
683
+ >>> m = nn.GELU()
684
+ >>> input = torch.randn(2)
685
+ >>> output = m(input)
686
+ """
687
+
688
+ __constants__ = ['approximate']
689
+ approximate: str
690
+
691
+ def __init__(self, approximate: str = 'none') -> None:
692
+ super().__init__()
693
+ self.approximate = approximate
694
+
695
+ def forward(self, input: Tensor) -> Tensor:
696
+ return F.gelu(input, approximate=self.approximate)
697
+
698
+ def extra_repr(self) -> str:
699
+ return f'approximate={repr(self.approximate)}'
700
+
701
+
702
+ class Hardshrink(Module):
703
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
704
+
705
+ Hardshrink is defined as:
706
+
707
+ .. math::
708
+ \text{HardShrink}(x) =
709
+ \begin{cases}
710
+ x, & \text{ if } x > \lambda \\
711
+ x, & \text{ if } x < -\lambda \\
712
+ 0, & \text{ otherwise }
713
+ \end{cases}
714
+
715
+ Args:
716
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
717
+
718
+ Shape:
719
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
720
+ - Output: :math:`(*)`, same shape as the input.
721
+
722
+ .. image:: ../scripts/activation_images/Hardshrink.png
723
+
724
+ Examples::
725
+
726
+ >>> m = nn.Hardshrink()
727
+ >>> input = torch.randn(2)
728
+ >>> output = m(input)
729
+ """
730
+
731
+ __constants__ = ['lambd']
732
+ lambd: float
733
+
734
+ def __init__(self, lambd: float = 0.5) -> None:
735
+ super().__init__()
736
+ self.lambd = lambd
737
+
738
+ def forward(self, input: Tensor) -> Tensor:
739
+ return F.hardshrink(input, self.lambd)
740
+
741
+ def extra_repr(self) -> str:
742
+ return f'{self.lambd}'
743
+
744
+
745
+ class LeakyReLU(Module):
746
+ r"""Applies the LeakyReLU function element-wise.
747
+
748
+ .. math::
749
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
750
+
751
+
752
+ or
753
+
754
+ .. math::
755
+ \text{LeakyReLU}(x) =
756
+ \begin{cases}
757
+ x, & \text{ if } x \geq 0 \\
758
+ \text{negative\_slope} \times x, & \text{ otherwise }
759
+ \end{cases}
760
+
761
+ Args:
762
+ negative_slope: Controls the angle of the negative slope (which is used for
763
+ negative input values). Default: 1e-2
764
+ inplace: can optionally do the operation in-place. Default: ``False``
765
+
766
+ Shape:
767
+ - Input: :math:`(*)` where `*` means, any number of additional
768
+ dimensions
769
+ - Output: :math:`(*)`, same shape as the input
770
+
771
+ .. image:: ../scripts/activation_images/LeakyReLU.png
772
+
773
+ Examples::
774
+
775
+ >>> m = nn.LeakyReLU(0.1)
776
+ >>> input = torch.randn(2)
777
+ >>> output = m(input)
778
+ """
779
+
780
+ __constants__ = ['inplace', 'negative_slope']
781
+ inplace: bool
782
+ negative_slope: float
783
+
784
+ def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
785
+ super().__init__()
786
+ self.negative_slope = negative_slope
787
+ self.inplace = inplace
788
+
789
+ def forward(self, input: Tensor) -> Tensor:
790
+ return F.leaky_relu(input, self.negative_slope, self.inplace)
791
+
792
+ def extra_repr(self) -> str:
793
+ inplace_str = ', inplace=True' if self.inplace else ''
794
+ return f'negative_slope={self.negative_slope}{inplace_str}'
795
+
796
+
797
+ class LogSigmoid(Module):
798
+ r"""Applies the Logsigmoid function element-wise.
799
+
800
+ .. math::
801
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
802
+
803
+ Shape:
804
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
805
+ - Output: :math:`(*)`, same shape as the input.
806
+
807
+ .. image:: ../scripts/activation_images/LogSigmoid.png
808
+
809
+ Examples::
810
+
811
+ >>> m = nn.LogSigmoid()
812
+ >>> input = torch.randn(2)
813
+ >>> output = m(input)
814
+ """
815
+
816
+ def forward(self, input: Tensor) -> Tensor:
817
+ return F.logsigmoid(input)
818
+
819
+
820
+ class Softplus(Module):
821
+ r"""Applies the Softplus function element-wise.
822
+
823
+ .. math::
824
+ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
825
+
826
+ SoftPlus is a smooth approximation to the ReLU function and can be used
827
+ to constrain the output of a machine to always be positive.
828
+
829
+ For numerical stability the implementation reverts to the linear function
830
+ when :math:`input \times \beta > threshold`.
831
+
832
+ Args:
833
+ beta: the :math:`\beta` value for the Softplus formulation. Default: 1
834
+ threshold: values above this revert to a linear function. Default: 20
835
+
836
+ Shape:
837
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
838
+ - Output: :math:`(*)`, same shape as the input.
839
+
840
+ .. image:: ../scripts/activation_images/Softplus.png
841
+
842
+ Examples::
843
+
844
+ >>> m = nn.Softplus()
845
+ >>> input = torch.randn(2)
846
+ >>> output = m(input)
847
+ """
848
+
849
+ __constants__ = ['beta', 'threshold']
850
+ beta: float
851
+ threshold: float
852
+
853
+ def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
854
+ super().__init__()
855
+ self.beta = beta
856
+ self.threshold = threshold
857
+
858
+ def forward(self, input: Tensor) -> Tensor:
859
+ return F.softplus(input, self.beta, self.threshold)
860
+
861
+ def extra_repr(self) -> str:
862
+ return f'beta={self.beta}, threshold={self.threshold}'
863
+
864
+
865
+ class Softshrink(Module):
866
+ r"""Applies the soft shrinkage function element-wise.
867
+
868
+ .. math::
869
+ \text{SoftShrinkage}(x) =
870
+ \begin{cases}
871
+ x - \lambda, & \text{ if } x > \lambda \\
872
+ x + \lambda, & \text{ if } x < -\lambda \\
873
+ 0, & \text{ otherwise }
874
+ \end{cases}
875
+
876
+ Args:
877
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
878
+
879
+ Shape:
880
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
881
+ - Output: :math:`(*)`, same shape as the input.
882
+
883
+ .. image:: ../scripts/activation_images/Softshrink.png
884
+
885
+ Examples::
886
+
887
+ >>> m = nn.Softshrink()
888
+ >>> input = torch.randn(2)
889
+ >>> output = m(input)
890
+ """
891
+
892
+ __constants__ = ['lambd']
893
+ lambd: float
894
+
895
+ def __init__(self, lambd: float = 0.5) -> None:
896
+ super().__init__()
897
+ self.lambd = lambd
898
+
899
+ def forward(self, input: Tensor) -> Tensor:
900
+ return F.softshrink(input, self.lambd)
901
+
902
+ def extra_repr(self) -> str:
903
+ return str(self.lambd)
904
+
905
+
906
+ def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
907
+ if x is not None:
908
+ return x.device.type in ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
909
+ return True
910
+
911
+
912
+ def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
913
+ if x is not None:
914
+ return x.requires_grad
915
+ return False
916
+
917
+
918
+ def _is_make_fx_tracing():
919
+ if not torch.jit.is_scripting():
920
+ torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack()
921
+ return any(type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack)
922
+ else:
923
+ return False
924
+
925
+
926
+ class MultiheadAttention(Module):
927
+ r"""Allows the model to jointly attend to information from different representation subspaces.
928
+
929
+ Method described in the paper:
930
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
931
+
932
+ Multi-Head Attention is defined as:
933
+
934
+ .. math::
935
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
936
+
937
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
938
+
939
+ ``nn.MultiHeadAttention`` will use the optimized implementations of
940
+ ``scaled_dot_product_attention()`` when possible.
941
+
942
+ In addition to support for the new ``scaled_dot_product_attention()``
943
+ function, for speeding up Inference, MHA will use
944
+ fastpath inference with support for Nested Tensors, iff:
945
+
946
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
947
+ - inputs are batched (3D) with ``batch_first==True``
948
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
949
+ - training is disabled (using ``.eval()``)
950
+ - ``add_bias_kv`` is ``False``
951
+ - ``add_zero_attn`` is ``False``
952
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
953
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
954
+ nor ``attn_mask`` is passed
955
+ - autocast is disabled
956
+
957
+ If the optimized inference fastpath implementation is in use, a
958
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
959
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
960
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
961
+ will be returned, and an additional speedup proportional to the fraction of the input
962
+ that is padding can be expected.
963
+
964
+ Args:
965
+ embed_dim: Total dimension of the model.
966
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
967
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
968
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
969
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
970
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
971
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
972
+ Default: ``False``.
973
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
974
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
975
+ batch_first: If ``True``, then the input and output tensors are provided
976
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
977
+
978
+ Examples::
979
+
980
+ >>> # xdoctest: +SKIP
981
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
982
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
983
+
984
+ .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
985
+ https://arxiv.org/abs/2205.14135
986
+
987
+ """
988
+
989
+ __constants__ = ['batch_first']
990
+ bias_k: Optional[torch.Tensor]
991
+ bias_v: Optional[torch.Tensor]
992
+
993
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
994
+ kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
995
+ if embed_dim <= 0 or num_heads <= 0:
996
+ raise ValueError(
997
+ f"embed_dim and num_heads must be greater than 0,"
998
+ f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
999
+ )
1000
+ factory_kwargs = {'device': device, 'dtype': dtype}
1001
+ super().__init__()
1002
+ self.embed_dim = embed_dim
1003
+ self.kdim = kdim if kdim is not None else embed_dim
1004
+ self.vdim = vdim if vdim is not None else embed_dim
1005
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
1006
+
1007
+ self.num_heads = num_heads
1008
+ self.dropout = dropout
1009
+ self.batch_first = batch_first
1010
+ self.head_dim = embed_dim // num_heads
1011
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
1012
+
1013
+ if not self._qkv_same_embed_dim:
1014
+ self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
1015
+ self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
1016
+ self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
1017
+ self.register_parameter('in_proj_weight', None)
1018
+ else:
1019
+ self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
1020
+ self.register_parameter('q_proj_weight', None)
1021
+ self.register_parameter('k_proj_weight', None)
1022
+ self.register_parameter('v_proj_weight', None)
1023
+
1024
+ if bias:
1025
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
1026
+ else:
1027
+ self.register_parameter('in_proj_bias', None)
1028
+ self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
1029
+
1030
+ if add_bias_kv:
1031
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1032
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
1033
+ else:
1034
+ self.bias_k = self.bias_v = None
1035
+
1036
+ self.add_zero_attn = add_zero_attn
1037
+
1038
+ self._reset_parameters()
1039
+
1040
+ def _reset_parameters(self):
1041
+ if self._qkv_same_embed_dim:
1042
+ xavier_uniform_(self.in_proj_weight)
1043
+ else:
1044
+ xavier_uniform_(self.q_proj_weight)
1045
+ xavier_uniform_(self.k_proj_weight)
1046
+ xavier_uniform_(self.v_proj_weight)
1047
+
1048
+ if self.in_proj_bias is not None:
1049
+ constant_(self.in_proj_bias, 0.)
1050
+ constant_(self.out_proj.bias, 0.)
1051
+ if self.bias_k is not None:
1052
+ xavier_normal_(self.bias_k)
1053
+ if self.bias_v is not None:
1054
+ xavier_normal_(self.bias_v)
1055
+
1056
+ def __setstate__(self, state):
1057
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
1058
+ if '_qkv_same_embed_dim' not in state:
1059
+ state['_qkv_same_embed_dim'] = True
1060
+
1061
+ super().__setstate__(state)
1062
+
1063
+ def forward(
1064
+ self,
1065
+ query: Tensor,
1066
+ key: Tensor,
1067
+ value: Tensor,
1068
+ key_padding_mask: Optional[Tensor] = None,
1069
+ need_weights: bool = True,
1070
+ attn_mask: Optional[Tensor] = None,
1071
+ average_attn_weights: bool = True,
1072
+ is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
1073
+ r"""Compute attention outputs using query, key, and value embeddings.
1074
+
1075
+ Supports optional parameters for padding, masks and attention weights.
1076
+
1077
+ Args:
1078
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
1079
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
1080
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
1081
+ Queries are compared against key-value pairs to produce the output.
1082
+ See "Attention Is All You Need" for more details.
1083
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
1084
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
1085
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
1086
+ See "Attention Is All You Need" for more details.
1087
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
1088
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
1089
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
1090
+ See "Attention Is All You Need" for more details.
1091
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
1092
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
1093
+ Binary and float masks are supported.
1094
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
1095
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
1096
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
1097
+ Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
1098
+ and achieve the best performance for MHA.
1099
+ Default: ``True``.
1100
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
1101
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
1102
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
1103
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
1104
+ Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
1105
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
1106
+ the attention weight.
1107
+ If both attn_mask and key_padding_mask are supplied, their types should match.
1108
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
1109
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
1110
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
1111
+ is_causal: If specified, applies a causal mask as attention mask.
1112
+ Default: ``False``.
1113
+ Warning:
1114
+ ``is_causal`` provides a hint that ``attn_mask`` is the
1115
+ causal mask. Providing incorrect hints can result in
1116
+ incorrect execution, including forward and backward
1117
+ compatibility.
1118
+
1119
+ Outputs:
1120
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
1121
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
1122
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
1123
+ embedding dimension ``embed_dim``.
1124
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
1125
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
1126
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
1127
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
1128
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
1129
+
1130
+ .. note::
1131
+ `batch_first` argument is ignored for unbatched inputs.
1132
+ """
1133
+ why_not_fast_path = ''
1134
+ if ((attn_mask is not None and torch.is_floating_point(attn_mask))
1135
+ or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
1136
+ why_not_fast_path = "floating-point masks are not supported for fast path."
1137
+
1138
+ is_batched = query.dim() == 3
1139
+
1140
+ key_padding_mask = F._canonical_mask(
1141
+ mask=key_padding_mask,
1142
+ mask_name="key_padding_mask",
1143
+ other_type=F._none_or_dtype(attn_mask),
1144
+ other_name="attn_mask",
1145
+ target_type=query.dtype
1146
+ )
1147
+
1148
+ attn_mask = F._canonical_mask(
1149
+ mask=attn_mask,
1150
+ mask_name="attn_mask",
1151
+ other_type=None,
1152
+ other_name="",
1153
+ target_type=query.dtype,
1154
+ check_other=False,
1155
+ )
1156
+
1157
+ is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
1158
+
1159
+ if not is_fastpath_enabled:
1160
+ why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
1161
+ elif not is_batched:
1162
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
1163
+ elif query is not key or key is not value:
1164
+ # When lifting this restriction, don't forget to either
1165
+ # enforce that the dtypes all match or test cases where
1166
+ # they don't!
1167
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
1168
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
1169
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
1170
+ elif self.in_proj_weight is None:
1171
+ why_not_fast_path = "in_proj_weight was None"
1172
+ elif query.dtype != self.in_proj_weight.dtype:
1173
+ # this case will fail anyway, but at least they'll get a useful error message.
1174
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
1175
+ elif self.training:
1176
+ why_not_fast_path = "training is enabled"
1177
+ elif (self.num_heads % 2) != 0:
1178
+ why_not_fast_path = "self.num_heads is not even"
1179
+ elif not self.batch_first:
1180
+ why_not_fast_path = "batch_first was not True"
1181
+ elif self.bias_k is not None:
1182
+ why_not_fast_path = "self.bias_k was not None"
1183
+ elif self.bias_v is not None:
1184
+ why_not_fast_path = "self.bias_v was not None"
1185
+ elif self.add_zero_attn:
1186
+ why_not_fast_path = "add_zero_attn was enabled"
1187
+ elif not self._qkv_same_embed_dim:
1188
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
1189
+ elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
1190
+ why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
1191
+ is not supported with NestedTensor input"
1192
+ elif torch.is_autocast_enabled():
1193
+ why_not_fast_path = "autocast is enabled"
1194
+
1195
+ if not why_not_fast_path:
1196
+ tensor_args = (
1197
+ query,
1198
+ key,
1199
+ value,
1200
+ self.in_proj_weight,
1201
+ self.in_proj_bias,
1202
+ self.out_proj.weight,
1203
+ self.out_proj.bias,
1204
+ )
1205
+ # We have to use list comprehensions below because TorchScript does not support
1206
+ # generator expressions.
1207
+ if torch.overrides.has_torch_function(tensor_args):
1208
+ why_not_fast_path = "some Tensor argument has_torch_function"
1209
+ elif _is_make_fx_tracing():
1210
+ why_not_fast_path = "we are running make_fx tracing"
1211
+ elif not all(_check_arg_device(x) for x in tensor_args):
1212
+ why_not_fast_path = ("some Tensor argument's device is neither one of "
1213
+ f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
1214
+ elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
1215
+ why_not_fast_path = ("grad is enabled and at least one of query or the "
1216
+ "input/output projection weights or biases requires_grad")
1217
+ if not why_not_fast_path:
1218
+ merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
1219
+
1220
+ if self.in_proj_bias is not None and self.in_proj_weight is not None:
1221
+ return torch._native_multi_head_attention(
1222
+ query,
1223
+ key,
1224
+ value,
1225
+ self.embed_dim,
1226
+ self.num_heads,
1227
+ self.in_proj_weight,
1228
+ self.in_proj_bias,
1229
+ self.out_proj.weight,
1230
+ self.out_proj.bias,
1231
+ merged_mask,
1232
+ need_weights,
1233
+ average_attn_weights,
1234
+ mask_type)
1235
+
1236
+ any_nested = query.is_nested or key.is_nested or value.is_nested
1237
+ assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
1238
+ f"The fast path was not hit because {why_not_fast_path}")
1239
+
1240
+ if self.batch_first and is_batched:
1241
+ # make sure that the transpose op does not affect the "is" property
1242
+ if key is value:
1243
+ if query is key:
1244
+ query = key = value = query.transpose(1, 0)
1245
+ else:
1246
+ query, key = (x.transpose(1, 0) for x in (query, key))
1247
+ value = key
1248
+ else:
1249
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
1250
+
1251
+ if not self._qkv_same_embed_dim:
1252
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
1253
+ query, key, value, self.embed_dim, self.num_heads,
1254
+ self.in_proj_weight, self.in_proj_bias,
1255
+ self.bias_k, self.bias_v, self.add_zero_attn,
1256
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
1257
+ training=self.training,
1258
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
1259
+ attn_mask=attn_mask,
1260
+ use_separate_proj_weight=True,
1261
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
1262
+ v_proj_weight=self.v_proj_weight,
1263
+ average_attn_weights=average_attn_weights,
1264
+ is_causal=is_causal)
1265
+ else:
1266
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
1267
+ query, key, value, self.embed_dim, self.num_heads,
1268
+ self.in_proj_weight, self.in_proj_bias,
1269
+ self.bias_k, self.bias_v, self.add_zero_attn,
1270
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
1271
+ training=self.training,
1272
+ key_padding_mask=key_padding_mask,
1273
+ need_weights=need_weights,
1274
+ attn_mask=attn_mask,
1275
+ average_attn_weights=average_attn_weights,
1276
+ is_causal=is_causal)
1277
+ if self.batch_first and is_batched:
1278
+ return attn_output.transpose(1, 0), attn_output_weights
1279
+ else:
1280
+ return attn_output, attn_output_weights
1281
+
1282
+ def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
1283
+ query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
1284
+ r"""Determine mask type and combine masks if necessary.
1285
+
1286
+ If only one mask is provided, that mask
1287
+ and the corresponding mask type will be returned. If both masks are provided, they will be both
1288
+ expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
1289
+ and mask type 2 will be returned
1290
+ Args:
1291
+ attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
1292
+ key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
1293
+ query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
1294
+ Returns:
1295
+ merged_mask: merged mask
1296
+ mask_type: merged mask type (0, 1, or 2)
1297
+ """
1298
+ mask_type: Optional[int] = None
1299
+ merged_mask: Optional[Tensor] = None
1300
+
1301
+ if key_padding_mask is not None:
1302
+ mask_type = 1
1303
+ merged_mask = key_padding_mask
1304
+
1305
+ if attn_mask is not None:
1306
+ # In this branch query can't be a nested tensor, so it has a shape
1307
+ batch_size, seq_len, _ = query.shape
1308
+ mask_type = 2
1309
+
1310
+ # Always expands attn_mask to 4D
1311
+ if attn_mask.dim() == 3:
1312
+ attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
1313
+ else: # attn_mask.dim() == 2:
1314
+ attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
1315
+ merged_mask = attn_mask_expanded
1316
+
1317
+ if key_padding_mask is not None:
1318
+ key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
1319
+ merged_mask = attn_mask_expanded + key_padding_mask_expanded
1320
+
1321
+ # no attn_mask and no key_padding_mask, returns None, None
1322
+ return merged_mask, mask_type
1323
+
1324
+
1325
+ class PReLU(Module):
1326
+ r"""Applies the element-wise PReLU function.
1327
+
1328
+ .. math::
1329
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
1330
+
1331
+ or
1332
+
1333
+ .. math::
1334
+ \text{PReLU}(x) =
1335
+ \begin{cases}
1336
+ x, & \text{ if } x \geq 0 \\
1337
+ ax, & \text{ otherwise }
1338
+ \end{cases}
1339
+
1340
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
1341
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
1342
+ a separate :math:`a` is used for each input channel.
1343
+
1344
+
1345
+ .. note::
1346
+ weight decay should not be used when learning :math:`a` for good performance.
1347
+
1348
+ .. note::
1349
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
1350
+ no channel dim and the number of channels = 1.
1351
+
1352
+ Args:
1353
+ num_parameters (int): number of :math:`a` to learn.
1354
+ Although it takes an int as input, there is only two values are legitimate:
1355
+ 1, or the number of channels at input. Default: 1
1356
+ init (float): the initial value of :math:`a`. Default: 0.25
1357
+
1358
+ Shape:
1359
+ - Input: :math:`( *)` where `*` means, any number of additional
1360
+ dimensions.
1361
+ - Output: :math:`(*)`, same shape as the input.
1362
+
1363
+ Attributes:
1364
+ weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
1365
+
1366
+ .. image:: ../scripts/activation_images/PReLU.png
1367
+
1368
+ Examples::
1369
+
1370
+ >>> m = nn.PReLU()
1371
+ >>> input = torch.randn(2)
1372
+ >>> output = m(input)
1373
+ """
1374
+
1375
+ __constants__ = ['num_parameters']
1376
+ num_parameters: int
1377
+
1378
+ def __init__(self, num_parameters: int = 1, init: float = 0.25,
1379
+ device=None, dtype=None) -> None:
1380
+ factory_kwargs = {'device': device, 'dtype': dtype}
1381
+ self.num_parameters = num_parameters
1382
+ super().__init__()
1383
+ self.init = init
1384
+ self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
1385
+ self.reset_parameters()
1386
+
1387
+ def reset_parameters(self):
1388
+ torch.nn.init.constant_(self.weight, self.init)
1389
+
1390
+ def forward(self, input: Tensor) -> Tensor:
1391
+ return F.prelu(input, self.weight)
1392
+
1393
+ def extra_repr(self) -> str:
1394
+ return f'num_parameters={self.num_parameters}'
1395
+
1396
+
1397
+ class Softsign(Module):
1398
+ r"""Applies the element-wise Softsign function.
1399
+
1400
+ .. math::
1401
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
1402
+
1403
+ Shape:
1404
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1405
+ - Output: :math:`(*)`, same shape as the input.
1406
+
1407
+ .. image:: ../scripts/activation_images/Softsign.png
1408
+
1409
+ Examples::
1410
+
1411
+ >>> m = nn.Softsign()
1412
+ >>> input = torch.randn(2)
1413
+ >>> output = m(input)
1414
+ """
1415
+
1416
+ def forward(self, input: Tensor) -> Tensor:
1417
+ return F.softsign(input)
1418
+
1419
+
1420
+ class Tanhshrink(Module):
1421
+ r"""Applies the element-wise Tanhshrink function.
1422
+
1423
+ .. math::
1424
+ \text{Tanhshrink}(x) = x - \tanh(x)
1425
+
1426
+ Shape:
1427
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1428
+ - Output: :math:`(*)`, same shape as the input.
1429
+
1430
+ .. image:: ../scripts/activation_images/Tanhshrink.png
1431
+
1432
+ Examples::
1433
+
1434
+ >>> m = nn.Tanhshrink()
1435
+ >>> input = torch.randn(2)
1436
+ >>> output = m(input)
1437
+ """
1438
+
1439
+ def forward(self, input: Tensor) -> Tensor:
1440
+ return F.tanhshrink(input)
1441
+
1442
+
1443
+ class Softmin(Module):
1444
+ r"""Applies the Softmin function to an n-dimensional input Tensor.
1445
+
1446
+ Rescales them so that the elements of the n-dimensional output Tensor
1447
+ lie in the range `[0, 1]` and sum to 1.
1448
+
1449
+ Softmin is defined as:
1450
+
1451
+ .. math::
1452
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
1453
+
1454
+ Shape:
1455
+ - Input: :math:`(*)` where `*` means, any number of additional
1456
+ dimensions
1457
+ - Output: :math:`(*)`, same shape as the input
1458
+
1459
+ Args:
1460
+ dim (int): A dimension along which Softmin will be computed (so every slice
1461
+ along dim will sum to 1).
1462
+
1463
+ Returns:
1464
+ a Tensor of the same dimension and shape as the input, with
1465
+ values in the range [0, 1]
1466
+
1467
+ Examples::
1468
+
1469
+ >>> m = nn.Softmin(dim=1)
1470
+ >>> input = torch.randn(2, 3)
1471
+ >>> output = m(input)
1472
+ """
1473
+
1474
+ __constants__ = ['dim']
1475
+ dim: Optional[int]
1476
+
1477
+ def __init__(self, dim: Optional[int] = None) -> None:
1478
+ super().__init__()
1479
+ self.dim = dim
1480
+
1481
+ def __setstate__(self, state):
1482
+ super().__setstate__(state)
1483
+ if not hasattr(self, 'dim'):
1484
+ self.dim = None
1485
+
1486
+ def forward(self, input: Tensor) -> Tensor:
1487
+ return F.softmin(input, self.dim, _stacklevel=5)
1488
+
1489
+ def extra_repr(self):
1490
+ return f'dim={self.dim}'
1491
+
1492
+ class Softmax(Module):
1493
+ r"""Applies the Softmax function to an n-dimensional input Tensor.
1494
+
1495
+ Rescales them so that the elements of the n-dimensional output Tensor
1496
+ lie in the range [0,1] and sum to 1.
1497
+
1498
+ Softmax is defined as:
1499
+
1500
+ .. math::
1501
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
1502
+
1503
+ When the input Tensor is a sparse tensor then the unspecified
1504
+ values are treated as ``-inf``.
1505
+
1506
+ Shape:
1507
+ - Input: :math:`(*)` where `*` means, any number of additional
1508
+ dimensions
1509
+ - Output: :math:`(*)`, same shape as the input
1510
+
1511
+ Returns:
1512
+ a Tensor of the same dimension and shape as the input with
1513
+ values in the range [0, 1]
1514
+
1515
+ Args:
1516
+ dim (int): A dimension along which Softmax will be computed (so every slice
1517
+ along dim will sum to 1).
1518
+
1519
+ .. note::
1520
+ This module doesn't work directly with NLLLoss,
1521
+ which expects the Log to be computed between the Softmax and itself.
1522
+ Use `LogSoftmax` instead (it's faster and has better numerical properties).
1523
+
1524
+ Examples::
1525
+
1526
+ >>> m = nn.Softmax(dim=1)
1527
+ >>> input = torch.randn(2, 3)
1528
+ >>> output = m(input)
1529
+
1530
+ """
1531
+
1532
+ __constants__ = ['dim']
1533
+ dim: Optional[int]
1534
+
1535
+ def __init__(self, dim: Optional[int] = None) -> None:
1536
+ super().__init__()
1537
+ self.dim = dim
1538
+
1539
+ def __setstate__(self, state):
1540
+ super().__setstate__(state)
1541
+ if not hasattr(self, 'dim'):
1542
+ self.dim = None
1543
+
1544
+ def forward(self, input: Tensor) -> Tensor:
1545
+ return F.softmax(input, self.dim, _stacklevel=5)
1546
+
1547
+ def extra_repr(self) -> str:
1548
+ return f'dim={self.dim}'
1549
+
1550
+
1551
+ class Softmax2d(Module):
1552
+ r"""Applies SoftMax over features to each spatial location.
1553
+
1554
+ When given an image of ``Channels x Height x Width``, it will
1555
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1556
+
1557
+ Shape:
1558
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1559
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1560
+
1561
+ Returns:
1562
+ a Tensor of the same dimension and shape as the input with
1563
+ values in the range [0, 1]
1564
+
1565
+ Examples::
1566
+
1567
+ >>> m = nn.Softmax2d()
1568
+ >>> # you softmax over the 2nd dimension
1569
+ >>> input = torch.randn(2, 3, 12, 13)
1570
+ >>> output = m(input)
1571
+ """
1572
+
1573
+ def forward(self, input: Tensor) -> Tensor:
1574
+ if input.dim() not in (3, 4):
1575
+ raise ValueError(
1576
+ f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
1577
+ )
1578
+ return F.softmax(input, -3, _stacklevel=5)
1579
+
1580
+
1581
+ class LogSoftmax(Module):
1582
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1583
+
1584
+ The LogSoftmax formulation can be simplified as:
1585
+
1586
+ .. math::
1587
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1588
+
1589
+ Shape:
1590
+ - Input: :math:`(*)` where `*` means, any number of additional
1591
+ dimensions
1592
+ - Output: :math:`(*)`, same shape as the input
1593
+
1594
+ Args:
1595
+ dim (int): A dimension along which LogSoftmax will be computed.
1596
+
1597
+ Returns:
1598
+ a Tensor of the same dimension and shape as the input with
1599
+ values in the range [-inf, 0)
1600
+
1601
+ Examples::
1602
+
1603
+ >>> m = nn.LogSoftmax(dim=1)
1604
+ >>> input = torch.randn(2, 3)
1605
+ >>> output = m(input)
1606
+ """
1607
+
1608
+ __constants__ = ['dim']
1609
+ dim: Optional[int]
1610
+
1611
+ def __init__(self, dim: Optional[int] = None) -> None:
1612
+ super().__init__()
1613
+ self.dim = dim
1614
+
1615
+ def __setstate__(self, state):
1616
+ super().__setstate__(state)
1617
+ if not hasattr(self, 'dim'):
1618
+ self.dim = None
1619
+
1620
+ def forward(self, input: Tensor) -> Tensor:
1621
+ return F.log_softmax(input, self.dim, _stacklevel=5)
1622
+
1623
+ def extra_repr(self):
1624
+ return f'dim={self.dim}'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Any
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
6
+
7
+ from .. import functional as F
8
+ from .. import init
9
+ from ._functions import SyncBatchNorm as sync_batch_norm
10
+ from .lazy import LazyModuleMixin
11
+ from .module import Module
12
+
13
+ __all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d',
14
+ 'LazyBatchNorm3d', 'SyncBatchNorm']
15
+
16
+
17
+ class _NormBase(Module):
18
+ """Common base of _InstanceNorm and _BatchNorm."""
19
+
20
+ _version = 2
21
+ __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
22
+ num_features: int
23
+ eps: float
24
+ momentum: float
25
+ affine: bool
26
+ track_running_stats: bool
27
+ # WARNING: weight and bias purposely not defined here.
28
+ # See https://github.com/pytorch/pytorch/issues/39670
29
+
30
+ def __init__(
31
+ self,
32
+ num_features: int,
33
+ eps: float = 1e-5,
34
+ momentum: float = 0.1,
35
+ affine: bool = True,
36
+ track_running_stats: bool = True,
37
+ device=None,
38
+ dtype=None
39
+ ) -> None:
40
+ factory_kwargs = {'device': device, 'dtype': dtype}
41
+ super().__init__()
42
+ self.num_features = num_features
43
+ self.eps = eps
44
+ self.momentum = momentum
45
+ self.affine = affine
46
+ self.track_running_stats = track_running_stats
47
+ if self.affine:
48
+ self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
49
+ self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
50
+ else:
51
+ self.register_parameter("weight", None)
52
+ self.register_parameter("bias", None)
53
+ if self.track_running_stats:
54
+ self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
55
+ self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
56
+ self.running_mean: Optional[Tensor]
57
+ self.running_var: Optional[Tensor]
58
+ self.register_buffer('num_batches_tracked',
59
+ torch.tensor(0, dtype=torch.long,
60
+ **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
61
+ self.num_batches_tracked: Optional[Tensor]
62
+ else:
63
+ self.register_buffer("running_mean", None)
64
+ self.register_buffer("running_var", None)
65
+ self.register_buffer("num_batches_tracked", None)
66
+ self.reset_parameters()
67
+
68
+ def reset_running_stats(self) -> None:
69
+ if self.track_running_stats:
70
+ # running_mean/running_var/num_batches... are registered at runtime depending
71
+ # if self.track_running_stats is on
72
+ self.running_mean.zero_() # type: ignore[union-attr]
73
+ self.running_var.fill_(1) # type: ignore[union-attr]
74
+ self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
75
+
76
+ def reset_parameters(self) -> None:
77
+ self.reset_running_stats()
78
+ if self.affine:
79
+ init.ones_(self.weight)
80
+ init.zeros_(self.bias)
81
+
82
+ def _check_input_dim(self, input):
83
+ raise NotImplementedError
84
+
85
+ def extra_repr(self):
86
+ return (
87
+ "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
88
+ "track_running_stats={track_running_stats}".format(**self.__dict__)
89
+ )
90
+
91
+ def _load_from_state_dict(
92
+ self,
93
+ state_dict,
94
+ prefix,
95
+ local_metadata,
96
+ strict,
97
+ missing_keys,
98
+ unexpected_keys,
99
+ error_msgs,
100
+ ):
101
+ version = local_metadata.get("version", None)
102
+
103
+ if (version is None or version < 2) and self.track_running_stats:
104
+ # at version 2: added num_batches_tracked buffer
105
+ # this should have a default value of 0
106
+ num_batches_tracked_key = prefix + "num_batches_tracked"
107
+ if num_batches_tracked_key not in state_dict:
108
+ state_dict[num_batches_tracked_key] = (
109
+ self.num_batches_tracked
110
+ if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta')
111
+ else torch.tensor(0, dtype=torch.long)
112
+ )
113
+
114
+ super()._load_from_state_dict(
115
+ state_dict,
116
+ prefix,
117
+ local_metadata,
118
+ strict,
119
+ missing_keys,
120
+ unexpected_keys,
121
+ error_msgs,
122
+ )
123
+
124
+
125
+ class _BatchNorm(_NormBase):
126
+ def __init__(
127
+ self,
128
+ num_features: int,
129
+ eps: float = 1e-5,
130
+ momentum: float = 0.1,
131
+ affine: bool = True,
132
+ track_running_stats: bool = True,
133
+ device=None,
134
+ dtype=None
135
+ ) -> None:
136
+ factory_kwargs = {'device': device, 'dtype': dtype}
137
+ super().__init__(
138
+ num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
139
+ )
140
+
141
+ def forward(self, input: Tensor) -> Tensor:
142
+ self._check_input_dim(input)
143
+
144
+ # exponential_average_factor is set to self.momentum
145
+ # (when it is available) only so that it gets updated
146
+ # in ONNX graph when this node is exported to ONNX.
147
+ if self.momentum is None:
148
+ exponential_average_factor = 0.0
149
+ else:
150
+ exponential_average_factor = self.momentum
151
+
152
+ if self.training and self.track_running_stats:
153
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
154
+ if self.num_batches_tracked is not None: # type: ignore[has-type]
155
+ self.num_batches_tracked.add_(1) # type: ignore[has-type]
156
+ if self.momentum is None: # use cumulative moving average
157
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
158
+ else: # use exponential moving average
159
+ exponential_average_factor = self.momentum
160
+
161
+ r"""
162
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
163
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
164
+ """
165
+ if self.training:
166
+ bn_training = True
167
+ else:
168
+ bn_training = (self.running_mean is None) and (self.running_var is None)
169
+
170
+ r"""
171
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
172
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
173
+ used for normalization (i.e. in eval mode when buffers are not None).
174
+ """
175
+ return F.batch_norm(
176
+ input,
177
+ # If buffers are not to be tracked, ensure that they won't be updated
178
+ self.running_mean
179
+ if not self.training or self.track_running_stats
180
+ else None,
181
+ self.running_var if not self.training or self.track_running_stats else None,
182
+ self.weight,
183
+ self.bias,
184
+ bn_training,
185
+ exponential_average_factor,
186
+ self.eps,
187
+ )
188
+
189
+
190
+ class _LazyNormBase(LazyModuleMixin, _NormBase):
191
+
192
+ weight: UninitializedParameter # type: ignore[assignment]
193
+ bias: UninitializedParameter # type: ignore[assignment]
194
+
195
+ def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
196
+ device=None, dtype=None) -> None:
197
+ factory_kwargs = {'device': device, 'dtype': dtype}
198
+ super().__init__(
199
+ # affine and track_running_stats are hardcoded to False to
200
+ # avoid creating tensors that will soon be overwritten.
201
+ 0,
202
+ eps,
203
+ momentum,
204
+ False,
205
+ False,
206
+ **factory_kwargs,
207
+ )
208
+ self.affine = affine
209
+ self.track_running_stats = track_running_stats
210
+ if self.affine:
211
+ self.weight = UninitializedParameter(**factory_kwargs)
212
+ self.bias = UninitializedParameter(**factory_kwargs)
213
+ if self.track_running_stats:
214
+ self.running_mean = UninitializedBuffer(**factory_kwargs)
215
+ self.running_var = UninitializedBuffer(**factory_kwargs)
216
+ self.num_batches_tracked = torch.tensor(
217
+ 0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
218
+
219
+ def reset_parameters(self) -> None:
220
+ if not self.has_uninitialized_params() and self.num_features != 0:
221
+ super().reset_parameters()
222
+
223
+ def initialize_parameters(self, input) -> None: # type: ignore[override]
224
+ if self.has_uninitialized_params():
225
+ self.num_features = input.shape[1]
226
+ if self.affine:
227
+ assert isinstance(self.weight, UninitializedParameter)
228
+ assert isinstance(self.bias, UninitializedParameter)
229
+ self.weight.materialize((self.num_features,))
230
+ self.bias.materialize((self.num_features,))
231
+ if self.track_running_stats:
232
+ self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
233
+ self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
234
+ self.reset_parameters()
235
+
236
+
237
+ class BatchNorm1d(_BatchNorm):
238
+ r"""Applies Batch Normalization over a 2D or 3D input.
239
+
240
+ Method described in the paper
241
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
242
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
243
+
244
+ .. math::
245
+
246
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
247
+
248
+ The mean and standard-deviation are calculated per-dimension over
249
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
250
+ of size `C` (where `C` is the number of features or channels of the input). By default, the
251
+ elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
252
+ At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
253
+ equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
254
+ moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to
255
+ ``torch.var(input, unbiased=True)``.
256
+
257
+ Also by default, during training this layer keeps running estimates of its
258
+ computed mean and variance, which are then used for normalization during
259
+ evaluation. The running estimates are kept with a default :attr:`momentum`
260
+ of 0.1.
261
+
262
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
263
+ keep running estimates, and batch statistics are instead used during
264
+ evaluation time as well.
265
+
266
+ .. note::
267
+ This :attr:`momentum` argument is different from one used in optimizer
268
+ classes and the conventional notion of momentum. Mathematically, the
269
+ update rule for running statistics here is
270
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
271
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
272
+ new observed value.
273
+
274
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
275
+ on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
276
+
277
+ Args:
278
+ num_features: number of features or channels :math:`C` of the input
279
+ eps: a value added to the denominator for numerical stability.
280
+ Default: 1e-5
281
+ momentum: the value used for the running_mean and running_var
282
+ computation. Can be set to ``None`` for cumulative moving average
283
+ (i.e. simple average). Default: 0.1
284
+ affine: a boolean value that when set to ``True``, this module has
285
+ learnable affine parameters. Default: ``True``
286
+ track_running_stats: a boolean value that when set to ``True``, this
287
+ module tracks the running mean and variance, and when set to ``False``,
288
+ this module does not track such statistics, and initializes statistics
289
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
290
+ When these buffers are ``None``, this module always uses batch statistics.
291
+ in both training and eval modes. Default: ``True``
292
+
293
+ Shape:
294
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
295
+ :math:`C` is the number of features or channels, and :math:`L` is the sequence length
296
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
297
+
298
+ Examples::
299
+
300
+ >>> # With Learnable Parameters
301
+ >>> m = nn.BatchNorm1d(100)
302
+ >>> # Without Learnable Parameters
303
+ >>> m = nn.BatchNorm1d(100, affine=False)
304
+ >>> input = torch.randn(20, 100)
305
+ >>> output = m(input)
306
+ """
307
+
308
+ def _check_input_dim(self, input):
309
+ if input.dim() != 2 and input.dim() != 3:
310
+ raise ValueError(
311
+ f"expected 2D or 3D input (got {input.dim()}D input)"
312
+ )
313
+
314
+
315
+ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
316
+ r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
317
+
318
+ Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
319
+ from the ``input.size(1)``.
320
+ The attributes that will be lazily initialized are `weight`, `bias`,
321
+ `running_mean` and `running_var`.
322
+
323
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
324
+ on lazy modules and their limitations.
325
+
326
+ Args:
327
+ eps: a value added to the denominator for numerical stability.
328
+ Default: 1e-5
329
+ momentum: the value used for the running_mean and running_var
330
+ computation. Can be set to ``None`` for cumulative moving average
331
+ (i.e. simple average). Default: 0.1
332
+ affine: a boolean value that when set to ``True``, this module has
333
+ learnable affine parameters. Default: ``True``
334
+ track_running_stats: a boolean value that when set to ``True``, this
335
+ module tracks the running mean and variance, and when set to ``False``,
336
+ this module does not track such statistics, and initializes statistics
337
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
338
+ When these buffers are ``None``, this module always uses batch statistics.
339
+ in both training and eval modes. Default: ``True``
340
+ """
341
+
342
+ cls_to_become = BatchNorm1d # type: ignore[assignment]
343
+
344
+ def _check_input_dim(self, input):
345
+ if input.dim() != 2 and input.dim() != 3:
346
+ raise ValueError(
347
+ f"expected 2D or 3D input (got {input.dim()}D input)"
348
+ )
349
+
350
+
351
+ class BatchNorm2d(_BatchNorm):
352
+ r"""Applies Batch Normalization over a 4D input.
353
+
354
+ 4D is a mini-batch of 2D inputs
355
+ with additional channel dimension. Method described in the paper
356
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
357
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
358
+
359
+ .. math::
360
+
361
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
362
+
363
+ The mean and standard-deviation are calculated per-dimension over
364
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
365
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
366
+ to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
367
+ standard-deviation is calculated via the biased estimator, equivalent to
368
+ ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
369
+ standard-deviation is calculated via the unbiased estimator, equivalent to
370
+ ``torch.var(input, unbiased=True)``.
371
+
372
+ Also by default, during training this layer keeps running estimates of its
373
+ computed mean and variance, which are then used for normalization during
374
+ evaluation. The running estimates are kept with a default :attr:`momentum`
375
+ of 0.1.
376
+
377
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
378
+ keep running estimates, and batch statistics are instead used during
379
+ evaluation time as well.
380
+
381
+ .. note::
382
+ This :attr:`momentum` argument is different from one used in optimizer
383
+ classes and the conventional notion of momentum. Mathematically, the
384
+ update rule for running statistics here is
385
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
386
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
387
+ new observed value.
388
+
389
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
390
+ on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
391
+
392
+ Args:
393
+ num_features: :math:`C` from an expected input of size
394
+ :math:`(N, C, H, W)`
395
+ eps: a value added to the denominator for numerical stability.
396
+ Default: 1e-5
397
+ momentum: the value used for the running_mean and running_var
398
+ computation. Can be set to ``None`` for cumulative moving average
399
+ (i.e. simple average). Default: 0.1
400
+ affine: a boolean value that when set to ``True``, this module has
401
+ learnable affine parameters. Default: ``True``
402
+ track_running_stats: a boolean value that when set to ``True``, this
403
+ module tracks the running mean and variance, and when set to ``False``,
404
+ this module does not track such statistics, and initializes statistics
405
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
406
+ When these buffers are ``None``, this module always uses batch statistics.
407
+ in both training and eval modes. Default: ``True``
408
+
409
+ Shape:
410
+ - Input: :math:`(N, C, H, W)`
411
+ - Output: :math:`(N, C, H, W)` (same shape as input)
412
+
413
+ Examples::
414
+
415
+ >>> # With Learnable Parameters
416
+ >>> m = nn.BatchNorm2d(100)
417
+ >>> # Without Learnable Parameters
418
+ >>> m = nn.BatchNorm2d(100, affine=False)
419
+ >>> input = torch.randn(20, 100, 35, 45)
420
+ >>> output = m(input)
421
+ """
422
+
423
+ def _check_input_dim(self, input):
424
+ if input.dim() != 4:
425
+ raise ValueError(f"expected 4D input (got {input.dim()}D input)")
426
+
427
+
428
+ class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
429
+ r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
430
+
431
+ Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
432
+ from the ``input.size(1)``.
433
+ The attributes that will be lazily initialized are `weight`, `bias`,
434
+ `running_mean` and `running_var`.
435
+
436
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
437
+ on lazy modules and their limitations.
438
+
439
+ Args:
440
+ eps: a value added to the denominator for numerical stability.
441
+ Default: 1e-5
442
+ momentum: the value used for the running_mean and running_var
443
+ computation. Can be set to ``None`` for cumulative moving average
444
+ (i.e. simple average). Default: 0.1
445
+ affine: a boolean value that when set to ``True``, this module has
446
+ learnable affine parameters. Default: ``True``
447
+ track_running_stats: a boolean value that when set to ``True``, this
448
+ module tracks the running mean and variance, and when set to ``False``,
449
+ this module does not track such statistics, and initializes statistics
450
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
451
+ When these buffers are ``None``, this module always uses batch statistics.
452
+ in both training and eval modes. Default: ``True``
453
+ """
454
+
455
+ cls_to_become = BatchNorm2d # type: ignore[assignment]
456
+
457
+ def _check_input_dim(self, input):
458
+ if input.dim() != 4:
459
+ raise ValueError(f"expected 4D input (got {input.dim()}D input)")
460
+
461
+
462
+ class BatchNorm3d(_BatchNorm):
463
+ r"""Applies Batch Normalization over a 5D input.
464
+
465
+ 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
466
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
467
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
468
+
469
+ .. math::
470
+
471
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
472
+
473
+ The mean and standard-deviation are calculated per-dimension over
474
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
475
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
476
+ to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
477
+ standard-deviation is calculated via the biased estimator, equivalent to
478
+ ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
479
+ standard-deviation is calculated via the unbiased estimator, equivalent to
480
+ ``torch.var(input, unbiased=True)``.
481
+
482
+ Also by default, during training this layer keeps running estimates of its
483
+ computed mean and variance, which are then used for normalization during
484
+ evaluation. The running estimates are kept with a default :attr:`momentum`
485
+ of 0.1.
486
+
487
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
488
+ keep running estimates, and batch statistics are instead used during
489
+ evaluation time as well.
490
+
491
+ .. note::
492
+ This :attr:`momentum` argument is different from one used in optimizer
493
+ classes and the conventional notion of momentum. Mathematically, the
494
+ update rule for running statistics here is
495
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
496
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
497
+ new observed value.
498
+
499
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
500
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
501
+ or Spatio-temporal Batch Normalization.
502
+
503
+ Args:
504
+ num_features: :math:`C` from an expected input of size
505
+ :math:`(N, C, D, H, W)`
506
+ eps: a value added to the denominator for numerical stability.
507
+ Default: 1e-5
508
+ momentum: the value used for the running_mean and running_var
509
+ computation. Can be set to ``None`` for cumulative moving average
510
+ (i.e. simple average). Default: 0.1
511
+ affine: a boolean value that when set to ``True``, this module has
512
+ learnable affine parameters. Default: ``True``
513
+ track_running_stats: a boolean value that when set to ``True``, this
514
+ module tracks the running mean and variance, and when set to ``False``,
515
+ this module does not track such statistics, and initializes statistics
516
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
517
+ When these buffers are ``None``, this module always uses batch statistics.
518
+ in both training and eval modes. Default: ``True``
519
+
520
+ Shape:
521
+ - Input: :math:`(N, C, D, H, W)`
522
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
523
+
524
+ Examples::
525
+
526
+ >>> # With Learnable Parameters
527
+ >>> m = nn.BatchNorm3d(100)
528
+ >>> # Without Learnable Parameters
529
+ >>> m = nn.BatchNorm3d(100, affine=False)
530
+ >>> input = torch.randn(20, 100, 35, 45, 10)
531
+ >>> output = m(input)
532
+ """
533
+
534
+ def _check_input_dim(self, input):
535
+ if input.dim() != 5:
536
+ raise ValueError(f"expected 5D input (got {input.dim()}D input)")
537
+
538
+
539
+ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
540
+ r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
541
+
542
+ Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
543
+ from the ``input.size(1)``.
544
+ The attributes that will be lazily initialized are `weight`, `bias`,
545
+ `running_mean` and `running_var`.
546
+
547
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
548
+ on lazy modules and their limitations.
549
+
550
+ Args:
551
+ eps: a value added to the denominator for numerical stability.
552
+ Default: 1e-5
553
+ momentum: the value used for the running_mean and running_var
554
+ computation. Can be set to ``None`` for cumulative moving average
555
+ (i.e. simple average). Default: 0.1
556
+ affine: a boolean value that when set to ``True``, this module has
557
+ learnable affine parameters. Default: ``True``
558
+ track_running_stats: a boolean value that when set to ``True``, this
559
+ module tracks the running mean and variance, and when set to ``False``,
560
+ this module does not track such statistics, and initializes statistics
561
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
562
+ When these buffers are ``None``, this module always uses batch statistics.
563
+ in both training and eval modes. Default: ``True``
564
+ """
565
+
566
+ cls_to_become = BatchNorm3d # type: ignore[assignment]
567
+
568
+ def _check_input_dim(self, input):
569
+ if input.dim() != 5:
570
+ raise ValueError(f"expected 5D input (got {input.dim()}D input)")
571
+
572
+
573
+ class SyncBatchNorm(_BatchNorm):
574
+ r"""Applies Batch Normalization over a N-Dimensional input.
575
+
576
+ The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
577
+ `Batch Normalization: Accelerating Deep Network Training by Reducing
578
+ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
579
+
580
+ .. math::
581
+
582
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
583
+
584
+ The mean and standard-deviation are calculated per-dimension over all
585
+ mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
586
+ are learnable parameter vectors of size `C` (where `C` is the input size).
587
+ By default, the elements of :math:`\gamma` are sampled from
588
+ :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
589
+ The standard-deviation is calculated via the biased estimator, equivalent to
590
+ `torch.var(input, unbiased=False)`.
591
+
592
+ Also by default, during training this layer keeps running estimates of its
593
+ computed mean and variance, which are then used for normalization during
594
+ evaluation. The running estimates are kept with a default :attr:`momentum`
595
+ of 0.1.
596
+
597
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
598
+ keep running estimates, and batch statistics are instead used during
599
+ evaluation time as well.
600
+
601
+ .. note::
602
+ This :attr:`momentum` argument is different from one used in optimizer
603
+ classes and the conventional notion of momentum. Mathematically, the
604
+ update rule for running statistics here is
605
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
606
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
607
+ new observed value.
608
+
609
+ Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
610
+ statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
611
+ Normalization or Spatio-temporal Batch Normalization.
612
+
613
+ Currently :class:`SyncBatchNorm` only supports
614
+ :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
615
+ :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
616
+ :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
617
+ Network with DDP.
618
+
619
+ Args:
620
+ num_features: :math:`C` from an expected input of size
621
+ :math:`(N, C, +)`
622
+ eps: a value added to the denominator for numerical stability.
623
+ Default: ``1e-5``
624
+ momentum: the value used for the running_mean and running_var
625
+ computation. Can be set to ``None`` for cumulative moving average
626
+ (i.e. simple average). Default: 0.1
627
+ affine: a boolean value that when set to ``True``, this module has
628
+ learnable affine parameters. Default: ``True``
629
+ track_running_stats: a boolean value that when set to ``True``, this
630
+ module tracks the running mean and variance, and when set to ``False``,
631
+ this module does not track such statistics, and initializes statistics
632
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
633
+ When these buffers are ``None``, this module always uses batch statistics.
634
+ in both training and eval modes. Default: ``True``
635
+ process_group: synchronization of stats happen within each process group
636
+ individually. Default behavior is synchronization across the whole
637
+ world
638
+
639
+ Shape:
640
+ - Input: :math:`(N, C, +)`
641
+ - Output: :math:`(N, C, +)` (same shape as input)
642
+
643
+ .. note::
644
+ Synchronization of batchnorm statistics occurs only while training, i.e.
645
+ synchronization is disabled when ``model.eval()`` is set or if
646
+ ``self.training`` is otherwise ``False``.
647
+
648
+ Examples::
649
+
650
+ >>> # xdoctest: +SKIP
651
+ >>> # With Learnable Parameters
652
+ >>> m = nn.SyncBatchNorm(100)
653
+ >>> # creating process group (optional)
654
+ >>> # ranks is a list of int identifying rank ids.
655
+ >>> ranks = list(range(8))
656
+ >>> r1, r2 = ranks[:4], ranks[4:]
657
+ >>> # Note: every rank calls into new_group for every
658
+ >>> # process group created, even if that rank is not
659
+ >>> # part of the group.
660
+ >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
661
+ >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
662
+ >>> # Without Learnable Parameters
663
+ >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
664
+ >>> input = torch.randn(20, 100, 35, 45, 10)
665
+ >>> output = m(input)
666
+
667
+ >>> # network is nn.BatchNorm layer
668
+ >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
669
+ >>> # only single gpu per process is currently supported
670
+ >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
671
+ >>> sync_bn_network,
672
+ >>> device_ids=[args.local_rank],
673
+ >>> output_device=args.local_rank)
674
+ """
675
+
676
+ def __init__(
677
+ self,
678
+ num_features: int,
679
+ eps: float = 1e-5,
680
+ momentum: float = 0.1,
681
+ affine: bool = True,
682
+ track_running_stats: bool = True,
683
+ process_group: Optional[Any] = None,
684
+ device=None,
685
+ dtype=None
686
+ ) -> None:
687
+ factory_kwargs = {'device': device, 'dtype': dtype}
688
+ super().__init__(
689
+ num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
690
+ )
691
+ self.process_group = process_group
692
+
693
+ def _check_input_dim(self, input):
694
+ if input.dim() < 2:
695
+ raise ValueError(
696
+ f"expected at least 2D input (got {input.dim()}D input)"
697
+ )
698
+
699
+ def _check_non_zero_input_channels(self, input):
700
+ if input.size(1) == 0:
701
+ raise ValueError(
702
+ "SyncBatchNorm number of input channels should be non-zero"
703
+ )
704
+
705
+ def forward(self, input: Tensor) -> Tensor:
706
+ self._check_input_dim(input)
707
+ self._check_non_zero_input_channels(input)
708
+
709
+ # exponential_average_factor is set to self.momentum
710
+ # (when it is available) only so that it gets updated
711
+ # in ONNX graph when this node is exported to ONNX.
712
+ if self.momentum is None:
713
+ exponential_average_factor = 0.0
714
+ else:
715
+ exponential_average_factor = self.momentum
716
+
717
+ if self.training and self.track_running_stats:
718
+ assert self.num_batches_tracked is not None
719
+ self.num_batches_tracked.add_(1)
720
+ if self.momentum is None: # use cumulative moving average
721
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
722
+ else: # use exponential moving average
723
+ exponential_average_factor = self.momentum
724
+
725
+ r"""
726
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
727
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
728
+ """
729
+ if self.training:
730
+ bn_training = True
731
+ else:
732
+ bn_training = (self.running_mean is None) and (self.running_var is None)
733
+
734
+ r"""
735
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
736
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
737
+ used for normalization (i.e. in eval mode when buffers are not None).
738
+ """
739
+ # If buffers are not to be tracked, ensure that they won't be updated
740
+ running_mean = (
741
+ self.running_mean if not self.training or self.track_running_stats else None
742
+ )
743
+ running_var = (
744
+ self.running_var if not self.training or self.track_running_stats else None
745
+ )
746
+
747
+ # Don't sync batchnorm stats in inference mode (model.eval()).
748
+ need_sync = (bn_training and self.training and
749
+ torch.distributed.is_available() and torch.distributed.is_initialized())
750
+ if need_sync:
751
+ # currently only GPU/PrivateUse1 input is supported
752
+ if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
753
+ raise ValueError("SyncBatchNorm expected input tensor to be on GPU or "
754
+ f"{torch._C._get_privateuse1_backend_name()}")
755
+
756
+ process_group = torch.distributed.group.WORLD
757
+ if self.process_group:
758
+ process_group = self.process_group
759
+ world_size = torch.distributed.get_world_size(process_group)
760
+ need_sync = world_size > 1
761
+
762
+ # fallback to framework BN when synchronization is not necessary
763
+ if not need_sync:
764
+ return F.batch_norm(
765
+ input,
766
+ running_mean,
767
+ running_var,
768
+ self.weight,
769
+ self.bias,
770
+ bn_training,
771
+ exponential_average_factor,
772
+ self.eps,
773
+ )
774
+ else:
775
+ assert bn_training
776
+ return sync_batch_norm.apply(
777
+ input,
778
+ self.weight,
779
+ self.bias,
780
+ running_mean,
781
+ running_var,
782
+ self.eps,
783
+ exponential_average_factor,
784
+ process_group, # type: ignore[possibly-undefined]
785
+ world_size, # type: ignore[possibly-undefined]
786
+ )
787
+
788
+ @classmethod
789
+ def convert_sync_batchnorm(cls, module, process_group=None):
790
+ r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
791
+
792
+ Args:
793
+ module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
794
+ process_group (optional): process group to scope synchronization,
795
+ default is the whole world
796
+
797
+ Returns:
798
+ The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
799
+ layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
800
+ a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
801
+ instead.
802
+
803
+ Example::
804
+
805
+ >>> # Network with nn.BatchNorm layer
806
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
807
+ >>> module = torch.nn.Sequential(
808
+ >>> torch.nn.Linear(20, 100),
809
+ >>> torch.nn.BatchNorm1d(100),
810
+ >>> ).cuda()
811
+ >>> # creating process group (optional)
812
+ >>> # ranks is a list of int identifying rank ids.
813
+ >>> ranks = list(range(8))
814
+ >>> r1, r2 = ranks[:4], ranks[4:]
815
+ >>> # Note: every rank calls into new_group for every
816
+ >>> # process group created, even if that rank is not
817
+ >>> # part of the group.
818
+ >>> # xdoctest: +SKIP("distributed")
819
+ >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
820
+ >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
821
+ >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
822
+
823
+ """
824
+ module_output = module
825
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
826
+ module_output = torch.nn.SyncBatchNorm(
827
+ module.num_features,
828
+ module.eps,
829
+ module.momentum,
830
+ module.affine,
831
+ module.track_running_stats,
832
+ process_group,
833
+ )
834
+ if module.affine:
835
+ with torch.no_grad():
836
+ module_output.weight = module.weight
837
+ module_output.bias = module.bias
838
+ module_output.running_mean = module.running_mean
839
+ module_output.running_var = module.running_var
840
+ module_output.num_batches_tracked = module.num_batches_tracked
841
+ module_output.training = module.training
842
+ if hasattr(module, "qconfig"):
843
+ module_output.qconfig = module.qconfig
844
+ for name, child in module.named_children():
845
+ module_output.add_module(
846
+ name, cls.convert_sync_batchnorm(child, process_group)
847
+ )
848
+ del module
849
+ return module_output
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/channelshuffle.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+ from .. import functional as F
3
+
4
+ from torch import Tensor
5
+
6
+ __all__ = ['ChannelShuffle']
7
+
8
+ class ChannelShuffle(Module):
9
+ r"""Divides and rearranges the channels in a tensor.
10
+
11
+ This operation divides the channels in a tensor of shape :math:`(*, C , H, W)`
12
+ into g groups and rearranges them as :math:`(*, \frac{C}{g}, g, H, W)`,
13
+ while keeping the original tensor shape.
14
+
15
+ Args:
16
+ groups (int): number of groups to divide channels in.
17
+
18
+ Examples::
19
+
20
+ >>> # xdoctest: +IGNORE_WANT("FIXME: incorrect want")
21
+ >>> channel_shuffle = nn.ChannelShuffle(2)
22
+ >>> input = torch.randn(1, 4, 2, 2)
23
+ >>> print(input)
24
+ [[[[1, 2],
25
+ [3, 4]],
26
+ [[5, 6],
27
+ [7, 8]],
28
+ [[9, 10],
29
+ [11, 12]],
30
+ [[13, 14],
31
+ [15, 16]],
32
+ ]]
33
+ >>> output = channel_shuffle(input)
34
+ >>> print(output)
35
+ [[[[1, 2],
36
+ [3, 4]],
37
+ [[9, 10],
38
+ [11, 12]],
39
+ [[5, 6],
40
+ [7, 8]],
41
+ [[13, 14],
42
+ [15, 16]],
43
+ ]]
44
+ """
45
+
46
+ __constants__ = ['groups']
47
+ groups: int
48
+
49
+ def __init__(self, groups: int) -> None:
50
+ super().__init__()
51
+ self.groups = groups
52
+
53
+ def forward(self, input: Tensor) -> Tensor:
54
+ return F.channel_shuffle(input, self.groups)
55
+
56
+ def extra_repr(self) -> str:
57
+ return f'groups={self.groups}'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from collections import OrderedDict, abc as container_abcs
3
+ from itertools import chain, islice
4
+ import operator
5
+
6
+ import torch
7
+ from .module import Module
8
+ from ..parameter import Parameter
9
+ from torch._jit_internal import _copy_to_script_wrapper
10
+
11
+ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
12
+ from typing_extensions import Self
13
+
14
+ __all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
15
+
16
+ T = TypeVar('T', bound=Module)
17
+
18
+
19
+ # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
20
+ def _addindent(s_, numSpaces):
21
+ s = s_.split('\n')
22
+ # don't do anything for single-line stuff
23
+ if len(s) == 1:
24
+ return s_
25
+ first = s.pop(0)
26
+ s = [(numSpaces * ' ') + line for line in s]
27
+ s = '\n'.join(s)
28
+ s = first + '\n' + s
29
+ return s
30
+
31
+
32
+ class Container(Module):
33
+
34
+ def __init__(self, **kwargs: Any) -> None:
35
+ super().__init__()
36
+ # DeprecationWarning is ignored by default <sigh>
37
+ warnings.warn("nn.Container is deprecated. All of it's functionality "
38
+ "is now implemented in nn.Module. Subclass that instead.")
39
+ for key, value in kwargs.items():
40
+ self.add_module(key, value)
41
+
42
+
43
+ class Sequential(Module):
44
+ r"""A sequential container.
45
+
46
+ Modules will be added to it in the order they are passed in the
47
+ constructor. Alternatively, an ``OrderedDict`` of modules can be
48
+ passed in. The ``forward()`` method of ``Sequential`` accepts any
49
+ input and forwards it to the first module it contains. It then
50
+ "chains" outputs to inputs sequentially for each subsequent module,
51
+ finally returning the output of the last module.
52
+
53
+ The value a ``Sequential`` provides over manually calling a sequence
54
+ of modules is that it allows treating the whole container as a
55
+ single module, such that performing a transformation on the
56
+ ``Sequential`` applies to each of the modules it stores (which are
57
+ each a registered submodule of the ``Sequential``).
58
+
59
+ What's the difference between a ``Sequential`` and a
60
+ :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
61
+ sounds like--a list for storing ``Module`` s! On the other hand,
62
+ the layers in a ``Sequential`` are connected in a cascading way.
63
+
64
+ Example::
65
+
66
+ # Using Sequential to create a small model. When `model` is run,
67
+ # input will first be passed to `Conv2d(1,20,5)`. The output of
68
+ # `Conv2d(1,20,5)` will be used as the input to the first
69
+ # `ReLU`; the output of the first `ReLU` will become the input
70
+ # for `Conv2d(20,64,5)`. Finally, the output of
71
+ # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
72
+ model = nn.Sequential(
73
+ nn.Conv2d(1,20,5),
74
+ nn.ReLU(),
75
+ nn.Conv2d(20,64,5),
76
+ nn.ReLU()
77
+ )
78
+
79
+ # Using Sequential with OrderedDict. This is functionally the
80
+ # same as the above code
81
+ model = nn.Sequential(OrderedDict([
82
+ ('conv1', nn.Conv2d(1,20,5)),
83
+ ('relu1', nn.ReLU()),
84
+ ('conv2', nn.Conv2d(20,64,5)),
85
+ ('relu2', nn.ReLU())
86
+ ]))
87
+ """
88
+
89
+ _modules: Dict[str, Module] # type: ignore[assignment]
90
+
91
+ @overload
92
+ def __init__(self, *args: Module) -> None:
93
+ ...
94
+
95
+ @overload
96
+ def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
97
+ ...
98
+
99
+ def __init__(self, *args):
100
+ super().__init__()
101
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
102
+ for key, module in args[0].items():
103
+ self.add_module(key, module)
104
+ else:
105
+ for idx, module in enumerate(args):
106
+ self.add_module(str(idx), module)
107
+
108
+ def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
109
+ """Get the idx-th item of the iterator."""
110
+ size = len(self)
111
+ idx = operator.index(idx)
112
+ if not -size <= idx < size:
113
+ raise IndexError(f'index {idx} is out of range')
114
+ idx %= size
115
+ return next(islice(iterator, idx, None))
116
+
117
+ @_copy_to_script_wrapper
118
+ def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]:
119
+ if isinstance(idx, slice):
120
+ return self.__class__(OrderedDict(list(self._modules.items())[idx]))
121
+ else:
122
+ return self._get_item_by_idx(self._modules.values(), idx)
123
+
124
+ def __setitem__(self, idx: int, module: Module) -> None:
125
+ key: str = self._get_item_by_idx(self._modules.keys(), idx)
126
+ return setattr(self, key, module)
127
+
128
+ def __delitem__(self, idx: Union[slice, int]) -> None:
129
+ if isinstance(idx, slice):
130
+ for key in list(self._modules.keys())[idx]:
131
+ delattr(self, key)
132
+ else:
133
+ key = self._get_item_by_idx(self._modules.keys(), idx)
134
+ delattr(self, key)
135
+ # To preserve numbering
136
+ str_indices = [str(i) for i in range(len(self._modules))]
137
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
138
+
139
+ @_copy_to_script_wrapper
140
+ def __len__(self) -> int:
141
+ return len(self._modules)
142
+
143
+ def __add__(self, other) -> 'Sequential':
144
+ if isinstance(other, Sequential):
145
+ ret = Sequential()
146
+ for layer in self:
147
+ ret.append(layer)
148
+ for layer in other:
149
+ ret.append(layer)
150
+ return ret
151
+ else:
152
+ raise ValueError('add operator supports only objects '
153
+ f'of Sequential class, but {str(type(other))} is given.')
154
+
155
+ def pop(self, key: Union[int, slice]) -> Module:
156
+ v = self[key]
157
+ del self[key]
158
+ return v
159
+
160
+ def __iadd__(self, other) -> Self:
161
+ if isinstance(other, Sequential):
162
+ offset = len(self)
163
+ for i, module in enumerate(other):
164
+ self.add_module(str(i + offset), module)
165
+ return self
166
+ else:
167
+ raise ValueError('add operator supports only objects '
168
+ f'of Sequential class, but {str(type(other))} is given.')
169
+
170
+ def __mul__(self, other: int) -> 'Sequential':
171
+ if not isinstance(other, int):
172
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
173
+ elif (other <= 0):
174
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
175
+ else:
176
+ combined = Sequential()
177
+ offset = 0
178
+ for _ in range(other):
179
+ for module in self:
180
+ combined.add_module(str(offset), module)
181
+ offset += 1
182
+ return combined
183
+
184
+ def __rmul__(self, other: int) -> 'Sequential':
185
+ return self.__mul__(other)
186
+
187
+ def __imul__(self, other: int) -> Self:
188
+ if not isinstance(other, int):
189
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
190
+ elif (other <= 0):
191
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
192
+ else:
193
+ len_original = len(self)
194
+ offset = len(self)
195
+ for _ in range(other - 1):
196
+ for i in range(len_original):
197
+ self.add_module(str(i + offset), self._modules[str(i)])
198
+ offset += len_original
199
+ return self
200
+
201
+ @_copy_to_script_wrapper
202
+ def __dir__(self):
203
+ keys = super().__dir__()
204
+ keys = [key for key in keys if not key.isdigit()]
205
+ return keys
206
+
207
+ @_copy_to_script_wrapper
208
+ def __iter__(self) -> Iterator[Module]:
209
+ return iter(self._modules.values())
210
+
211
+ # NB: We can't really type check this function as the type of input
212
+ # may change dynamically (as is tested in
213
+ # TestScript.test_sequential_intermediary_types). Cannot annotate
214
+ # with Any as TorchScript expects a more precise type
215
+ def forward(self, input):
216
+ for module in self:
217
+ input = module(input)
218
+ return input
219
+
220
+ def append(self, module: Module) -> 'Sequential':
221
+ r"""Append a given module to the end.
222
+
223
+ Args:
224
+ module (nn.Module): module to append
225
+ """
226
+ self.add_module(str(len(self)), module)
227
+ return self
228
+
229
+ def insert(self, index: int, module: Module) -> 'Sequential':
230
+ if not isinstance(module, Module):
231
+ raise AssertionError(
232
+ f'module should be of type: {Module}')
233
+ n = len(self._modules)
234
+ if not (-n <= index <= n):
235
+ raise IndexError(
236
+ f'Index out of range: {index}')
237
+ if index < 0:
238
+ index += n
239
+ for i in range(n, index, -1):
240
+ self._modules[str(i)] = self._modules[str(i - 1)]
241
+ self._modules[str(index)] = module
242
+ return self
243
+
244
+ def extend(self, sequential) -> 'Sequential':
245
+ for layer in sequential:
246
+ self.append(layer)
247
+ return self
248
+
249
+
250
+ class ModuleList(Module):
251
+ r"""Holds submodules in a list.
252
+
253
+ :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
254
+ modules it contains are properly registered, and will be visible by all
255
+ :class:`~torch.nn.Module` methods.
256
+
257
+ Args:
258
+ modules (iterable, optional): an iterable of modules to add
259
+
260
+ Example::
261
+
262
+ class MyModule(nn.Module):
263
+ def __init__(self):
264
+ super().__init__()
265
+ self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
266
+
267
+ def forward(self, x):
268
+ # ModuleList can act as an iterable, or be indexed using ints
269
+ for i, l in enumerate(self.linears):
270
+ x = self.linears[i // 2](x) + l(x)
271
+ return x
272
+ """
273
+
274
+ _modules: Dict[str, Module] # type: ignore[assignment]
275
+
276
+ def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
277
+ super().__init__()
278
+ if modules is not None:
279
+ self += modules
280
+
281
+ def _get_abs_string_index(self, idx):
282
+ """Get the absolute index for the list of modules."""
283
+ idx = operator.index(idx)
284
+ if not (-len(self) <= idx < len(self)):
285
+ raise IndexError(f'index {idx} is out of range')
286
+ if idx < 0:
287
+ idx += len(self)
288
+ return str(idx)
289
+
290
+ @_copy_to_script_wrapper
291
+ def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
292
+ if isinstance(idx, slice):
293
+ return self.__class__(list(self._modules.values())[idx])
294
+ else:
295
+ return self._modules[self._get_abs_string_index(idx)]
296
+
297
+ def __setitem__(self, idx: int, module: Module) -> None:
298
+ idx = self._get_abs_string_index(idx)
299
+ return setattr(self, str(idx), module)
300
+
301
+ def __delitem__(self, idx: Union[int, slice]) -> None:
302
+ if isinstance(idx, slice):
303
+ for k in range(len(self._modules))[idx]:
304
+ delattr(self, str(k))
305
+ else:
306
+ delattr(self, self._get_abs_string_index(idx))
307
+ # To preserve numbering, self._modules is being reconstructed with modules after deletion
308
+ str_indices = [str(i) for i in range(len(self._modules))]
309
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
310
+
311
+ @_copy_to_script_wrapper
312
+ def __len__(self) -> int:
313
+ return len(self._modules)
314
+
315
+ @_copy_to_script_wrapper
316
+ def __iter__(self) -> Iterator[Module]:
317
+ return iter(self._modules.values())
318
+
319
+ def __iadd__(self, modules: Iterable[Module]) -> Self:
320
+ return self.extend(modules)
321
+
322
+ def __add__(self, other: Iterable[Module]) -> 'ModuleList':
323
+ combined = ModuleList()
324
+ for i, module in enumerate(chain(self, other)):
325
+ combined.add_module(str(i), module)
326
+ return combined
327
+
328
+ def __repr__(self):
329
+ """Return a custom repr for ModuleList that compresses repeated module representations."""
330
+ list_of_reprs = [repr(item) for item in self]
331
+ if len(list_of_reprs) == 0:
332
+ return self._get_name() + '()'
333
+
334
+ start_end_indices = [[0, 0]]
335
+ repeated_blocks = [list_of_reprs[0]]
336
+ for i, r in enumerate(list_of_reprs[1:], 1):
337
+ if r == repeated_blocks[-1]:
338
+ start_end_indices[-1][1] += 1
339
+ continue
340
+
341
+ start_end_indices.append([i, i])
342
+ repeated_blocks.append(r)
343
+
344
+ lines = []
345
+ main_str = self._get_name() + '('
346
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
347
+ local_repr = f"({start_id}): {b}" # default repr
348
+
349
+ if start_id != end_id:
350
+ n = end_id - start_id + 1
351
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
352
+
353
+ local_repr = _addindent(local_repr, 2)
354
+ lines.append(local_repr)
355
+
356
+ main_str += '\n ' + '\n '.join(lines) + '\n'
357
+ main_str += ')'
358
+ return main_str
359
+
360
+ @_copy_to_script_wrapper
361
+ def __dir__(self):
362
+ keys = super().__dir__()
363
+ keys = [key for key in keys if not key.isdigit()]
364
+ return keys
365
+
366
+ def insert(self, index: int, module: Module) -> None:
367
+ r"""Insert a given module before a given index in the list.
368
+
369
+ Args:
370
+ index (int): index to insert.
371
+ module (nn.Module): module to insert
372
+ """
373
+ for i in range(len(self._modules), index, -1):
374
+ self._modules[str(i)] = self._modules[str(i - 1)]
375
+ self._modules[str(index)] = module
376
+
377
+ def append(self, module: Module) -> 'ModuleList':
378
+ r"""Append a given module to the end of the list.
379
+
380
+ Args:
381
+ module (nn.Module): module to append
382
+ """
383
+ self.add_module(str(len(self)), module)
384
+ return self
385
+
386
+ def pop(self, key: Union[int, slice]) -> Module:
387
+ v = self[key]
388
+ del self[key]
389
+ return v
390
+
391
+ def extend(self, modules: Iterable[Module]) -> Self:
392
+ r"""Append modules from a Python iterable to the end of the list.
393
+
394
+ Args:
395
+ modules (iterable): iterable of modules to append
396
+ """
397
+ if not isinstance(modules, container_abcs.Iterable):
398
+ raise TypeError("ModuleList.extend should be called with an "
399
+ "iterable, but got " + type(modules).__name__)
400
+ offset = len(self)
401
+ for i, module in enumerate(modules):
402
+ self.add_module(str(offset + i), module)
403
+ return self
404
+
405
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
406
+
407
+
408
+ class ModuleDict(Module):
409
+ r"""Holds submodules in a dictionary.
410
+
411
+ :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
412
+ but modules it contains are properly registered, and will be visible by all
413
+ :class:`~torch.nn.Module` methods.
414
+
415
+ :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
416
+
417
+ * the order of insertion, and
418
+
419
+ * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
420
+ ``OrderedDict``, ``dict`` (started from Python 3.6) or another
421
+ :class:`~torch.nn.ModuleDict` (the argument to
422
+ :meth:`~torch.nn.ModuleDict.update`).
423
+
424
+ Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
425
+ types (e.g., Python's plain ``dict`` before Python version 3.6) does not
426
+ preserve the order of the merged mapping.
427
+
428
+ Args:
429
+ modules (iterable, optional): a mapping (dictionary) of (string: module)
430
+ or an iterable of key-value pairs of type (string, module)
431
+
432
+ Example::
433
+
434
+ class MyModule(nn.Module):
435
+ def __init__(self):
436
+ super().__init__()
437
+ self.choices = nn.ModuleDict({
438
+ 'conv': nn.Conv2d(10, 10, 3),
439
+ 'pool': nn.MaxPool2d(3)
440
+ })
441
+ self.activations = nn.ModuleDict([
442
+ ['lrelu', nn.LeakyReLU()],
443
+ ['prelu', nn.PReLU()]
444
+ ])
445
+
446
+ def forward(self, x, choice, act):
447
+ x = self.choices[choice](x)
448
+ x = self.activations[act](x)
449
+ return x
450
+ """
451
+
452
+ _modules: Dict[str, Module] # type: ignore[assignment]
453
+
454
+ def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
455
+ super().__init__()
456
+ if modules is not None:
457
+ self.update(modules)
458
+
459
+ @_copy_to_script_wrapper
460
+ def __getitem__(self, key: str) -> Module:
461
+ return self._modules[key]
462
+
463
+ def __setitem__(self, key: str, module: Module) -> None:
464
+ self.add_module(key, module)
465
+
466
+ def __delitem__(self, key: str) -> None:
467
+ del self._modules[key]
468
+
469
+ @_copy_to_script_wrapper
470
+ def __len__(self) -> int:
471
+ return len(self._modules)
472
+
473
+ @_copy_to_script_wrapper
474
+ def __iter__(self) -> Iterator[str]:
475
+ return iter(self._modules)
476
+
477
+ @_copy_to_script_wrapper
478
+ def __contains__(self, key: str) -> bool:
479
+ return key in self._modules
480
+
481
+ def clear(self) -> None:
482
+ """Remove all items from the ModuleDict."""
483
+ self._modules.clear()
484
+
485
+ def pop(self, key: str) -> Module:
486
+ r"""Remove key from the ModuleDict and return its module.
487
+
488
+ Args:
489
+ key (str): key to pop from the ModuleDict
490
+ """
491
+ v = self[key]
492
+ del self[key]
493
+ return v
494
+
495
+ @_copy_to_script_wrapper
496
+ def keys(self) -> Iterable[str]:
497
+ r"""Return an iterable of the ModuleDict keys."""
498
+ return self._modules.keys()
499
+
500
+ @_copy_to_script_wrapper
501
+ def items(self) -> Iterable[Tuple[str, Module]]:
502
+ r"""Return an iterable of the ModuleDict key/value pairs."""
503
+ return self._modules.items()
504
+
505
+ @_copy_to_script_wrapper
506
+ def values(self) -> Iterable[Module]:
507
+ r"""Return an iterable of the ModuleDict values."""
508
+ return self._modules.values()
509
+
510
+ def update(self, modules: Mapping[str, Module]) -> None:
511
+ r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
512
+
513
+ .. note::
514
+ If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
515
+ an iterable of key-value pairs, the order of new elements in it is preserved.
516
+
517
+ Args:
518
+ modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
519
+ or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
520
+ """
521
+ if not isinstance(modules, container_abcs.Iterable):
522
+ raise TypeError("ModuleDict.update should be called with an "
523
+ "iterable of key/value pairs, but got " +
524
+ type(modules).__name__)
525
+
526
+ if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
527
+ for key, module in modules.items():
528
+ self[key] = module
529
+ else:
530
+ # modules here can be a list with two items
531
+ for j, m in enumerate(modules):
532
+ if not isinstance(m, container_abcs.Iterable):
533
+ raise TypeError("ModuleDict update sequence element "
534
+ "#" + str(j) + " should be Iterable; is" +
535
+ type(m).__name__)
536
+ if not len(m) == 2:
537
+ raise ValueError("ModuleDict update sequence element "
538
+ "#" + str(j) + " has length " + str(len(m)) +
539
+ "; 2 is required")
540
+ # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
541
+ # that's too cumbersome to type correctly with overloads, so we add an ignore here
542
+ self[m[0]] = m[1] # type: ignore[assignment]
543
+
544
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
545
+
546
+
547
+ class ParameterList(Module):
548
+ r"""Holds parameters in a list.
549
+
550
+ :class:`~torch.nn.ParameterList` can be used like a regular Python
551
+ list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
552
+ and will be visible by all :class:`~torch.nn.Module` methods.
553
+
554
+ Note that the constructor, assigning an element of the list, the
555
+ :meth:`~torch.nn.ParameterDict.append` method and the :meth:`~torch.nn.ParameterDict.extend`
556
+ method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
557
+
558
+ Args:
559
+ parameters (iterable, optional): an iterable of elements to add to the list.
560
+
561
+ Example::
562
+
563
+ class MyModule(nn.Module):
564
+ def __init__(self):
565
+ super().__init__()
566
+ self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
567
+
568
+ def forward(self, x):
569
+ # ParameterList can act as an iterable, or be indexed using ints
570
+ for i, p in enumerate(self.params):
571
+ x = self.params[i // 2].mm(x) + p.mm(x)
572
+ return x
573
+ """
574
+
575
+ def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
576
+ super().__init__()
577
+ self._size = 0
578
+ if values is not None:
579
+ self += values
580
+
581
+ def _get_abs_string_index(self, idx):
582
+ """Get the absolute index for the list of modules."""
583
+ idx = operator.index(idx)
584
+ if not (-len(self) <= idx < len(self)):
585
+ raise IndexError(f'index {idx} is out of range')
586
+ if idx < 0:
587
+ idx += len(self)
588
+ return str(idx)
589
+
590
+ @overload
591
+ def __getitem__(self, idx: int) -> Any:
592
+ ...
593
+
594
+ @overload
595
+ def __getitem__(self: T, idx: slice) -> T:
596
+ ...
597
+
598
+ def __getitem__(self, idx):
599
+ if isinstance(idx, slice):
600
+ start, stop, step = idx.indices(len(self))
601
+ out = self.__class__()
602
+ for i in range(start, stop, step):
603
+ out.append(self[i])
604
+ return out
605
+ else:
606
+ idx = self._get_abs_string_index(idx)
607
+ return getattr(self, str(idx))
608
+
609
+ def __setitem__(self, idx: int, param: Any) -> None:
610
+ # Note that all other function that add an entry to the list part of
611
+ # the ParameterList end up here. So this is the only place where we need
612
+ # to wrap things into Parameter if needed.
613
+ # Objects added via setattr() are not in the list part and thus won't
614
+ # call into this function.
615
+ idx = self._get_abs_string_index(idx)
616
+ if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
617
+ param = Parameter(param)
618
+ return setattr(self, str(idx), param)
619
+
620
+ def __len__(self) -> int:
621
+ return self._size
622
+
623
+ def __iter__(self) -> Iterator[Any]:
624
+ return iter(self[i] for i in range(len(self)))
625
+
626
+ def __iadd__(self, parameters: Iterable[Any]) -> Self:
627
+ return self.extend(parameters)
628
+
629
+ def __dir__(self):
630
+ keys = super().__dir__()
631
+ keys = [key for key in keys if not key.isdigit()]
632
+ return keys
633
+
634
+ def append(self, value: Any) -> 'ParameterList':
635
+ """Append a given value at the end of the list.
636
+
637
+ Args:
638
+ value (Any): value to append
639
+ """
640
+ new_idx = len(self)
641
+ self._size += 1
642
+ self[new_idx] = value
643
+ return self
644
+
645
+ def extend(self, values: Iterable[Any]) -> Self:
646
+ """Append values from a Python iterable to the end of the list.
647
+
648
+ Args:
649
+ values (iterable): iterable of values to append
650
+ """
651
+ # Tensor is an iterable but we never want to unpack it here
652
+ if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch.Tensor):
653
+ raise TypeError("ParameterList.extend should be called with an "
654
+ "iterable, but got " + type(values).__name__)
655
+ for value in values:
656
+ self.append(value)
657
+ return self
658
+
659
+ def extra_repr(self) -> str:
660
+ child_lines = []
661
+ for k, p in enumerate(self):
662
+ if isinstance(p, torch.Tensor):
663
+ size_str = 'x'.join(str(size) for size in p.size())
664
+ if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
665
+ device_str = f' ({p.device})'
666
+ else:
667
+ device_str = ''
668
+ parastr = '{} containing: [{} of size {}{}]'.format(
669
+ "Parameter" if isinstance(p, Parameter) else "Tensor",
670
+ p.dtype, size_str, device_str)
671
+ child_lines.append(' (' + str(k) + '): ' + parastr)
672
+ else:
673
+ child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
674
+
675
+ tmpstr = '\n'.join(child_lines)
676
+ return tmpstr
677
+
678
+ def __call__(self, *args, **kwargs):
679
+ raise RuntimeError('ParameterList should not be called.')
680
+
681
+
682
+ class ParameterDict(Module):
683
+ r"""Holds parameters in a dictionary.
684
+
685
+ ParameterDict can be indexed like a regular Python dictionary, but Parameters it
686
+ contains are properly registered, and will be visible by all Module methods.
687
+ Other objects are treated as would be done by a regular Python dictionary
688
+
689
+ :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
690
+ :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
691
+ types (e.g., Python's plain ``dict``) does not preserve the order of the
692
+ merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
693
+ will preserve their ordering.
694
+
695
+ Note that the constructor, assigning an element of the dictionary and the
696
+ :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
697
+ :class:`~torch.nn.Parameter`.
698
+
699
+ Args:
700
+ values (iterable, optional): a mapping (dictionary) of
701
+ (string : Any) or an iterable of key-value pairs
702
+ of type (string, Any)
703
+
704
+ Example::
705
+
706
+ class MyModule(nn.Module):
707
+ def __init__(self):
708
+ super().__init__()
709
+ self.params = nn.ParameterDict({
710
+ 'left': nn.Parameter(torch.randn(5, 10)),
711
+ 'right': nn.Parameter(torch.randn(5, 10))
712
+ })
713
+
714
+ def forward(self, x, choice):
715
+ x = self.params[choice].mm(x)
716
+ return x
717
+ """
718
+
719
+ def __init__(self, parameters: Any = None) -> None:
720
+ super().__init__()
721
+ self._keys: Dict[str, None] = {}
722
+ if parameters is not None:
723
+ self.update(parameters)
724
+
725
+ def _key_to_attr(self, key: str) -> str:
726
+ if not isinstance(key, str):
727
+ raise TypeError("Index given to ParameterDict cannot be used as a key as it is "
728
+ f"not a string (type is '{type(key).__name__}'). Open an issue on "
729
+ "github if you need non-string keys.")
730
+ else:
731
+ # Use the key as-is so that `.named_parameters()` returns the right thing
732
+ return key
733
+
734
+ def __getitem__(self, key: str) -> Any:
735
+ attr = self._key_to_attr(key)
736
+ return getattr(self, attr)
737
+
738
+ def __setitem__(self, key: str, value: Any) -> None:
739
+ # Note that all other function that add an entry to the dictionary part of
740
+ # the ParameterDict end up here. So this is the only place where we need
741
+ # to wrap things into Parameter if needed.
742
+ # Objects added via setattr() are not in the dictionary part and thus won't
743
+ # call into this function.
744
+ self._keys[key] = None
745
+ attr = self._key_to_attr(key)
746
+ if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
747
+ value = Parameter(value)
748
+ setattr(self, attr, value)
749
+
750
+ def __delitem__(self, key: str) -> None:
751
+ del self._keys[key]
752
+ attr = self._key_to_attr(key)
753
+ delattr(self, attr)
754
+
755
+ def __len__(self) -> int:
756
+ return len(self._keys)
757
+
758
+ def __iter__(self) -> Iterator[str]:
759
+ return iter(self._keys)
760
+
761
+ def __reversed__(self) -> Iterator[str]:
762
+ return reversed(list(self._keys))
763
+
764
+ def copy(self) -> 'ParameterDict':
765
+ """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
766
+ # We have to use an OrderedDict because the ParameterDict constructor
767
+ # behaves differently on plain dict vs OrderedDict
768
+ return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
769
+
770
+ def __contains__(self, key: str) -> bool:
771
+ return key in self._keys
772
+
773
+ def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
774
+ """Set the default for a key in the Parameterdict.
775
+
776
+ If key is in the ParameterDict, return its value.
777
+ If not, insert `key` with a parameter `default` and return `default`.
778
+ `default` defaults to `None`.
779
+
780
+ Args:
781
+ key (str): key to set default for
782
+ default (Any): the parameter set to the key
783
+ """
784
+ if key not in self:
785
+ self[key] = default
786
+ return self[key]
787
+
788
+ def clear(self) -> None:
789
+ """Remove all items from the ParameterDict."""
790
+ for k in self._keys.copy():
791
+ del self[k]
792
+
793
+ def pop(self, key: str) -> Any:
794
+ r"""Remove key from the ParameterDict and return its parameter.
795
+
796
+ Args:
797
+ key (str): key to pop from the ParameterDict
798
+ """
799
+ v = self[key]
800
+ del self[key]
801
+ return v
802
+
803
+ def popitem(self) -> Tuple[str, Any]:
804
+ """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
805
+ k, _ = self._keys.popitem()
806
+ # We need the key in the _keys to be able to access/del
807
+ self._keys[k] = None
808
+ val = self[k]
809
+ del self[k]
810
+ return k, val
811
+
812
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
813
+ r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
814
+
815
+ Args:
816
+ key (str): key to get from the ParameterDict
817
+ default (Parameter, optional): value to return if key not present
818
+ """
819
+ return self[key] if key in self else default
820
+
821
+ def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict':
822
+ r"""Return a new ParameterDict with the keys provided.
823
+
824
+ Args:
825
+ keys (iterable, string): keys to make the new ParameterDict from
826
+ default (Parameter, optional): value to set for all keys
827
+ """
828
+ return ParameterDict((k, default) for k in keys)
829
+
830
+ def keys(self) -> Iterable[str]:
831
+ r"""Return an iterable of the ParameterDict keys."""
832
+ return self._keys.keys()
833
+
834
+ def items(self) -> Iterable[Tuple[str, Any]]:
835
+ r"""Return an iterable of the ParameterDict key/value pairs."""
836
+ return ((k, self[k]) for k in self._keys)
837
+
838
+ def values(self) -> Iterable[Any]:
839
+ r"""Return an iterable of the ParameterDict values."""
840
+ return (self[k] for k in self._keys)
841
+
842
+ def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None:
843
+ r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
844
+
845
+ .. note::
846
+ If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
847
+ an iterable of key-value pairs, the order of new elements in it is preserved.
848
+
849
+ Args:
850
+ parameters (iterable): a mapping (dictionary) from string to
851
+ :class:`~torch.nn.Parameter`, or an iterable of
852
+ key-value pairs of type (string, :class:`~torch.nn.Parameter`)
853
+ """
854
+ if not isinstance(parameters, container_abcs.Iterable):
855
+ raise TypeError("ParametersDict.update should be called with an "
856
+ "iterable of key/value pairs, but got " +
857
+ type(parameters).__name__)
858
+
859
+ if isinstance(parameters, (OrderedDict, ParameterDict)):
860
+ for key, parameter in parameters.items():
861
+ self[key] = parameter
862
+ elif isinstance(parameters, container_abcs.Mapping):
863
+ for key, parameter in sorted(parameters.items()):
864
+ self[key] = parameter
865
+ else:
866
+ for j, p in enumerate(parameters):
867
+ if not isinstance(p, container_abcs.Iterable):
868
+ raise TypeError("ParameterDict update sequence element "
869
+ "#" + str(j) + " should be Iterable; is" +
870
+ type(p).__name__)
871
+ if not len(p) == 2:
872
+ raise ValueError("ParameterDict update sequence element "
873
+ "#" + str(j) + " has length " + str(len(p)) +
874
+ "; 2 is required")
875
+ # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
876
+ self[p[0]] = p[1] # type: ignore[assignment]
877
+
878
+ def extra_repr(self) -> str:
879
+ child_lines = []
880
+ for k, p in self.items():
881
+ if isinstance(p, torch.Tensor):
882
+ size_str = 'x'.join(str(size) for size in p.size())
883
+ if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
884
+ device_str = f' ({p.device})'
885
+ else:
886
+ device_str = ''
887
+ parastr = '{} containing: [{} of size {}{}]'.format(
888
+ "Parameter" if isinstance(p, Parameter) else "Tensor",
889
+ torch.typename(p), size_str, device_str)
890
+ child_lines.append(' (' + str(k) + '): ' + parastr)
891
+ else:
892
+ child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
893
+ tmpstr = '\n'.join(child_lines)
894
+ return tmpstr
895
+
896
+ def __call__(self, input):
897
+ raise RuntimeError('ParameterDict should not be called.')
898
+
899
+ def __or__(self, other: 'ParameterDict') -> 'ParameterDict':
900
+ copy = self.copy()
901
+ copy.update(other)
902
+ return copy
903
+
904
+ def __ror__(self, other: 'ParameterDict') -> 'ParameterDict':
905
+ copy = other.copy()
906
+ copy.update(self)
907
+ return copy
908
+
909
+ def __ior__(self, other : 'ParameterDict') -> Self:
910
+ self.update(other)
911
+ return self
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/dropout.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+ from .. import functional as F
3
+
4
+ from torch import Tensor
5
+
6
+ __all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout']
7
+
8
+ class _DropoutNd(Module):
9
+ __constants__ = ['p', 'inplace']
10
+ p: float
11
+ inplace: bool
12
+
13
+ def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
14
+ super().__init__()
15
+ if p < 0 or p > 1:
16
+ raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
17
+ self.p = p
18
+ self.inplace = inplace
19
+
20
+ def extra_repr(self) -> str:
21
+ return f'p={self.p}, inplace={self.inplace}'
22
+
23
+
24
+ class Dropout(_DropoutNd):
25
+ r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`.
26
+
27
+ The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution.
28
+
29
+ Each channel will be zeroed out independently on every forward call.
30
+
31
+ This has proven to be an effective technique for regularization and
32
+ preventing the co-adaptation of neurons as described in the paper
33
+ `Improving neural networks by preventing co-adaptation of feature
34
+ detectors`_ .
35
+
36
+ Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
37
+ training. This means that during evaluation the module simply computes an
38
+ identity function.
39
+
40
+ Args:
41
+ p: probability of an element to be zeroed. Default: 0.5
42
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
43
+
44
+ Shape:
45
+ - Input: :math:`(*)`. Input can be of any shape
46
+ - Output: :math:`(*)`. Output is of the same shape as input
47
+
48
+ Examples::
49
+
50
+ >>> m = nn.Dropout(p=0.2)
51
+ >>> input = torch.randn(20, 16)
52
+ >>> output = m(input)
53
+
54
+ .. _Improving neural networks by preventing co-adaptation of feature
55
+ detectors: https://arxiv.org/abs/1207.0580
56
+ """
57
+
58
+ def forward(self, input: Tensor) -> Tensor:
59
+ return F.dropout(input, self.p, self.training, self.inplace)
60
+
61
+
62
+ class Dropout1d(_DropoutNd):
63
+ r"""Randomly zero out entire channels.
64
+
65
+ A channel is a 1D feature map,
66
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
67
+ batched input is a 1D tensor :math:`\text{input}[i, j]`.
68
+
69
+ Each channel will be zeroed out independently on every forward call with
70
+ probability :attr:`p` using samples from a Bernoulli distribution.
71
+
72
+ Usually the input comes from :class:`nn.Conv1d` modules.
73
+
74
+ As described in the paper
75
+ `Efficient Object Localization Using Convolutional Networks`_ ,
76
+ if adjacent pixels within feature maps are strongly correlated
77
+ (as is normally the case in early convolution layers) then i.i.d. dropout
78
+ will not regularize the activations and will otherwise just result
79
+ in an effective learning rate decrease.
80
+
81
+ In this case, :func:`nn.Dropout1d` will help promote independence between
82
+ feature maps and should be used instead.
83
+
84
+ Args:
85
+ p (float, optional): probability of an element to be zero-ed.
86
+ inplace (bool, optional): If set to ``True``, will do this operation
87
+ in-place
88
+
89
+ Shape:
90
+ - Input: :math:`(N, C, L)` or :math:`(C, L)`.
91
+ - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
92
+
93
+ Examples::
94
+
95
+ >>> m = nn.Dropout1d(p=0.2)
96
+ >>> input = torch.randn(20, 16, 32)
97
+ >>> output = m(input)
98
+
99
+ .. _Efficient Object Localization Using Convolutional Networks:
100
+ https://arxiv.org/abs/1411.4280
101
+ """
102
+
103
+ def forward(self, input: Tensor) -> Tensor:
104
+ return F.dropout1d(input, self.p, self.training, self.inplace)
105
+
106
+
107
+ class Dropout2d(_DropoutNd):
108
+ r"""Randomly zero out entire channels.
109
+
110
+ A channel is a 2D feature map,
111
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
112
+ batched input is a 2D tensor :math:`\text{input}[i, j]`.
113
+
114
+ Each channel will be zeroed out independently on every forward call with
115
+ probability :attr:`p` using samples from a Bernoulli distribution.
116
+
117
+ Usually the input comes from :class:`nn.Conv2d` modules.
118
+
119
+ As described in the paper
120
+ `Efficient Object Localization Using Convolutional Networks`_ ,
121
+ if adjacent pixels within feature maps are strongly correlated
122
+ (as is normally the case in early convolution layers) then i.i.d. dropout
123
+ will not regularize the activations and will otherwise just result
124
+ in an effective learning rate decrease.
125
+
126
+ In this case, :func:`nn.Dropout2d` will help promote independence between
127
+ feature maps and should be used instead.
128
+
129
+ Args:
130
+ p (float, optional): probability of an element to be zero-ed.
131
+ inplace (bool, optional): If set to ``True``, will do this operation
132
+ in-place
133
+
134
+ .. warning ::
135
+ Due to historical reasons, this class will perform 1D channel-wise dropout
136
+ for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
137
+ support inputs without a batch dimension of shape :math:`(C, H, W)`. This
138
+ behavior will change in a future release to interpret 3D inputs as no-batch-dim
139
+ inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
140
+
141
+ Shape:
142
+ - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
143
+ - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
144
+
145
+ Examples::
146
+
147
+ >>> m = nn.Dropout2d(p=0.2)
148
+ >>> input = torch.randn(20, 16, 32, 32)
149
+ >>> output = m(input)
150
+
151
+ .. _Efficient Object Localization Using Convolutional Networks:
152
+ https://arxiv.org/abs/1411.4280
153
+ """
154
+
155
+ def forward(self, input: Tensor) -> Tensor:
156
+ return F.dropout2d(input, self.p, self.training, self.inplace)
157
+
158
+
159
+ class Dropout3d(_DropoutNd):
160
+ r"""Randomly zero out entire channels.
161
+
162
+ A channel is a 3D feature map,
163
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
164
+ batched input is a 3D tensor :math:`\text{input}[i, j]`.
165
+
166
+ Each channel will be zeroed out independently on every forward call with
167
+ probability :attr:`p` using samples from a Bernoulli distribution.
168
+
169
+ Usually the input comes from :class:`nn.Conv3d` modules.
170
+
171
+ As described in the paper
172
+ `Efficient Object Localization Using Convolutional Networks`_ ,
173
+ if adjacent pixels within feature maps are strongly correlated
174
+ (as is normally the case in early convolution layers) then i.i.d. dropout
175
+ will not regularize the activations and will otherwise just result
176
+ in an effective learning rate decrease.
177
+
178
+ In this case, :func:`nn.Dropout3d` will help promote independence between
179
+ feature maps and should be used instead.
180
+
181
+ Args:
182
+ p (float, optional): probability of an element to be zeroed.
183
+ inplace (bool, optional): If set to ``True``, will do this operation
184
+ in-place
185
+
186
+ Shape:
187
+ - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
188
+ - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
189
+
190
+ Examples::
191
+
192
+ >>> m = nn.Dropout3d(p=0.2)
193
+ >>> input = torch.randn(20, 16, 4, 32, 32)
194
+ >>> output = m(input)
195
+
196
+ .. _Efficient Object Localization Using Convolutional Networks:
197
+ https://arxiv.org/abs/1411.4280
198
+ """
199
+
200
+ def forward(self, input: Tensor) -> Tensor:
201
+ return F.dropout3d(input, self.p, self.training, self.inplace)
202
+
203
+
204
+ class AlphaDropout(_DropoutNd):
205
+ r"""Applies Alpha Dropout over the input.
206
+
207
+ Alpha Dropout is a type of Dropout that maintains the self-normalizing
208
+ property.
209
+ For an input with zero mean and unit standard deviation, the output of
210
+ Alpha Dropout maintains the original mean and standard deviation of the
211
+ input.
212
+ Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
213
+ that the outputs have zero mean and unit standard deviation.
214
+
215
+ During training, it randomly masks some of the elements of the input
216
+ tensor with probability *p* using samples from a bernoulli distribution.
217
+ The elements to masked are randomized on every forward call, and scaled
218
+ and shifted to maintain zero mean and unit standard deviation.
219
+
220
+ During evaluation the module simply computes an identity function.
221
+
222
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
223
+
224
+ Args:
225
+ p (float): probability of an element to be dropped. Default: 0.5
226
+ inplace (bool, optional): If set to ``True``, will do this operation
227
+ in-place
228
+
229
+ Shape:
230
+ - Input: :math:`(*)`. Input can be of any shape
231
+ - Output: :math:`(*)`. Output is of the same shape as input
232
+
233
+ Examples::
234
+
235
+ >>> m = nn.AlphaDropout(p=0.2)
236
+ >>> input = torch.randn(20, 16)
237
+ >>> output = m(input)
238
+
239
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
240
+ """
241
+
242
+ def forward(self, input: Tensor) -> Tensor:
243
+ return F.alpha_dropout(input, self.p, self.training)
244
+
245
+
246
+ class FeatureAlphaDropout(_DropoutNd):
247
+ r"""Randomly masks out entire channels.
248
+
249
+ A channel is a feature map,
250
+ e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
251
+ is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of
252
+ setting activations to zero, as in regular Dropout, the activations are set
253
+ to the negative saturation value of the SELU activation function. More details
254
+ can be found in the paper `Self-Normalizing Neural Networks`_ .
255
+
256
+ Each element will be masked independently for each sample on every forward
257
+ call with probability :attr:`p` using samples from a Bernoulli distribution.
258
+ The elements to be masked are randomized on every forward call, and scaled
259
+ and shifted to maintain zero mean and unit variance.
260
+
261
+ Usually the input comes from :class:`nn.AlphaDropout` modules.
262
+
263
+ As described in the paper
264
+ `Efficient Object Localization Using Convolutional Networks`_ ,
265
+ if adjacent pixels within feature maps are strongly correlated
266
+ (as is normally the case in early convolution layers) then i.i.d. dropout
267
+ will not regularize the activations and will otherwise just result
268
+ in an effective learning rate decrease.
269
+
270
+ In this case, :func:`nn.AlphaDropout` will help promote independence between
271
+ feature maps and should be used instead.
272
+
273
+ Args:
274
+ p (float, optional): probability of an element to be zeroed. Default: 0.5
275
+ inplace (bool, optional): If set to ``True``, will do this operation
276
+ in-place
277
+
278
+ Shape:
279
+ - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
280
+ - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
281
+
282
+ Examples::
283
+
284
+ >>> m = nn.FeatureAlphaDropout(p=0.2)
285
+ >>> input = torch.randn(20, 16, 4, 32, 32)
286
+ >>> output = m(input)
287
+
288
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
289
+ .. _Efficient Object Localization Using Convolutional Networks:
290
+ https://arxiv.org/abs/1411.4280
291
+ """
292
+
293
+ def forward(self, input: Tensor) -> Tensor:
294
+ return F.feature_alpha_dropout(input, self.p, self.training)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/flatten.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+
3
+ from typing import Tuple, Union
4
+ from torch import Tensor
5
+ from torch.types import _size
6
+
7
+ __all__ = ['Flatten', 'Unflatten']
8
+
9
+ class Flatten(Module):
10
+ r"""
11
+ Flattens a contiguous range of dims into a tensor.
12
+
13
+ For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
14
+
15
+ Shape:
16
+ - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
17
+ where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
18
+ number of dimensions including none.
19
+ - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
20
+
21
+ Args:
22
+ start_dim: first dim to flatten (default = 1).
23
+ end_dim: last dim to flatten (default = -1).
24
+
25
+ Examples::
26
+ >>> input = torch.randn(32, 1, 5, 5)
27
+ >>> # With default parameters
28
+ >>> m = nn.Flatten()
29
+ >>> output = m(input)
30
+ >>> output.size()
31
+ torch.Size([32, 25])
32
+ >>> # With non-default parameters
33
+ >>> m = nn.Flatten(0, 2)
34
+ >>> output = m(input)
35
+ >>> output.size()
36
+ torch.Size([160, 5])
37
+ """
38
+
39
+ __constants__ = ['start_dim', 'end_dim']
40
+ start_dim: int
41
+ end_dim: int
42
+
43
+ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
44
+ super().__init__()
45
+ self.start_dim = start_dim
46
+ self.end_dim = end_dim
47
+
48
+ def forward(self, input: Tensor) -> Tensor:
49
+ return input.flatten(self.start_dim, self.end_dim)
50
+
51
+ def extra_repr(self) -> str:
52
+ return f'start_dim={self.start_dim}, end_dim={self.end_dim}'
53
+
54
+
55
+ class Unflatten(Module):
56
+ r"""
57
+ Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
58
+
59
+ * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
60
+ be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
61
+
62
+ * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
63
+ a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
64
+ (tuple of `(name, size)` tuples) for `NamedTensor` input.
65
+
66
+ Shape:
67
+ - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
68
+ dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
69
+ - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
70
+ :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
71
+
72
+ Args:
73
+ dim (Union[int, str]): Dimension to be unflattened
74
+ unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
75
+
76
+ Examples:
77
+ >>> input = torch.randn(2, 50)
78
+ >>> # With tuple of ints
79
+ >>> m = nn.Sequential(
80
+ >>> nn.Linear(50, 50),
81
+ >>> nn.Unflatten(1, (2, 5, 5))
82
+ >>> )
83
+ >>> output = m(input)
84
+ >>> output.size()
85
+ torch.Size([2, 2, 5, 5])
86
+ >>> # With torch.Size
87
+ >>> m = nn.Sequential(
88
+ >>> nn.Linear(50, 50),
89
+ >>> nn.Unflatten(1, torch.Size([2, 5, 5]))
90
+ >>> )
91
+ >>> output = m(input)
92
+ >>> output.size()
93
+ torch.Size([2, 2, 5, 5])
94
+ >>> # With namedshape (tuple of tuples)
95
+ >>> input = torch.randn(2, 50, names=('N', 'features'))
96
+ >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
97
+ >>> output = unflatten(input)
98
+ >>> output.size()
99
+ torch.Size([2, 2, 5, 5])
100
+ """
101
+
102
+ NamedShape = Tuple[Tuple[str, int]]
103
+
104
+ __constants__ = ['dim', 'unflattened_size']
105
+ dim: Union[int, str]
106
+ unflattened_size: Union[_size, NamedShape]
107
+
108
+ def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
109
+ super().__init__()
110
+
111
+ if isinstance(dim, int):
112
+ self._require_tuple_int(unflattened_size)
113
+ elif isinstance(dim, str):
114
+ self._require_tuple_tuple(unflattened_size)
115
+ else:
116
+ raise TypeError("invalid argument type for dim parameter")
117
+
118
+ self.dim = dim
119
+ self.unflattened_size = unflattened_size
120
+
121
+ def _require_tuple_tuple(self, input):
122
+ if (isinstance(input, tuple)):
123
+ for idx, elem in enumerate(input):
124
+ if not isinstance(elem, tuple):
125
+ raise TypeError("unflattened_size must be tuple of tuples, " +
126
+ f"but found element of type {type(elem).__name__} at pos {idx}")
127
+ return
128
+ raise TypeError("unflattened_size must be a tuple of tuples, " +
129
+ f"but found type {type(input).__name__}")
130
+
131
+ def _require_tuple_int(self, input):
132
+ if (isinstance(input, (tuple, list))):
133
+ for idx, elem in enumerate(input):
134
+ if not isinstance(elem, int):
135
+ raise TypeError("unflattened_size must be tuple of ints, " +
136
+ f"but found element of type {type(elem).__name__} at pos {idx}")
137
+ return
138
+ raise TypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}")
139
+
140
+ def forward(self, input: Tensor) -> Tensor:
141
+ return input.unflatten(self.dim, self.unflattened_size)
142
+
143
+ def extra_repr(self) -> str:
144
+ return f'dim={self.dim}, unflattened_size={self.unflattened_size}'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numbers
3
+ from torch.nn.parameter import Parameter
4
+ from .module import Module
5
+ from ._functions import CrossMapLRN2d as _cross_map_lrn2d
6
+ from .. import functional as F
7
+ from .. import init
8
+
9
+ from torch import Tensor, Size
10
+ from typing import Union, List, Tuple
11
+
12
+ __all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm']
13
+
14
+ class LocalResponseNorm(Module):
15
+ r"""Applies local response normalization over an input signal.
16
+
17
+ The input signal is composed of several input planes, where channels occupy the second dimension.
18
+ Applies normalization across channels.
19
+
20
+ .. math::
21
+ b_{c} = a_{c}\left(k + \frac{\alpha}{n}
22
+ \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
23
+
24
+ Args:
25
+ size: amount of neighbouring channels used for normalization
26
+ alpha: multiplicative factor. Default: 0.0001
27
+ beta: exponent. Default: 0.75
28
+ k: additive factor. Default: 1
29
+
30
+ Shape:
31
+ - Input: :math:`(N, C, *)`
32
+ - Output: :math:`(N, C, *)` (same shape as input)
33
+
34
+ Examples::
35
+
36
+ >>> lrn = nn.LocalResponseNorm(2)
37
+ >>> signal_2d = torch.randn(32, 5, 24, 24)
38
+ >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
39
+ >>> output_2d = lrn(signal_2d)
40
+ >>> output_4d = lrn(signal_4d)
41
+
42
+ """
43
+
44
+ __constants__ = ['size', 'alpha', 'beta', 'k']
45
+ size: int
46
+ alpha: float
47
+ beta: float
48
+ k: float
49
+
50
+ def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
51
+ super().__init__()
52
+ self.size = size
53
+ self.alpha = alpha
54
+ self.beta = beta
55
+ self.k = k
56
+
57
+ def forward(self, input: Tensor) -> Tensor:
58
+ return F.local_response_norm(input, self.size, self.alpha, self.beta,
59
+ self.k)
60
+
61
+ def extra_repr(self):
62
+ return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
63
+
64
+
65
+ class CrossMapLRN2d(Module):
66
+ size: int
67
+ alpha: float
68
+ beta: float
69
+ k: float
70
+
71
+ def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
72
+ super().__init__()
73
+ self.size = size
74
+ self.alpha = alpha
75
+ self.beta = beta
76
+ self.k = k
77
+
78
+ def forward(self, input: Tensor) -> Tensor:
79
+ return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta,
80
+ self.k)
81
+
82
+ def extra_repr(self) -> str:
83
+ return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
84
+
85
+
86
+ _shape_t = Union[int, List[int], Size]
87
+
88
+
89
+ class LayerNorm(Module):
90
+ r"""Applies Layer Normalization over a mini-batch of inputs.
91
+
92
+ This layer implements the operation as described in
93
+ the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
94
+
95
+ .. math::
96
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
97
+
98
+ The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
99
+ is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
100
+ is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
101
+ the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
102
+ :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
103
+ :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
104
+ The standard-deviation is calculated via the biased estimator, equivalent to
105
+ `torch.var(input, unbiased=False)`.
106
+
107
+ .. note::
108
+ Unlike Batch Normalization and Instance Normalization, which applies
109
+ scalar scale and bias for each entire channel/plane with the
110
+ :attr:`affine` option, Layer Normalization applies per-element scale and
111
+ bias with :attr:`elementwise_affine`.
112
+
113
+ This layer uses statistics computed from input data in both training and
114
+ evaluation modes.
115
+
116
+ Args:
117
+ normalized_shape (int or list or torch.Size): input shape from an expected input
118
+ of size
119
+
120
+ .. math::
121
+ [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
122
+ \times \ldots \times \text{normalized\_shape}[-1]]
123
+
124
+ If a single integer is used, it is treated as a singleton list, and this module will
125
+ normalize over the last dimension which is expected to be of that specific size.
126
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
127
+ elementwise_affine: a boolean value that when set to ``True``, this module
128
+ has learnable per-element affine parameters initialized to ones (for weights)
129
+ and zeros (for biases). Default: ``True``.
130
+ bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
131
+ :attr:`elementwise_affine` is ``True``). Default: ``True``.
132
+
133
+ Attributes:
134
+ weight: the learnable weights of the module of shape
135
+ :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
136
+ The values are initialized to 1.
137
+ bias: the learnable bias of the module of shape
138
+ :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
139
+ The values are initialized to 0.
140
+
141
+ Shape:
142
+ - Input: :math:`(N, *)`
143
+ - Output: :math:`(N, *)` (same shape as input)
144
+
145
+ Examples::
146
+
147
+ >>> # NLP Example
148
+ >>> batch, sentence_length, embedding_dim = 20, 5, 10
149
+ >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
150
+ >>> layer_norm = nn.LayerNorm(embedding_dim)
151
+ >>> # Activate module
152
+ >>> layer_norm(embedding)
153
+ >>>
154
+ >>> # Image Example
155
+ >>> N, C, H, W = 20, 5, 10, 10
156
+ >>> input = torch.randn(N, C, H, W)
157
+ >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
158
+ >>> # as shown in the image below
159
+ >>> layer_norm = nn.LayerNorm([C, H, W])
160
+ >>> output = layer_norm(input)
161
+
162
+ .. image:: ../_static/img/nn/layer_norm.jpg
163
+ :scale: 50 %
164
+
165
+ """
166
+
167
+ __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
168
+ normalized_shape: Tuple[int, ...]
169
+ eps: float
170
+ elementwise_affine: bool
171
+
172
+ def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
173
+ bias: bool = True, device=None, dtype=None) -> None:
174
+ factory_kwargs = {'device': device, 'dtype': dtype}
175
+ super().__init__()
176
+ if isinstance(normalized_shape, numbers.Integral):
177
+ # mypy error: incompatible types in assignment
178
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
179
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
180
+ self.eps = eps
181
+ self.elementwise_affine = elementwise_affine
182
+ if self.elementwise_affine:
183
+ self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
184
+ if bias:
185
+ self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
186
+ else:
187
+ self.register_parameter('bias', None)
188
+ else:
189
+ self.register_parameter('weight', None)
190
+ self.register_parameter('bias', None)
191
+
192
+ self.reset_parameters()
193
+
194
+ def reset_parameters(self) -> None:
195
+ if self.elementwise_affine:
196
+ init.ones_(self.weight)
197
+ if self.bias is not None:
198
+ init.zeros_(self.bias)
199
+
200
+ def forward(self, input: Tensor) -> Tensor:
201
+ return F.layer_norm(
202
+ input, self.normalized_shape, self.weight, self.bias, self.eps)
203
+
204
+ def extra_repr(self) -> str:
205
+ return '{normalized_shape}, eps={eps}, ' \
206
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
207
+
208
+
209
+ class GroupNorm(Module):
210
+ r"""Applies Group Normalization over a mini-batch of inputs.
211
+
212
+ This layer implements the operation as described in
213
+ the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
214
+
215
+ .. math::
216
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
217
+
218
+ The input channels are separated into :attr:`num_groups` groups, each containing
219
+ ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
220
+ :attr:`num_groups`. The mean and standard-deviation are calculated
221
+ separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
222
+ per-channel affine transform parameter vectors of size :attr:`num_channels` if
223
+ :attr:`affine` is ``True``.
224
+ The standard-deviation is calculated via the biased estimator, equivalent to
225
+ `torch.var(input, unbiased=False)`.
226
+
227
+ This layer uses statistics computed from input data in both training and
228
+ evaluation modes.
229
+
230
+ Args:
231
+ num_groups (int): number of groups to separate the channels into
232
+ num_channels (int): number of channels expected in input
233
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
234
+ affine: a boolean value that when set to ``True``, this module
235
+ has learnable per-channel affine parameters initialized to ones (for weights)
236
+ and zeros (for biases). Default: ``True``.
237
+
238
+ Shape:
239
+ - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
240
+ - Output: :math:`(N, C, *)` (same shape as input)
241
+
242
+ Examples::
243
+
244
+ >>> input = torch.randn(20, 6, 10, 10)
245
+ >>> # Separate 6 channels into 3 groups
246
+ >>> m = nn.GroupNorm(3, 6)
247
+ >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
248
+ >>> m = nn.GroupNorm(6, 6)
249
+ >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
250
+ >>> m = nn.GroupNorm(1, 6)
251
+ >>> # Activating the module
252
+ >>> output = m(input)
253
+ """
254
+
255
+ __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
256
+ num_groups: int
257
+ num_channels: int
258
+ eps: float
259
+ affine: bool
260
+
261
+ def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
262
+ device=None, dtype=None) -> None:
263
+ factory_kwargs = {'device': device, 'dtype': dtype}
264
+ super().__init__()
265
+ if num_channels % num_groups != 0:
266
+ raise ValueError('num_channels must be divisible by num_groups')
267
+
268
+ self.num_groups = num_groups
269
+ self.num_channels = num_channels
270
+ self.eps = eps
271
+ self.affine = affine
272
+ if self.affine:
273
+ self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
274
+ self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
275
+ else:
276
+ self.register_parameter('weight', None)
277
+ self.register_parameter('bias', None)
278
+
279
+ self.reset_parameters()
280
+
281
+ def reset_parameters(self) -> None:
282
+ if self.affine:
283
+ init.ones_(self.weight)
284
+ init.zeros_(self.bias)
285
+
286
+ def forward(self, input: Tensor) -> Tensor:
287
+ return F.group_norm(
288
+ input, self.num_groups, self.weight, self.bias, self.eps)
289
+
290
+ def extra_repr(self) -> str:
291
+ return '{num_groups}, {num_channels}, eps={eps}, ' \
292
+ 'affine={affine}'.format(**self.__dict__)
293
+
294
+
295
+ # TODO: ContrastiveNorm2d
296
+ # TODO: DivisiveNorm2d
297
+ # TODO: SubtractiveNorm2d
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/padding.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+ from .utils import _pair, _quadruple, _ntuple
3
+ from .. import functional as F
4
+
5
+ from torch import Tensor
6
+ from ..common_types import _size_2_t, _size_4_t, _size_6_t
7
+ from typing import Sequence, Tuple
8
+
9
+
10
+ # TODO: grad_output size asserts in THNN
11
+
12
+ __all__ = ['CircularPad1d', 'CircularPad2d', 'CircularPad3d', 'ConstantPad1d', 'ConstantPad2d',
13
+ 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
14
+ 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d']
15
+
16
+
17
+ class _CircularPadNd(Module):
18
+ __constants__ = ['padding']
19
+ padding: Sequence[int]
20
+
21
+ def _check_input_dim(self, input):
22
+ raise NotImplementedError
23
+
24
+ def forward(self, input: Tensor) -> Tensor:
25
+ self._check_input_dim(input)
26
+ return F.pad(input, self.padding, 'circular')
27
+
28
+ def extra_repr(self) -> str:
29
+ return f'{self.padding}'
30
+
31
+
32
+ class CircularPad1d(_CircularPadNd):
33
+ r"""Pads the input tensor using circular padding of the input boundary.
34
+
35
+ Tensor values at the beginning of the dimension are used to pad the end,
36
+ and values at the end are used to pad the beginning. If negative padding is
37
+ applied then the ends of the tensor get removed.
38
+
39
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
40
+
41
+ Args:
42
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
43
+ padding in all boundaries. If a 2-`tuple`, uses
44
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
45
+
46
+ Shape:
47
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
48
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
49
+
50
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
51
+
52
+ Examples::
53
+
54
+ >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
55
+ >>> m = nn.CircularPad1d(2)
56
+ >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
57
+ >>> input
58
+ tensor([[[0., 1., 2., 3.],
59
+ [4., 5., 6., 7.]]])
60
+ >>> m(input)
61
+ tensor([[[2., 3., 0., 1., 2., 3., 0., 1.],
62
+ [6., 7., 4., 5., 6., 7., 4., 5.]]])
63
+ >>> # using different paddings for different sides
64
+ >>> m = nn.CircularPad1d((3, 1))
65
+ >>> m(input)
66
+ tensor([[[1., 2., 3., 0., 1., 2., 3., 0.],
67
+ [5., 6., 7., 4., 5., 6., 7., 4.]]])
68
+ """
69
+
70
+ padding: Tuple[int, int]
71
+
72
+ def __init__(self, padding: _size_2_t) -> None:
73
+ super().__init__()
74
+ self.padding = _pair(padding)
75
+
76
+ def _check_input_dim(self, input):
77
+ if input.dim() != 2 and input.dim() != 3:
78
+ raise ValueError(
79
+ f"expected 2D or 3D input (got {input.dim()}D input)"
80
+ )
81
+
82
+
83
+ class CircularPad2d(_CircularPadNd):
84
+ r"""Pads the input tensor using circular padding of the input boundary.
85
+
86
+ Tensor values at the beginning of the dimension are used to pad the end,
87
+ and values at the end are used to pad the beginning. If negative padding is
88
+ applied then the ends of the tensor get removed.
89
+
90
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
91
+
92
+ Args:
93
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
94
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
95
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
96
+
97
+ Shape:
98
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
99
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
100
+
101
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
102
+
103
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
104
+
105
+ Examples::
106
+
107
+ >>> m = nn.CircularPad2d(2)
108
+ >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
109
+ >>> input
110
+ tensor([[[[0., 1., 2.],
111
+ [3., 4., 5.],
112
+ [6., 7., 8.]]]])
113
+ >>> m(input)
114
+ tensor([[[[4., 5., 3., 4., 5., 3., 4.],
115
+ [7., 8., 6., 7., 8., 6., 7.],
116
+ [1., 2., 0., 1., 2., 0., 1.],
117
+ [4., 5., 3., 4., 5., 3., 4.],
118
+ [7., 8., 6., 7., 8., 6., 7.],
119
+ [1., 2., 0., 1., 2., 0., 1.],
120
+ [4., 5., 3., 4., 5., 3., 4.]]]])
121
+ >>> # using different paddings for different sides
122
+ >>> m = nn.CircularPad2d((1, 1, 2, 0))
123
+ >>> m(input)
124
+ tensor([[[[5., 3., 4., 5., 3.],
125
+ [8., 6., 7., 8., 6.],
126
+ [2., 0., 1., 2., 0.],
127
+ [5., 3., 4., 5., 3.],
128
+ [8., 6., 7., 8., 6.]]]])
129
+ """
130
+
131
+ padding: Tuple[int, int, int, int]
132
+
133
+ def __init__(self, padding: _size_4_t) -> None:
134
+ super().__init__()
135
+ self.padding = _quadruple(padding)
136
+
137
+ def _check_input_dim(self, input):
138
+ if input.dim() != 3 and input.dim() != 4:
139
+ raise ValueError(
140
+ f"expected 3D or 4D input (got {input.dim()}D input)"
141
+ )
142
+
143
+
144
+ class CircularPad3d(_CircularPadNd):
145
+ r"""Pads the input tensor using circular padding of the input boundary.
146
+
147
+ Tensor values at the beginning of the dimension are used to pad the end,
148
+ and values at the end are used to pad the beginning. If negative padding is
149
+ applied then the ends of the tensor get removed.
150
+
151
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
152
+
153
+ Args:
154
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
155
+ padding in all boundaries. If a 6-`tuple`, uses
156
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
157
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
158
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
159
+
160
+ Shape:
161
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
162
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
163
+ where
164
+
165
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
166
+
167
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
168
+
169
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
170
+
171
+ Examples::
172
+
173
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
174
+ >>> m = nn.CircularPad3d(3)
175
+ >>> input = torch.randn(16, 3, 8, 320, 480)
176
+ >>> output = m(input)
177
+ >>> # using different paddings for different sides
178
+ >>> m = nn.CircularPad3d((3, 3, 6, 6, 1, 1))
179
+ >>> output = m(input)
180
+ """
181
+
182
+ padding: Tuple[int, int, int, int, int, int]
183
+
184
+ def __init__(self, padding: _size_6_t) -> None:
185
+ super().__init__()
186
+ self.padding = _ntuple(6)(padding)
187
+
188
+ def _check_input_dim(self, input):
189
+ if input.dim() != 4 and input.dim() != 5:
190
+ raise ValueError(
191
+ f"expected 4D or 5D input (got {input.dim()}D input)"
192
+ )
193
+
194
+
195
+ class _ConstantPadNd(Module):
196
+ __constants__ = ['padding', 'value']
197
+ value: float
198
+ padding: Sequence[int]
199
+
200
+ def __init__(self, value: float) -> None:
201
+ super().__init__()
202
+ self.value = value
203
+
204
+ def forward(self, input: Tensor) -> Tensor:
205
+ return F.pad(input, self.padding, 'constant', self.value)
206
+
207
+ def extra_repr(self) -> str:
208
+ return f'padding={self.padding}, value={self.value}'
209
+
210
+
211
+ class ConstantPad1d(_ConstantPadNd):
212
+ r"""Pads the input tensor boundaries with a constant value.
213
+
214
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
215
+
216
+ Args:
217
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
218
+ padding in both boundaries. If a 2-`tuple`, uses
219
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
220
+
221
+ Shape:
222
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
223
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
224
+
225
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
226
+
227
+ Examples::
228
+
229
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
230
+ >>> m = nn.ConstantPad1d(2, 3.5)
231
+ >>> input = torch.randn(1, 2, 4)
232
+ >>> input
233
+ tensor([[[-1.0491, -0.7152, -0.0749, 0.8530],
234
+ [-1.3287, 1.8966, 0.1466, -0.2771]]])
235
+ >>> m(input)
236
+ tensor([[[ 3.5000, 3.5000, -1.0491, -0.7152, -0.0749, 0.8530, 3.5000,
237
+ 3.5000],
238
+ [ 3.5000, 3.5000, -1.3287, 1.8966, 0.1466, -0.2771, 3.5000,
239
+ 3.5000]]])
240
+ >>> m = nn.ConstantPad1d(2, 3.5)
241
+ >>> input = torch.randn(1, 2, 3)
242
+ >>> input
243
+ tensor([[[ 1.6616, 1.4523, -1.1255],
244
+ [-3.6372, 0.1182, -1.8652]]])
245
+ >>> m(input)
246
+ tensor([[[ 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000, 3.5000],
247
+ [ 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000, 3.5000]]])
248
+ >>> # using different paddings for different sides
249
+ >>> m = nn.ConstantPad1d((3, 1), 3.5)
250
+ >>> m(input)
251
+ tensor([[[ 3.5000, 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000],
252
+ [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
253
+ """
254
+
255
+ padding: Tuple[int, int]
256
+
257
+ def __init__(self, padding: _size_2_t, value: float):
258
+ super().__init__(value)
259
+ self.padding = _pair(padding)
260
+
261
+
262
+ class ConstantPad2d(_ConstantPadNd):
263
+ r"""Pads the input tensor boundaries with a constant value.
264
+
265
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
266
+
267
+ Args:
268
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
269
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
270
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
271
+
272
+ Shape:
273
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
274
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
275
+
276
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
277
+
278
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
279
+
280
+ Examples::
281
+
282
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
283
+ >>> m = nn.ConstantPad2d(2, 3.5)
284
+ >>> input = torch.randn(1, 2, 2)
285
+ >>> input
286
+ tensor([[[ 1.6585, 0.4320],
287
+ [-0.8701, -0.4649]]])
288
+ >>> m(input)
289
+ tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
290
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
291
+ [ 3.5000, 3.5000, 1.6585, 0.4320, 3.5000, 3.5000],
292
+ [ 3.5000, 3.5000, -0.8701, -0.4649, 3.5000, 3.5000],
293
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
294
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
295
+ >>> # using different paddings for different sides
296
+ >>> m = nn.ConstantPad2d((3, 0, 2, 1), 3.5)
297
+ >>> m(input)
298
+ tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
299
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
300
+ [ 3.5000, 3.5000, 3.5000, 1.6585, 0.4320],
301
+ [ 3.5000, 3.5000, 3.5000, -0.8701, -0.4649],
302
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
303
+ """
304
+
305
+ __constants__ = ['padding', 'value']
306
+ padding: Tuple[int, int, int, int]
307
+
308
+ def __init__(self, padding: _size_4_t, value: float) -> None:
309
+ super().__init__(value)
310
+ self.padding = _quadruple(padding)
311
+
312
+
313
+ class ConstantPad3d(_ConstantPadNd):
314
+ r"""Pads the input tensor boundaries with a constant value.
315
+
316
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
317
+
318
+ Args:
319
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
320
+ padding in all boundaries. If a 6-`tuple`, uses
321
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
322
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
323
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
324
+
325
+ Shape:
326
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
327
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
328
+ :math:`(C, D_{out}, H_{out}, W_{out})`, where
329
+
330
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
331
+
332
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
333
+
334
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
335
+
336
+ Examples::
337
+
338
+ >>> m = nn.ConstantPad3d(3, 3.5)
339
+ >>> input = torch.randn(16, 3, 10, 20, 30)
340
+ >>> output = m(input)
341
+ >>> # using different paddings for different sides
342
+ >>> m = nn.ConstantPad3d((3, 3, 6, 6, 0, 1), 3.5)
343
+ >>> output = m(input)
344
+ """
345
+
346
+ padding: Tuple[int, int, int, int, int, int]
347
+
348
+ def __init__(self, padding: _size_6_t, value: float) -> None:
349
+ super().__init__(value)
350
+ self.padding = _ntuple(6)(padding)
351
+
352
+
353
+ class _ReflectionPadNd(Module):
354
+ __constants__ = ['padding']
355
+ padding: Sequence[int]
356
+
357
+ def forward(self, input: Tensor) -> Tensor:
358
+ return F.pad(input, self.padding, 'reflect')
359
+
360
+ def extra_repr(self) -> str:
361
+ return f'{self.padding}'
362
+
363
+
364
+ class ReflectionPad1d(_ReflectionPadNd):
365
+ r"""Pads the input tensor using the reflection of the input boundary.
366
+
367
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
368
+
369
+ Args:
370
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
371
+ padding in all boundaries. If a 2-`tuple`, uses
372
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
373
+
374
+ Shape:
375
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
376
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
377
+
378
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
379
+
380
+ Examples::
381
+
382
+ >>> m = nn.ReflectionPad1d(2)
383
+ >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
384
+ >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
385
+ >>> input
386
+ tensor([[[0., 1., 2., 3.],
387
+ [4., 5., 6., 7.]]])
388
+ >>> m(input)
389
+ tensor([[[2., 1., 0., 1., 2., 3., 2., 1.],
390
+ [6., 5., 4., 5., 6., 7., 6., 5.]]])
391
+ >>> # using different paddings for different sides
392
+ >>> m = nn.ReflectionPad1d((3, 1))
393
+ >>> m(input)
394
+ tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
395
+ [7., 6., 5., 4., 5., 6., 7., 6.]]])
396
+ """
397
+
398
+ padding: Tuple[int, int]
399
+
400
+ def __init__(self, padding: _size_2_t) -> None:
401
+ super().__init__()
402
+ self.padding = _pair(padding)
403
+
404
+
405
+ class ReflectionPad2d(_ReflectionPadNd):
406
+ r"""Pads the input tensor using the reflection of the input boundary.
407
+
408
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
409
+
410
+ Args:
411
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
412
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
413
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
414
+ Note that padding size should be less than the corresponding input dimension.
415
+
416
+ Shape:
417
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
418
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})` where
419
+
420
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
421
+
422
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
423
+
424
+ Examples::
425
+
426
+ >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
427
+ >>> m = nn.ReflectionPad2d(2)
428
+ >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
429
+ >>> input
430
+ tensor([[[[0., 1., 2.],
431
+ [3., 4., 5.],
432
+ [6., 7., 8.]]]])
433
+ >>> m(input)
434
+ tensor([[[[8., 7., 6., 7., 8., 7., 6.],
435
+ [5., 4., 3., 4., 5., 4., 3.],
436
+ [2., 1., 0., 1., 2., 1., 0.],
437
+ [5., 4., 3., 4., 5., 4., 3.],
438
+ [8., 7., 6., 7., 8., 7., 6.],
439
+ [5., 4., 3., 4., 5., 4., 3.],
440
+ [2., 1., 0., 1., 2., 1., 0.]]]])
441
+ >>> # using different paddings for different sides
442
+ >>> m = nn.ReflectionPad2d((1, 1, 2, 0))
443
+ >>> m(input)
444
+ tensor([[[[7., 6., 7., 8., 7.],
445
+ [4., 3., 4., 5., 4.],
446
+ [1., 0., 1., 2., 1.],
447
+ [4., 3., 4., 5., 4.],
448
+ [7., 6., 7., 8., 7.]]]])
449
+ """
450
+
451
+ padding: Tuple[int, int, int, int]
452
+
453
+ def __init__(self, padding: _size_4_t) -> None:
454
+ super().__init__()
455
+ self.padding = _quadruple(padding)
456
+
457
+
458
+ class ReflectionPad3d(_ReflectionPadNd):
459
+ r"""Pads the input tensor using the reflection of the input boundary.
460
+
461
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
462
+
463
+ Args:
464
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
465
+ padding in all boundaries. If a 6-`tuple`, uses
466
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
467
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
468
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
469
+
470
+ Shape:
471
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
472
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
473
+ where
474
+
475
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
476
+
477
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
478
+
479
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
480
+
481
+ Examples::
482
+
483
+ >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
484
+ >>> m = nn.ReflectionPad3d(1)
485
+ >>> input = torch.arange(8, dtype=torch.float).reshape(1, 1, 2, 2, 2)
486
+ >>> m(input)
487
+ tensor([[[[[7., 6., 7., 6.],
488
+ [5., 4., 5., 4.],
489
+ [7., 6., 7., 6.],
490
+ [5., 4., 5., 4.]],
491
+ [[3., 2., 3., 2.],
492
+ [1., 0., 1., 0.],
493
+ [3., 2., 3., 2.],
494
+ [1., 0., 1., 0.]],
495
+ [[7., 6., 7., 6.],
496
+ [5., 4., 5., 4.],
497
+ [7., 6., 7., 6.],
498
+ [5., 4., 5., 4.]],
499
+ [[3., 2., 3., 2.],
500
+ [1., 0., 1., 0.],
501
+ [3., 2., 3., 2.],
502
+ [1., 0., 1., 0.]]]]])
503
+ """
504
+
505
+ padding: Tuple[int, int, int, int, int, int]
506
+
507
+ def __init__(self, padding: _size_6_t) -> None:
508
+ super().__init__()
509
+ self.padding = _ntuple(6)(padding)
510
+
511
+
512
+ class _ReplicationPadNd(Module):
513
+ __constants__ = ['padding']
514
+ padding: Sequence[int]
515
+
516
+ def forward(self, input: Tensor) -> Tensor:
517
+ return F.pad(input, self.padding, 'replicate')
518
+
519
+ def extra_repr(self) -> str:
520
+ return f'{self.padding}'
521
+
522
+
523
+ class ReplicationPad1d(_ReplicationPadNd):
524
+ r"""Pads the input tensor using replication of the input boundary.
525
+
526
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
527
+
528
+ Args:
529
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
530
+ padding in all boundaries. If a 2-`tuple`, uses
531
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
532
+
533
+ Shape:
534
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
535
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
536
+
537
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
538
+
539
+ Examples::
540
+
541
+ >>> # xdoctest: +IGNORE_WANT("not sure why xdoctest is choking on this")
542
+ >>> m = nn.ReplicationPad1d(2)
543
+ >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
544
+ >>> input
545
+ tensor([[[0., 1., 2., 3.],
546
+ [4., 5., 6., 7.]]])
547
+ >>> m(input)
548
+ tensor([[[0., 0., 0., 1., 2., 3., 3., 3.],
549
+ [4., 4., 4., 5., 6., 7., 7., 7.]]])
550
+ >>> # using different paddings for different sides
551
+ >>> m = nn.ReplicationPad1d((3, 1))
552
+ >>> m(input)
553
+ tensor([[[0., 0., 0., 0., 1., 2., 3., 3.],
554
+ [4., 4., 4., 4., 5., 6., 7., 7.]]])
555
+ """
556
+
557
+ padding: Tuple[int, int]
558
+
559
+ def __init__(self, padding: _size_2_t) -> None:
560
+ super().__init__()
561
+ self.padding = _pair(padding)
562
+
563
+
564
+ class ReplicationPad2d(_ReplicationPadNd):
565
+ r"""Pads the input tensor using replication of the input boundary.
566
+
567
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
568
+
569
+ Args:
570
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
571
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
572
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
573
+
574
+ Shape:
575
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
576
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
577
+
578
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
579
+
580
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
581
+
582
+ Examples::
583
+
584
+ >>> m = nn.ReplicationPad2d(2)
585
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
586
+ >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
587
+ >>> input
588
+ tensor([[[[0., 1., 2.],
589
+ [3., 4., 5.],
590
+ [6., 7., 8.]]]])
591
+ >>> m(input)
592
+ tensor([[[[0., 0., 0., 1., 2., 2., 2.],
593
+ [0., 0., 0., 1., 2., 2., 2.],
594
+ [0., 0., 0., 1., 2., 2., 2.],
595
+ [3., 3., 3., 4., 5., 5., 5.],
596
+ [6., 6., 6., 7., 8., 8., 8.],
597
+ [6., 6., 6., 7., 8., 8., 8.],
598
+ [6., 6., 6., 7., 8., 8., 8.]]]])
599
+ >>> # using different paddings for different sides
600
+ >>> m = nn.ReplicationPad2d((1, 1, 2, 0))
601
+ >>> m(input)
602
+ tensor([[[[0., 0., 1., 2., 2.],
603
+ [0., 0., 1., 2., 2.],
604
+ [0., 0., 1., 2., 2.],
605
+ [3., 3., 4., 5., 5.],
606
+ [6., 6., 7., 8., 8.]]]])
607
+ """
608
+
609
+ padding: Tuple[int, int, int, int]
610
+
611
+ def __init__(self, padding: _size_4_t) -> None:
612
+ super().__init__()
613
+ self.padding = _quadruple(padding)
614
+
615
+
616
+ class ReplicationPad3d(_ReplicationPadNd):
617
+ r"""Pads the input tensor using replication of the input boundary.
618
+
619
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
620
+
621
+ Args:
622
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
623
+ padding in all boundaries. If a 6-`tuple`, uses
624
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
625
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
626
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
627
+
628
+ Shape:
629
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
630
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
631
+ where
632
+
633
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
634
+
635
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
636
+
637
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
638
+
639
+ Examples::
640
+
641
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
642
+ >>> m = nn.ReplicationPad3d(3)
643
+ >>> input = torch.randn(16, 3, 8, 320, 480)
644
+ >>> output = m(input)
645
+ >>> # using different paddings for different sides
646
+ >>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1))
647
+ >>> output = m(input)
648
+ """
649
+
650
+ padding: Tuple[int, int, int, int, int, int]
651
+
652
+ def __init__(self, padding: _size_6_t) -> None:
653
+ super().__init__()
654
+ self.padding = _ntuple(6)(padding)
655
+
656
+
657
+ class ZeroPad1d(ConstantPad1d):
658
+ r"""Pads the input tensor boundaries with zero.
659
+
660
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
661
+
662
+ Args:
663
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
664
+ padding in both boundaries. If a 2-`tuple`, uses
665
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
666
+
667
+ Shape:
668
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
669
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
670
+
671
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
672
+
673
+ Examples::
674
+
675
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
676
+ >>> m = nn.ZeroPad1d(2)
677
+ >>> input = torch.randn(1, 2, 4)
678
+ >>> input
679
+ tensor([[[-1.0491, -0.7152, -0.0749, 0.8530],
680
+ [-1.3287, 1.8966, 0.1466, -0.2771]]])
681
+ >>> m(input)
682
+ tensor([[[ 0.0000, 0.0000, -1.0491, -0.7152, -0.0749, 0.8530, 0.0000,
683
+ 0.0000],
684
+ [ 0.0000, 0.0000, -1.3287, 1.8966, 0.1466, -0.2771, 0.0000,
685
+ 0.0000]]])
686
+ >>> m = nn.ZeroPad1d(2)
687
+ >>> input = torch.randn(1, 2, 3)
688
+ >>> input
689
+ tensor([[[ 1.6616, 1.4523, -1.1255],
690
+ [-3.6372, 0.1182, -1.8652]]])
691
+ >>> m(input)
692
+ tensor([[[ 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000, 0.0000],
693
+ [ 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000, 0.0000]]])
694
+ >>> # using different paddings for different sides
695
+ >>> m = nn.ZeroPad1d((3, 1))
696
+ >>> m(input)
697
+ tensor([[[ 0.0000, 0.0000, 0.0000, 1.6616, 1.4523, -1.1255, 0.0000],
698
+ [ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]])
699
+ """
700
+
701
+ padding: Tuple[int, int]
702
+
703
+ def __init__(self, padding: _size_2_t) -> None:
704
+ super().__init__(padding, 0.)
705
+
706
+ def extra_repr(self) -> str:
707
+ return f'{self.padding}'
708
+
709
+ class ZeroPad2d(ConstantPad2d):
710
+ r"""Pads the input tensor boundaries with zero.
711
+
712
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
713
+
714
+ Args:
715
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
716
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
717
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
718
+
719
+ Shape:
720
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
721
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
722
+
723
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
724
+
725
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
726
+
727
+ Examples::
728
+
729
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
730
+ >>> m = nn.ZeroPad2d(2)
731
+ >>> input = torch.randn(1, 1, 3, 3)
732
+ >>> input
733
+ tensor([[[[-0.1678, -0.4418, 1.9466],
734
+ [ 0.9604, -0.4219, -0.5241],
735
+ [-0.9162, -0.5436, -0.6446]]]])
736
+ >>> m(input)
737
+ tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
738
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
739
+ [ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000],
740
+ [ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000],
741
+ [ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000],
742
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
743
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
744
+ >>> # using different paddings for different sides
745
+ >>> m = nn.ZeroPad2d((1, 1, 2, 0))
746
+ >>> m(input)
747
+ tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
748
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
749
+ [ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000],
750
+ [ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000],
751
+ [ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]])
752
+ """
753
+
754
+ padding: Tuple[int, int, int, int]
755
+
756
+ def __init__(self, padding: _size_4_t) -> None:
757
+ super().__init__(padding, 0.)
758
+
759
+ def extra_repr(self) -> str:
760
+ return f'{self.padding}'
761
+
762
+ class ZeroPad3d(ConstantPad3d):
763
+ r"""Pads the input tensor boundaries with zero.
764
+
765
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
766
+
767
+ Args:
768
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
769
+ padding in all boundaries. If a 6-`tuple`, uses
770
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
771
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
772
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
773
+
774
+ Shape:
775
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
776
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
777
+ :math:`(C, D_{out}, H_{out}, W_{out})`, where
778
+
779
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
780
+
781
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
782
+
783
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
784
+
785
+ Examples::
786
+
787
+ >>> m = nn.ZeroPad3d(3)
788
+ >>> input = torch.randn(16, 3, 10, 20, 30)
789
+ >>> output = m(input)
790
+ >>> # using different paddings for different sides
791
+ >>> m = nn.ZeroPad3d((3, 3, 6, 6, 0, 1))
792
+ >>> output = m(input)
793
+ """
794
+
795
+ padding: Tuple[int, int, int, int, int, int]
796
+
797
+ def __init__(self, padding: _size_6_t) -> None:
798
+ super().__init__(padding, 0.)
799
+
800
+ def extra_repr(self) -> str:
801
+ return f'{self.padding}'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pixelshuffle.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+ from .. import functional as F
3
+
4
+ from torch import Tensor
5
+
6
+ __all__ = ['PixelShuffle', 'PixelUnshuffle']
7
+
8
+ class PixelShuffle(Module):
9
+ r"""Rearrange elements in a tensor according to an upscaling factor.
10
+
11
+ Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
12
+ to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
13
+
14
+ This is useful for implementing efficient sub-pixel convolution
15
+ with a stride of :math:`1/r`.
16
+
17
+ See the paper:
18
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
19
+ by Shi et. al (2016) for more details.
20
+
21
+ Args:
22
+ upscale_factor (int): factor to increase spatial resolution by
23
+
24
+ Shape:
25
+ - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
26
+ - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
27
+
28
+ .. math::
29
+ C_{out} = C_{in} \div \text{upscale\_factor}^2
30
+
31
+ .. math::
32
+ H_{out} = H_{in} \times \text{upscale\_factor}
33
+
34
+ .. math::
35
+ W_{out} = W_{in} \times \text{upscale\_factor}
36
+
37
+ Examples::
38
+
39
+ >>> pixel_shuffle = nn.PixelShuffle(3)
40
+ >>> input = torch.randn(1, 9, 4, 4)
41
+ >>> output = pixel_shuffle(input)
42
+ >>> print(output.size())
43
+ torch.Size([1, 1, 12, 12])
44
+
45
+ .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
46
+ https://arxiv.org/abs/1609.05158
47
+ """
48
+
49
+ __constants__ = ['upscale_factor']
50
+ upscale_factor: int
51
+
52
+ def __init__(self, upscale_factor: int) -> None:
53
+ super().__init__()
54
+ self.upscale_factor = upscale_factor
55
+
56
+ def forward(self, input: Tensor) -> Tensor:
57
+ return F.pixel_shuffle(input, self.upscale_factor)
58
+
59
+ def extra_repr(self) -> str:
60
+ return f'upscale_factor={self.upscale_factor}'
61
+
62
+
63
+ class PixelUnshuffle(Module):
64
+ r"""Reverse the PixelShuffle operation.
65
+
66
+ Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements
67
+ in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
68
+ :math:`(*, C \times r^2, H, W)`, where r is a downscale factor.
69
+
70
+ See the paper:
71
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
72
+ by Shi et. al (2016) for more details.
73
+
74
+ Args:
75
+ downscale_factor (int): factor to decrease spatial resolution by
76
+
77
+ Shape:
78
+ - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
79
+ - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
80
+
81
+ .. math::
82
+ C_{out} = C_{in} \times \text{downscale\_factor}^2
83
+
84
+ .. math::
85
+ H_{out} = H_{in} \div \text{downscale\_factor}
86
+
87
+ .. math::
88
+ W_{out} = W_{in} \div \text{downscale\_factor}
89
+
90
+ Examples::
91
+
92
+ >>> pixel_unshuffle = nn.PixelUnshuffle(3)
93
+ >>> input = torch.randn(1, 1, 12, 12)
94
+ >>> output = pixel_unshuffle(input)
95
+ >>> print(output.size())
96
+ torch.Size([1, 9, 4, 4])
97
+
98
+ .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
99
+ https://arxiv.org/abs/1609.05158
100
+ """
101
+
102
+ __constants__ = ['downscale_factor']
103
+ downscale_factor: int
104
+
105
+ def __init__(self, downscale_factor: int) -> None:
106
+ super().__init__()
107
+ self.downscale_factor = downscale_factor
108
+
109
+ def forward(self, input: Tensor) -> Tensor:
110
+ return F.pixel_unshuffle(input, self.downscale_factor)
111
+
112
+ def extra_repr(self) -> str:
113
+ return f'downscale_factor={self.downscale_factor}'
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/pooling.py ADDED
@@ -0,0 +1,1306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from torch import Tensor
4
+ from .module import Module
5
+ from .utils import _single, _pair, _triple
6
+ from .. import functional as F
7
+
8
+ from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t,
9
+ _ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t)
10
+
11
+ __all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d',
12
+ 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d',
13
+ 'LPPool2d', 'LPPool3d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
14
+ 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
15
+
16
+ class _MaxPoolNd(Module):
17
+ __constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
18
+ 'return_indices', 'ceil_mode']
19
+ return_indices: bool
20
+ ceil_mode: bool
21
+
22
+ def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None,
23
+ padding: _size_any_t = 0, dilation: _size_any_t = 1,
24
+ return_indices: bool = False, ceil_mode: bool = False) -> None:
25
+ super().__init__()
26
+ self.kernel_size = kernel_size
27
+ self.stride = stride if (stride is not None) else kernel_size
28
+ self.padding = padding
29
+ self.dilation = dilation
30
+ self.return_indices = return_indices
31
+ self.ceil_mode = ceil_mode
32
+
33
+ def extra_repr(self) -> str:
34
+ return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
35
+ ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
36
+
37
+
38
+ class MaxPool1d(_MaxPoolNd):
39
+ r"""Applies a 1D max pooling over an input signal composed of several input planes.
40
+
41
+ In the simplest case, the output value of the layer with input size :math:`(N, C, L)`
42
+ and output :math:`(N, C, L_{out})` can be precisely described as:
43
+
44
+ .. math::
45
+ out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
46
+ input(N_i, C_j, stride \times k + m)
47
+
48
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
49
+ for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
50
+ sliding window. This `link`_ has a nice visualization of the pooling parameters.
51
+
52
+ Note:
53
+ When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
54
+ or the input. Sliding windows that would start in the right padded region are ignored.
55
+
56
+ Args:
57
+ kernel_size: The size of the sliding window, must be > 0.
58
+ stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.
59
+ padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
60
+ dilation: The stride between elements within a sliding window, must be > 0.
61
+ return_indices: If ``True``, will return the argmax along with the max values.
62
+ Useful for :class:`torch.nn.MaxUnpool1d` later
63
+ ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
64
+ ensures that every element in the input tensor is covered by a sliding window.
65
+
66
+ Shape:
67
+ - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
68
+ - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
69
+
70
+ .. math::
71
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
72
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
73
+
74
+ Examples::
75
+
76
+ >>> # pool of size=3, stride=2
77
+ >>> m = nn.MaxPool1d(3, stride=2)
78
+ >>> input = torch.randn(20, 16, 50)
79
+ >>> output = m(input)
80
+
81
+ .. _link:
82
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
83
+ """
84
+
85
+ kernel_size: _size_1_t
86
+ stride: _size_1_t
87
+ padding: _size_1_t
88
+ dilation: _size_1_t
89
+
90
+ def forward(self, input: Tensor):
91
+ return F.max_pool1d(input, self.kernel_size, self.stride,
92
+ self.padding, self.dilation, ceil_mode=self.ceil_mode,
93
+ return_indices=self.return_indices)
94
+
95
+
96
+ class MaxPool2d(_MaxPoolNd):
97
+ r"""Applies a 2D max pooling over an input signal composed of several input planes.
98
+
99
+ In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
100
+ output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
101
+ can be precisely described as:
102
+
103
+ .. math::
104
+ \begin{aligned}
105
+ out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
106
+ & \text{input}(N_i, C_j, \text{stride[0]} \times h + m,
107
+ \text{stride[1]} \times w + n)
108
+ \end{aligned}
109
+
110
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
111
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
112
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
113
+
114
+ Note:
115
+ When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
116
+ or the input. Sliding windows that would start in the right padded region are ignored.
117
+
118
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
119
+
120
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
121
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
122
+ and the second `int` for the width dimension
123
+
124
+ Args:
125
+ kernel_size: the size of the window to take a max over
126
+ stride: the stride of the window. Default value is :attr:`kernel_size`
127
+ padding: Implicit negative infinity padding to be added on both sides
128
+ dilation: a parameter that controls the stride of elements in the window
129
+ return_indices: if ``True``, will return the max indices along with the outputs.
130
+ Useful for :class:`torch.nn.MaxUnpool2d` later
131
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
132
+
133
+ Shape:
134
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`
135
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
136
+
137
+ .. math::
138
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
139
+ \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
140
+
141
+ .. math::
142
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
143
+ \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
144
+
145
+ Examples::
146
+
147
+ >>> # pool of square window of size=3, stride=2
148
+ >>> m = nn.MaxPool2d(3, stride=2)
149
+ >>> # pool of non-square window
150
+ >>> m = nn.MaxPool2d((3, 2), stride=(2, 1))
151
+ >>> input = torch.randn(20, 16, 50, 32)
152
+ >>> output = m(input)
153
+
154
+ .. _link:
155
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
156
+ """
157
+
158
+ kernel_size: _size_2_t
159
+ stride: _size_2_t
160
+ padding: _size_2_t
161
+ dilation: _size_2_t
162
+
163
+ def forward(self, input: Tensor):
164
+ return F.max_pool2d(input, self.kernel_size, self.stride,
165
+ self.padding, self.dilation, ceil_mode=self.ceil_mode,
166
+ return_indices=self.return_indices)
167
+
168
+
169
+ class MaxPool3d(_MaxPoolNd):
170
+ r"""Applies a 3D max pooling over an input signal composed of several input planes.
171
+
172
+ In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
173
+ output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
174
+ can be precisely described as:
175
+
176
+ .. math::
177
+ \begin{aligned}
178
+ \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
179
+ & \text{input}(N_i, C_j, \text{stride[0]} \times d + k,
180
+ \text{stride[1]} \times h + m, \text{stride[2]} \times w + n)
181
+ \end{aligned}
182
+
183
+ If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
184
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
185
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
186
+
187
+ Note:
188
+ When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
189
+ or the input. Sliding windows that would start in the right padded region are ignored.
190
+
191
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
192
+
193
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
194
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
195
+ the second `int` for the height dimension and the third `int` for the width dimension
196
+
197
+ Args:
198
+ kernel_size: the size of the window to take a max over
199
+ stride: the stride of the window. Default value is :attr:`kernel_size`
200
+ padding: Implicit negative infinity padding to be added on all three sides
201
+ dilation: a parameter that controls the stride of elements in the window
202
+ return_indices: if ``True``, will return the max indices along with the outputs.
203
+ Useful for :class:`torch.nn.MaxUnpool3d` later
204
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
205
+
206
+ Shape:
207
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
208
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
209
+
210
+ .. math::
211
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
212
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
213
+
214
+ .. math::
215
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
216
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
217
+
218
+ .. math::
219
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
220
+ (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
221
+
222
+ Examples::
223
+
224
+ >>> # pool of square window of size=3, stride=2
225
+ >>> m = nn.MaxPool3d(3, stride=2)
226
+ >>> # pool of non-square window
227
+ >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
228
+ >>> input = torch.randn(20, 16, 50, 44, 31)
229
+ >>> output = m(input)
230
+
231
+ .. _link:
232
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
233
+ """ # noqa: E501
234
+
235
+ kernel_size: _size_3_t
236
+ stride: _size_3_t
237
+ padding: _size_3_t
238
+ dilation: _size_3_t
239
+
240
+ def forward(self, input: Tensor):
241
+ return F.max_pool3d(input, self.kernel_size, self.stride,
242
+ self.padding, self.dilation, ceil_mode=self.ceil_mode,
243
+ return_indices=self.return_indices)
244
+
245
+
246
+ class _MaxUnpoolNd(Module):
247
+
248
+ def extra_repr(self) -> str:
249
+ return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}'
250
+
251
+
252
+ class MaxUnpool1d(_MaxUnpoolNd):
253
+ r"""Computes a partial inverse of :class:`MaxPool1d`.
254
+
255
+ :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
256
+
257
+ :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`
258
+ including the indices of the maximal values and computes a partial inverse
259
+ in which all non-maximal values are set to zero.
260
+
261
+ Note:
262
+ This operation may behave nondeterministically when the input indices has repeat values.
263
+ See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
264
+
265
+ .. note:: :class:`MaxPool1d` can map several input sizes to the same output
266
+ sizes. Hence, the inversion process can get ambiguous.
267
+ To accommodate this, you can provide the needed output size
268
+ as an additional argument :attr:`output_size` in the forward call.
269
+ See the Inputs and Example below.
270
+
271
+ Args:
272
+ kernel_size (int or tuple): Size of the max pooling window.
273
+ stride (int or tuple): Stride of the max pooling window.
274
+ It is set to :attr:`kernel_size` by default.
275
+ padding (int or tuple): Padding that was added to the input
276
+
277
+ Inputs:
278
+ - `input`: the input Tensor to invert
279
+ - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d`
280
+ - `output_size` (optional): the targeted output size
281
+
282
+ Shape:
283
+ - Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`.
284
+ - Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where
285
+
286
+ .. math::
287
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
288
+
289
+ or as given by :attr:`output_size` in the call operator
290
+
291
+ Example::
292
+
293
+ >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
294
+ >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True)
295
+ >>> unpool = nn.MaxUnpool1d(2, stride=2)
296
+ >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])
297
+ >>> output, indices = pool(input)
298
+ >>> unpool(output, indices)
299
+ tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
300
+
301
+ >>> # Example showcasing the use of output_size
302
+ >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])
303
+ >>> output, indices = pool(input)
304
+ >>> unpool(output, indices, output_size=input.size())
305
+ tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]])
306
+
307
+ >>> unpool(output, indices)
308
+ tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
309
+ """
310
+
311
+ kernel_size: _size_1_t
312
+ stride: _size_1_t
313
+ padding: _size_1_t
314
+
315
+ def __init__(self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0) -> None:
316
+ super().__init__()
317
+ self.kernel_size = _single(kernel_size)
318
+ self.stride = _single(stride if (stride is not None) else kernel_size)
319
+ self.padding = _single(padding)
320
+
321
+ def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
322
+ return F.max_unpool1d(input, indices, self.kernel_size, self.stride,
323
+ self.padding, output_size)
324
+
325
+
326
+ class MaxUnpool2d(_MaxUnpoolNd):
327
+ r"""Computes a partial inverse of :class:`MaxPool2d`.
328
+
329
+ :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
330
+
331
+ :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`
332
+ including the indices of the maximal values and computes a partial inverse
333
+ in which all non-maximal values are set to zero.
334
+
335
+ Note:
336
+ This operation may behave nondeterministically when the input indices has repeat values.
337
+ See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
338
+
339
+ .. note:: :class:`MaxPool2d` can map several input sizes to the same output
340
+ sizes. Hence, the inversion process can get ambiguous.
341
+ To accommodate this, you can provide the needed output size
342
+ as an additional argument :attr:`output_size` in the forward call.
343
+ See the Inputs and Example below.
344
+
345
+ Args:
346
+ kernel_size (int or tuple): Size of the max pooling window.
347
+ stride (int or tuple): Stride of the max pooling window.
348
+ It is set to :attr:`kernel_size` by default.
349
+ padding (int or tuple): Padding that was added to the input
350
+
351
+ Inputs:
352
+ - `input`: the input Tensor to invert
353
+ - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d`
354
+ - `output_size` (optional): the targeted output size
355
+
356
+ Shape:
357
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
358
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
359
+
360
+ .. math::
361
+ H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
362
+
363
+ .. math::
364
+ W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
365
+
366
+ or as given by :attr:`output_size` in the call operator
367
+
368
+ Example::
369
+
370
+ >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
371
+ >>> unpool = nn.MaxUnpool2d(2, stride=2)
372
+ >>> input = torch.tensor([[[[ 1., 2., 3., 4.],
373
+ [ 5., 6., 7., 8.],
374
+ [ 9., 10., 11., 12.],
375
+ [13., 14., 15., 16.]]]])
376
+ >>> output, indices = pool(input)
377
+ >>> unpool(output, indices)
378
+ tensor([[[[ 0., 0., 0., 0.],
379
+ [ 0., 6., 0., 8.],
380
+ [ 0., 0., 0., 0.],
381
+ [ 0., 14., 0., 16.]]]])
382
+ >>> # Now using output_size to resolve an ambiguous size for the inverse
383
+ >>> input = torch.torch.tensor([[[[ 1., 2., 3., 4., 5.],
384
+ [ 6., 7., 8., 9., 10.],
385
+ [11., 12., 13., 14., 15.],
386
+ [16., 17., 18., 19., 20.]]]])
387
+ >>> output, indices = pool(input)
388
+ >>> # This call will not work without specifying output_size
389
+ >>> unpool(output, indices, output_size=input.size())
390
+ tensor([[[[ 0., 0., 0., 0., 0.],
391
+ [ 0., 7., 0., 9., 0.],
392
+ [ 0., 0., 0., 0., 0.],
393
+ [ 0., 17., 0., 19., 0.]]]])
394
+
395
+
396
+ """
397
+
398
+ kernel_size: _size_2_t
399
+ stride: _size_2_t
400
+ padding: _size_2_t
401
+
402
+ def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0) -> None:
403
+ super().__init__()
404
+ self.kernel_size = _pair(kernel_size)
405
+ self.stride = _pair(stride if (stride is not None) else kernel_size)
406
+ self.padding = _pair(padding)
407
+
408
+ def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
409
+ return F.max_unpool2d(input, indices, self.kernel_size, self.stride,
410
+ self.padding, output_size)
411
+
412
+
413
+ class MaxUnpool3d(_MaxUnpoolNd):
414
+ r"""Computes a partial inverse of :class:`MaxPool3d`.
415
+
416
+ :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.
417
+ :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d`
418
+ including the indices of the maximal values and computes a partial inverse
419
+ in which all non-maximal values are set to zero.
420
+
421
+ Note:
422
+ This operation may behave nondeterministically when the input indices has repeat values.
423
+ See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
424
+
425
+ .. note:: :class:`MaxPool3d` can map several input sizes to the same output
426
+ sizes. Hence, the inversion process can get ambiguous.
427
+ To accommodate this, you can provide the needed output size
428
+ as an additional argument :attr:`output_size` in the forward call.
429
+ See the Inputs section below.
430
+
431
+ Args:
432
+ kernel_size (int or tuple): Size of the max pooling window.
433
+ stride (int or tuple): Stride of the max pooling window.
434
+ It is set to :attr:`kernel_size` by default.
435
+ padding (int or tuple): Padding that was added to the input
436
+
437
+ Inputs:
438
+ - `input`: the input Tensor to invert
439
+ - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d`
440
+ - `output_size` (optional): the targeted output size
441
+
442
+ Shape:
443
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
444
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
445
+
446
+ .. math::
447
+ D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
448
+
449
+ .. math::
450
+ H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
451
+
452
+ .. math::
453
+ W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]}
454
+
455
+ or as given by :attr:`output_size` in the call operator
456
+
457
+ Example::
458
+
459
+ >>> # pool of square window of size=3, stride=2
460
+ >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True)
461
+ >>> unpool = nn.MaxUnpool3d(3, stride=2)
462
+ >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15))
463
+ >>> unpooled_output = unpool(output, indices)
464
+ >>> unpooled_output.size()
465
+ torch.Size([20, 16, 51, 33, 15])
466
+ """
467
+
468
+ kernel_size: _size_3_t
469
+ stride: _size_3_t
470
+ padding: _size_3_t
471
+
472
+ def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0) -> None:
473
+ super().__init__()
474
+ self.kernel_size = _triple(kernel_size)
475
+ self.stride = _triple(stride if (stride is not None) else kernel_size)
476
+ self.padding = _triple(padding)
477
+
478
+ def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
479
+ return F.max_unpool3d(input, indices, self.kernel_size, self.stride,
480
+ self.padding, output_size)
481
+
482
+
483
+ class _AvgPoolNd(Module):
484
+ __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
485
+
486
+ def extra_repr(self) -> str:
487
+ return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}'
488
+
489
+
490
+ class AvgPool1d(_AvgPoolNd):
491
+ r"""Applies a 1D average pooling over an input signal composed of several input planes.
492
+
493
+ In the simplest case, the output value of the layer with input size :math:`(N, C, L)`,
494
+ output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k`
495
+ can be precisely described as:
496
+
497
+ .. math::
498
+
499
+ \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1}
500
+ \text{input}(N_i, C_j, \text{stride} \times l + m)
501
+
502
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
503
+ for :attr:`padding` number of points.
504
+
505
+ Note:
506
+ When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
507
+ or the input. Sliding windows that would start in the right padded region are ignored.
508
+
509
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be
510
+ an ``int`` or a one-element tuple.
511
+
512
+ Args:
513
+ kernel_size: the size of the window
514
+ stride: the stride of the window. Default value is :attr:`kernel_size`
515
+ padding: implicit zero padding to be added on both sides
516
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
517
+ count_include_pad: when True, will include the zero-padding in the averaging calculation
518
+
519
+ Shape:
520
+ - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
521
+ - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
522
+
523
+ .. math::
524
+ L_{out} = \left\lfloor \frac{L_{in} +
525
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
526
+
527
+ Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in}
528
+ + \text{padding}`, we skip the last window as it would start in the right padded region, resulting in
529
+ :math:`L_{out}` being reduced by one.
530
+
531
+ Examples::
532
+
533
+ >>> # pool with window of size=3, stride=2
534
+ >>> m = nn.AvgPool1d(3, stride=2)
535
+ >>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]]))
536
+ tensor([[[2., 4., 6.]]])
537
+ """
538
+
539
+ kernel_size: _size_1_t
540
+ stride: _size_1_t
541
+ padding: _size_1_t
542
+ ceil_mode: bool
543
+ count_include_pad: bool
544
+
545
+ def __init__(self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _size_1_t = 0, ceil_mode: bool = False,
546
+ count_include_pad: bool = True) -> None:
547
+ super().__init__()
548
+ self.kernel_size = _single(kernel_size)
549
+ self.stride = _single(stride if stride is not None else kernel_size)
550
+ self.padding = _single(padding)
551
+ self.ceil_mode = ceil_mode
552
+ self.count_include_pad = count_include_pad
553
+
554
+ def forward(self, input: Tensor) -> Tensor:
555
+ return F.avg_pool1d(
556
+ input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
557
+ self.count_include_pad)
558
+
559
+
560
+ class AvgPool2d(_AvgPoolNd):
561
+ r"""Applies a 2D average pooling over an input signal composed of several input planes.
562
+
563
+ In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
564
+ output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
565
+ can be precisely described as:
566
+
567
+ .. math::
568
+
569
+ out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
570
+ input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
571
+
572
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
573
+ for :attr:`padding` number of points.
574
+
575
+ Note:
576
+ When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
577
+ or the input. Sliding windows that would start in the right padded region are ignored.
578
+
579
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be:
580
+
581
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
582
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
583
+ and the second `int` for the width dimension
584
+
585
+ Args:
586
+ kernel_size: the size of the window
587
+ stride: the stride of the window. Default value is :attr:`kernel_size`
588
+ padding: implicit zero padding to be added on both sides
589
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
590
+ count_include_pad: when True, will include the zero-padding in the averaging calculation
591
+ divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used.
592
+
593
+
594
+ Shape:
595
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
596
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
597
+
598
+ .. math::
599
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
600
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
601
+
602
+ .. math::
603
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
604
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
605
+
606
+ Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in}
607
+ + \text{padding}[0]`, we skip the last window as it would start in the bottom padded region,
608
+ resulting in :math:`H_{out}` being reduced by one.
609
+
610
+ The same applies for :math:`W_{out}`.
611
+
612
+ Examples::
613
+
614
+ >>> # pool of square window of size=3, stride=2
615
+ >>> m = nn.AvgPool2d(3, stride=2)
616
+ >>> # pool of non-square window
617
+ >>> m = nn.AvgPool2d((3, 2), stride=(2, 1))
618
+ >>> input = torch.randn(20, 16, 50, 32)
619
+ >>> output = m(input)
620
+ """
621
+
622
+ __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override']
623
+
624
+ kernel_size: _size_2_t
625
+ stride: _size_2_t
626
+ padding: _size_2_t
627
+ ceil_mode: bool
628
+ count_include_pad: bool
629
+
630
+ def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0,
631
+ ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None:
632
+ super().__init__()
633
+ self.kernel_size = kernel_size
634
+ self.stride = stride if (stride is not None) else kernel_size
635
+ self.padding = padding
636
+ self.ceil_mode = ceil_mode
637
+ self.count_include_pad = count_include_pad
638
+ self.divisor_override = divisor_override
639
+
640
+ def forward(self, input: Tensor) -> Tensor:
641
+ return F.avg_pool2d(input, self.kernel_size, self.stride,
642
+ self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override)
643
+
644
+
645
+ class AvgPool3d(_AvgPoolNd):
646
+ r"""Applies a 3D average pooling over an input signal composed of several input planes.
647
+
648
+ In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
649
+ output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
650
+ can be precisely described as:
651
+
652
+ .. math::
653
+ \begin{aligned}
654
+ \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
655
+ & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k,
656
+ \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)}
657
+ {kD \times kH \times kW}
658
+ \end{aligned}
659
+
660
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
661
+ for :attr:`padding` number of points.
662
+
663
+ Note:
664
+ When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
665
+ or the input. Sliding windows that would start in the right padded region are ignored.
666
+
667
+ The parameters :attr:`kernel_size`, :attr:`stride` can either be:
668
+
669
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
670
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
671
+ the second `int` for the height dimension and the third `int` for the width dimension
672
+
673
+ Args:
674
+ kernel_size: the size of the window
675
+ stride: the stride of the window. Default value is :attr:`kernel_size`
676
+ padding: implicit zero padding to be added on all three sides
677
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
678
+ count_include_pad: when True, will include the zero-padding in the averaging calculation
679
+ divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used
680
+
681
+ Shape:
682
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
683
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
684
+ :math:`(C, D_{out}, H_{out}, W_{out})`, where
685
+
686
+ .. math::
687
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
688
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
689
+
690
+ .. math::
691
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
692
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
693
+
694
+ .. math::
695
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
696
+ \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
697
+
698
+ Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in}
699
+ + \text{padding}[0]`, we skip the last window as it would start in the padded region,
700
+ resulting in :math:`D_{out}` being reduced by one.
701
+
702
+ The same applies for :math:`W_{out}` and :math:`H_{out}`.
703
+
704
+ Examples::
705
+
706
+ >>> # pool of square window of size=3, stride=2
707
+ >>> m = nn.AvgPool3d(3, stride=2)
708
+ >>> # pool of non-square window
709
+ >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))
710
+ >>> input = torch.randn(20, 16, 50, 44, 31)
711
+ >>> output = m(input)
712
+ """
713
+
714
+ __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override']
715
+
716
+ kernel_size: _size_3_t
717
+ stride: _size_3_t
718
+ padding: _size_3_t
719
+ ceil_mode: bool
720
+ count_include_pad: bool
721
+
722
+ def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0,
723
+ ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None:
724
+ super().__init__()
725
+ self.kernel_size = kernel_size
726
+ self.stride = stride if (stride is not None) else kernel_size
727
+ self.padding = padding
728
+ self.ceil_mode = ceil_mode
729
+ self.count_include_pad = count_include_pad
730
+ self.divisor_override = divisor_override
731
+
732
+ def forward(self, input: Tensor) -> Tensor:
733
+ return F.avg_pool3d(input, self.kernel_size, self.stride,
734
+ self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override)
735
+
736
+ def __setstate__(self, d):
737
+ super().__setstate__(d)
738
+ self.__dict__.setdefault('padding', 0)
739
+ self.__dict__.setdefault('ceil_mode', False)
740
+ self.__dict__.setdefault('count_include_pad', True)
741
+
742
+
743
+ class FractionalMaxPool2d(Module):
744
+ r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
745
+
746
+ Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
747
+
748
+ The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
749
+ step size determined by the target output size.
750
+ The number of output features is equal to the number of input planes.
751
+
752
+ .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
753
+
754
+ Args:
755
+ kernel_size: the size of the window to take a max over.
756
+ Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)`
757
+ output_size: the target output size of the image of the form `oH x oW`.
758
+ Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`.
759
+ Note that we must have :math:`kH + oH - 1 <= H_{in}` and :math:`kW + oW - 1 <= W_{in}`
760
+ output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
761
+ This has to be a number or tuple in the range (0, 1).
762
+ Note that we must have :math:`kH + (output\_ratio\_H * H_{in}) - 1 <= H_{in}`
763
+ and :math:`kW + (output\_ratio\_W * W_{in}) - 1 <= W_{in}`
764
+ return_indices: if ``True``, will return the indices along with the outputs.
765
+ Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False``
766
+
767
+ Shape:
768
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
769
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
770
+ :math:`(H_{out}, W_{out})=\text{output\_size}` or
771
+ :math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`.
772
+
773
+ Examples:
774
+ >>> # pool of square window of size=3, and target output size 13x12
775
+ >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12))
776
+ >>> # pool of square window and target output size being half of input image size
777
+ >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
778
+ >>> input = torch.randn(20, 16, 50, 32)
779
+ >>> output = m(input)
780
+
781
+ .. _Fractional MaxPooling:
782
+ https://arxiv.org/abs/1412.6071
783
+ """
784
+
785
+ __constants__ = ['kernel_size', 'return_indices', 'output_size',
786
+ 'output_ratio']
787
+
788
+ kernel_size: _size_2_t
789
+ return_indices: bool
790
+ output_size: _size_2_t
791
+ output_ratio: _ratio_2_t
792
+
793
+ def __init__(self, kernel_size: _size_2_t, output_size: Optional[_size_2_t] = None,
794
+ output_ratio: Optional[_ratio_2_t] = None,
795
+ return_indices: bool = False, _random_samples=None) -> None:
796
+ super().__init__()
797
+ self.kernel_size = _pair(kernel_size)
798
+ self.return_indices = return_indices
799
+ self.register_buffer('_random_samples', _random_samples)
800
+ self.output_size = _pair(output_size) if output_size is not None else None
801
+ self.output_ratio = _pair(output_ratio) if output_ratio is not None else None
802
+ if output_size is None and output_ratio is None:
803
+ raise ValueError("FractionalMaxPool2d requires specifying either "
804
+ "an output size, or a pooling ratio")
805
+ if output_size is not None and output_ratio is not None:
806
+ raise ValueError("only one of output_size and output_ratio may be specified")
807
+ if self.output_ratio is not None:
808
+ if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
809
+ raise ValueError(f"output_ratio must be between 0 and 1 (got {output_ratio})")
810
+
811
+ def forward(self, input: Tensor):
812
+ return F.fractional_max_pool2d(
813
+ input, self.kernel_size, self.output_size, self.output_ratio,
814
+ self.return_indices,
815
+ _random_samples=self._random_samples)
816
+
817
+
818
+ class FractionalMaxPool3d(Module):
819
+ r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
820
+
821
+ Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
822
+
823
+ The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
824
+ step size determined by the target output size.
825
+ The number of output features is equal to the number of input planes.
826
+
827
+ .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
828
+
829
+ Args:
830
+ kernel_size: the size of the window to take a max over.
831
+ Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)`
832
+ output_size: the target output size of the image of the form `oT x oH x oW`.
833
+ Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH`
834
+ output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
835
+ This has to be a number or tuple in the range (0, 1)
836
+ return_indices: if ``True``, will return the indices along with the outputs.
837
+ Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False``
838
+
839
+ Shape:
840
+ - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
841
+ - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
842
+ :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
843
+ :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
844
+
845
+ Examples:
846
+ >>> # pool of cubic window of size=3, and target output size 13x12x11
847
+ >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11))
848
+ >>> # pool of cubic window and target output size being half of input size
849
+ >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5))
850
+ >>> input = torch.randn(20, 16, 50, 32, 16)
851
+ >>> output = m(input)
852
+
853
+ .. _Fractional MaxPooling:
854
+ https://arxiv.org/abs/1412.6071
855
+ """
856
+
857
+ __constants__ = ['kernel_size', 'return_indices', 'output_size',
858
+ 'output_ratio']
859
+ kernel_size: _size_3_t
860
+ return_indices: bool
861
+ output_size: _size_3_t
862
+ output_ratio: _ratio_3_t
863
+
864
+ def __init__(self, kernel_size: _size_3_t, output_size: Optional[_size_3_t] = None,
865
+ output_ratio: Optional[_ratio_3_t] = None,
866
+ return_indices: bool = False, _random_samples=None) -> None:
867
+ super().__init__()
868
+ self.kernel_size = _triple(kernel_size)
869
+ self.return_indices = return_indices
870
+ self.register_buffer('_random_samples', _random_samples)
871
+ self.output_size = _triple(output_size) if output_size is not None else None
872
+ self.output_ratio = _triple(output_ratio) if output_ratio is not None else None
873
+ if output_size is None and output_ratio is None:
874
+ raise ValueError("FractionalMaxPool3d requires specifying either "
875
+ "an output size, or a pooling ratio")
876
+ if output_size is not None and output_ratio is not None:
877
+ raise ValueError("only one of output_size and output_ratio may be specified")
878
+ if self.output_ratio is not None:
879
+ if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1 and 0 < self.output_ratio[2] < 1):
880
+ raise ValueError(f"output_ratio must be between 0 and 1 (got {output_ratio})")
881
+
882
+ def forward(self, input: Tensor):
883
+ return F.fractional_max_pool3d(
884
+ input, self.kernel_size, self.output_size, self.output_ratio,
885
+ self.return_indices,
886
+ _random_samples=self._random_samples)
887
+
888
+
889
+ class _LPPoolNd(Module):
890
+ __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
891
+
892
+ norm_type: float
893
+ ceil_mode: bool
894
+
895
+ def __init__(self, norm_type: float, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None,
896
+ ceil_mode: bool = False) -> None:
897
+ super().__init__()
898
+ self.norm_type = norm_type
899
+ self.kernel_size = kernel_size
900
+ self.stride = stride
901
+ self.ceil_mode = ceil_mode
902
+
903
+ def extra_repr(self) -> str:
904
+ return 'norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, ' \
905
+ 'ceil_mode={ceil_mode}'.format(**self.__dict__)
906
+
907
+
908
+ class LPPool1d(_LPPoolNd):
909
+ r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
910
+
911
+ On each window, the function computed is:
912
+
913
+ .. math::
914
+ f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
915
+
916
+ - At p = :math:`\infty`, one gets Max Pooling
917
+ - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
918
+
919
+ .. note:: If the sum to the power of `p` is zero, the gradient of this function is
920
+ not defined. This implementation will set the gradient to zero in this case.
921
+
922
+ Args:
923
+ kernel_size: a single int, the size of the window
924
+ stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
925
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
926
+
927
+ Shape:
928
+ - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
929
+ - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
930
+
931
+ .. math::
932
+ L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
933
+
934
+ Examples::
935
+ >>> # power-2 pool of window of length 3, with stride 2.
936
+ >>> m = nn.LPPool1d(2, 3, stride=2)
937
+ >>> input = torch.randn(20, 16, 50)
938
+ >>> output = m(input)
939
+ """
940
+
941
+ kernel_size: _size_1_t
942
+ stride: _size_1_t
943
+
944
+ def forward(self, input: Tensor) -> Tensor:
945
+ return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
946
+ self.stride, self.ceil_mode)
947
+
948
+
949
+ class LPPool2d(_LPPoolNd):
950
+ r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
951
+
952
+ On each window, the function computed is:
953
+
954
+ .. math::
955
+ f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
956
+
957
+ - At p = :math:`\infty`, one gets Max Pooling
958
+ - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
959
+
960
+ The parameters :attr:`kernel_size`, :attr:`stride` can either be:
961
+
962
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
963
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
964
+ and the second `int` for the width dimension
965
+
966
+ .. note:: If the sum to the power of `p` is zero, the gradient of this function is
967
+ not defined. This implementation will set the gradient to zero in this case.
968
+
969
+ Args:
970
+ kernel_size: the size of the window
971
+ stride: the stride of the window. Default value is :attr:`kernel_size`
972
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
973
+
974
+ Shape:
975
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
976
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
977
+
978
+ .. math::
979
+ H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
980
+
981
+ .. math::
982
+ W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
983
+
984
+ Examples::
985
+
986
+ >>> # power-2 pool of square window of size=3, stride=2
987
+ >>> m = nn.LPPool2d(2, 3, stride=2)
988
+ >>> # pool of non-square window of power 1.2
989
+ >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1))
990
+ >>> input = torch.randn(20, 16, 50, 32)
991
+ >>> output = m(input)
992
+
993
+ """
994
+
995
+ kernel_size: _size_2_t
996
+ stride: _size_2_t
997
+
998
+ def forward(self, input: Tensor) -> Tensor:
999
+ return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
1000
+ self.stride, self.ceil_mode)
1001
+
1002
+
1003
+ class LPPool3d(_LPPoolNd):
1004
+ r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
1005
+
1006
+ On each window, the function computed is:
1007
+
1008
+ .. math::
1009
+ f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
1010
+
1011
+ - At p = :math:`\infty`, one gets Max Pooling
1012
+ - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
1013
+
1014
+ The parameters :attr:`kernel_size`, :attr:`stride` can either be:
1015
+
1016
+ - a single ``int`` -- in which case the same value is used for the height, width and depth dimension
1017
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
1018
+ the second `int` for the height dimension and the third `int` for the width dimension
1019
+
1020
+ .. note:: If the sum to the power of `p` is zero, the gradient of this function is
1021
+ not defined. This implementation will set the gradient to zero in this case.
1022
+
1023
+ Args:
1024
+ kernel_size: the size of the window
1025
+ stride: the stride of the window. Default value is :attr:`kernel_size`
1026
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
1027
+
1028
+ Shape:
1029
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
1030
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
1031
+ :math:`(C, D_{out}, H_{out}, W_{out})`, where
1032
+
1033
+ .. math::
1034
+ D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
1035
+
1036
+ .. math::
1037
+ H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
1038
+
1039
+ .. math::
1040
+ W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
1041
+
1042
+ Examples::
1043
+
1044
+ >>> # power-2 pool of square window of size=3, stride=2
1045
+ >>> m = nn.LPPool3d(2, 3, stride=2)
1046
+ >>> # pool of non-square window of power 1.2
1047
+ >>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2))
1048
+ >>> input = torch.randn(20, 16, 50, 44, 31)
1049
+ >>> output = m(input)
1050
+
1051
+ """
1052
+
1053
+ kernel_size: _size_3_t
1054
+ stride: _size_3_t
1055
+
1056
+ def forward(self, input: Tensor) -> Tensor:
1057
+ return F.lp_pool3d(input, float(self.norm_type), self.kernel_size,
1058
+ self.stride, self.ceil_mode)
1059
+
1060
+
1061
+ class _AdaptiveMaxPoolNd(Module):
1062
+ __constants__ = ['output_size', 'return_indices']
1063
+ return_indices: bool
1064
+
1065
+ def __init__(self, output_size: _size_any_opt_t, return_indices: bool = False) -> None:
1066
+ super().__init__()
1067
+ self.output_size = output_size
1068
+ self.return_indices = return_indices
1069
+
1070
+ def extra_repr(self) -> str:
1071
+ return f'output_size={self.output_size}'
1072
+
1073
+ # FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and
1074
+ # output shapes are, and how the operation computes output.
1075
+
1076
+
1077
+ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
1078
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
1079
+
1080
+ The output size is :math:`L_{out}`, for any input size.
1081
+ The number of output features is equal to the number of input planes.
1082
+
1083
+ Args:
1084
+ output_size: the target output size :math:`L_{out}`.
1085
+ return_indices: if ``True``, will return the indices along with the outputs.
1086
+ Useful to pass to nn.MaxUnpool1d. Default: ``False``
1087
+
1088
+ Shape:
1089
+ - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
1090
+ - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
1091
+ :math:`L_{out}=\text{output\_size}`.
1092
+
1093
+ Examples:
1094
+ >>> # target output size of 5
1095
+ >>> m = nn.AdaptiveMaxPool1d(5)
1096
+ >>> input = torch.randn(1, 64, 8)
1097
+ >>> output = m(input)
1098
+
1099
+ """
1100
+
1101
+ output_size: _size_1_t
1102
+
1103
+ def forward(self, input: Tensor):
1104
+ return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
1105
+
1106
+
1107
+ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
1108
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
1109
+
1110
+ The output is of size :math:`H_{out} \times W_{out}`, for any input size.
1111
+ The number of output features is equal to the number of input planes.
1112
+
1113
+ Args:
1114
+ output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`.
1115
+ Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a
1116
+ square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}`
1117
+ can be either a ``int``, or ``None`` which means the size will be the same as that
1118
+ of the input.
1119
+ return_indices: if ``True``, will return the indices along with the outputs.
1120
+ Useful to pass to nn.MaxUnpool2d. Default: ``False``
1121
+
1122
+ Shape:
1123
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
1124
+ - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
1125
+ :math:`(H_{out}, W_{out})=\text{output\_size}`.
1126
+
1127
+ Examples:
1128
+ >>> # target output size of 5x7
1129
+ >>> m = nn.AdaptiveMaxPool2d((5, 7))
1130
+ >>> input = torch.randn(1, 64, 8, 9)
1131
+ >>> output = m(input)
1132
+ >>> # target output size of 7x7 (square)
1133
+ >>> m = nn.AdaptiveMaxPool2d(7)
1134
+ >>> input = torch.randn(1, 64, 10, 9)
1135
+ >>> output = m(input)
1136
+ >>> # target output size of 10x7
1137
+ >>> m = nn.AdaptiveMaxPool2d((None, 7))
1138
+ >>> input = torch.randn(1, 64, 10, 9)
1139
+ >>> output = m(input)
1140
+
1141
+ """
1142
+
1143
+ output_size: _size_2_opt_t
1144
+
1145
+ def forward(self, input: Tensor):
1146
+ return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
1147
+
1148
+
1149
+ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
1150
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
1151
+
1152
+ The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
1153
+ The number of output features is equal to the number of input planes.
1154
+
1155
+ Args:
1156
+ output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`.
1157
+ Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single
1158
+ :math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`.
1159
+ :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` can be either a
1160
+ ``int``, or ``None`` which means the size will be the same as that of the input.
1161
+
1162
+ return_indices: if ``True``, will return the indices along with the outputs.
1163
+ Useful to pass to nn.MaxUnpool3d. Default: ``False``
1164
+
1165
+ Shape:
1166
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
1167
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
1168
+ where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
1169
+
1170
+ Examples:
1171
+ >>> # target output size of 5x7x9
1172
+ >>> m = nn.AdaptiveMaxPool3d((5, 7, 9))
1173
+ >>> input = torch.randn(1, 64, 8, 9, 10)
1174
+ >>> output = m(input)
1175
+ >>> # target output size of 7x7x7 (cube)
1176
+ >>> m = nn.AdaptiveMaxPool3d(7)
1177
+ >>> input = torch.randn(1, 64, 10, 9, 8)
1178
+ >>> output = m(input)
1179
+ >>> # target output size of 7x9x8
1180
+ >>> m = nn.AdaptiveMaxPool3d((7, None, None))
1181
+ >>> input = torch.randn(1, 64, 10, 9, 8)
1182
+ >>> output = m(input)
1183
+
1184
+ """
1185
+
1186
+ output_size: _size_3_opt_t
1187
+
1188
+ def forward(self, input: Tensor):
1189
+ return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
1190
+
1191
+
1192
+ class _AdaptiveAvgPoolNd(Module):
1193
+ __constants__ = ['output_size']
1194
+
1195
+ def __init__(self, output_size: _size_any_opt_t) -> None:
1196
+ super().__init__()
1197
+ self.output_size = output_size
1198
+
1199
+ def extra_repr(self) -> str:
1200
+ return f'output_size={self.output_size}'
1201
+
1202
+
1203
+ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
1204
+ r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
1205
+
1206
+ The output size is :math:`L_{out}`, for any input size.
1207
+ The number of output features is equal to the number of input planes.
1208
+
1209
+ Args:
1210
+ output_size: the target output size :math:`L_{out}`.
1211
+
1212
+ Shape:
1213
+ - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
1214
+ - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
1215
+ :math:`L_{out}=\text{output\_size}`.
1216
+
1217
+ Examples:
1218
+ >>> # target output size of 5
1219
+ >>> m = nn.AdaptiveAvgPool1d(5)
1220
+ >>> input = torch.randn(1, 64, 8)
1221
+ >>> output = m(input)
1222
+
1223
+ """
1224
+
1225
+ output_size: _size_1_t
1226
+
1227
+ def forward(self, input: Tensor) -> Tensor:
1228
+ return F.adaptive_avg_pool1d(input, self.output_size)
1229
+
1230
+
1231
+ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
1232
+ r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
1233
+
1234
+ The output is of size H x W, for any input size.
1235
+ The number of output features is equal to the number of input planes.
1236
+
1237
+ Args:
1238
+ output_size: the target output size of the image of the form H x W.
1239
+ Can be a tuple (H, W) or a single H for a square image H x H.
1240
+ H and W can be either a ``int``, or ``None`` which means the size will
1241
+ be the same as that of the input.
1242
+
1243
+ Shape:
1244
+ - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
1245
+ - Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where
1246
+ :math:`S=\text{output\_size}`.
1247
+
1248
+ Examples:
1249
+ >>> # target output size of 5x7
1250
+ >>> m = nn.AdaptiveAvgPool2d((5, 7))
1251
+ >>> input = torch.randn(1, 64, 8, 9)
1252
+ >>> output = m(input)
1253
+ >>> # target output size of 7x7 (square)
1254
+ >>> m = nn.AdaptiveAvgPool2d(7)
1255
+ >>> input = torch.randn(1, 64, 10, 9)
1256
+ >>> output = m(input)
1257
+ >>> # target output size of 10x7
1258
+ >>> m = nn.AdaptiveAvgPool2d((None, 7))
1259
+ >>> input = torch.randn(1, 64, 10, 9)
1260
+ >>> output = m(input)
1261
+
1262
+ """
1263
+
1264
+ output_size: _size_2_opt_t
1265
+
1266
+ def forward(self, input: Tensor) -> Tensor:
1267
+ return F.adaptive_avg_pool2d(input, self.output_size)
1268
+
1269
+
1270
+ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
1271
+ r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
1272
+
1273
+ The output is of size D x H x W, for any input size.
1274
+ The number of output features is equal to the number of input planes.
1275
+
1276
+ Args:
1277
+ output_size: the target output size of the form D x H x W.
1278
+ Can be a tuple (D, H, W) or a single number D for a cube D x D x D.
1279
+ D, H and W can be either a ``int``, or ``None`` which means the size will
1280
+ be the same as that of the input.
1281
+
1282
+ Shape:
1283
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
1284
+ - Output: :math:`(N, C, S_{0}, S_{1}, S_{2})` or :math:`(C, S_{0}, S_{1}, S_{2})`,
1285
+ where :math:`S=\text{output\_size}`.
1286
+
1287
+ Examples:
1288
+ >>> # target output size of 5x7x9
1289
+ >>> m = nn.AdaptiveAvgPool3d((5, 7, 9))
1290
+ >>> input = torch.randn(1, 64, 8, 9, 10)
1291
+ >>> output = m(input)
1292
+ >>> # target output size of 7x7x7 (cube)
1293
+ >>> m = nn.AdaptiveAvgPool3d(7)
1294
+ >>> input = torch.randn(1, 64, 10, 9, 8)
1295
+ >>> output = m(input)
1296
+ >>> # target output size of 7x9x8
1297
+ >>> m = nn.AdaptiveAvgPool3d((7, None, None))
1298
+ >>> input = torch.randn(1, 64, 10, 9, 8)
1299
+ >>> output = m(input)
1300
+
1301
+ """
1302
+
1303
+ output_size: _size_3_opt_t
1304
+
1305
+ def forward(self, input: Tensor) -> Tensor:
1306
+ return F.adaptive_avg_pool3d(input, self.output_size)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, Any, Union, Callable
3
+
4
+ import torch
5
+ import warnings
6
+ from torch import Tensor
7
+ from .. import functional as F
8
+ from .module import Module
9
+ from .activation import MultiheadAttention
10
+ from .container import ModuleList
11
+ from ..init import xavier_uniform_
12
+ from .dropout import Dropout
13
+ from .linear import Linear
14
+ from .normalization import LayerNorm
15
+
16
+ __all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
17
+
18
+ def _generate_square_subsequent_mask(
19
+ sz: int,
20
+ device: Optional[torch.device] = None,
21
+ dtype: Optional[torch.dtype] = None,
22
+ ) -> Tensor:
23
+ r"""Generate a square causal mask for the sequence.
24
+
25
+ The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
26
+ """
27
+ if device is None:
28
+ device = torch.device('cpu')
29
+ if dtype is None:
30
+ dtype = torch.float32
31
+ return torch.triu(
32
+ torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
33
+ diagonal=1,
34
+ )
35
+
36
+
37
+ def _get_seq_len(
38
+ src: Tensor,
39
+ batch_first: bool
40
+ ) -> Optional[int]:
41
+
42
+ if src.is_nested:
43
+ return None
44
+ else:
45
+ src_size = src.size()
46
+ if len(src_size) == 2:
47
+ # unbatched: S, E
48
+ return src_size[0]
49
+ else:
50
+ # batched: B, S, E if batch_first else S, B, E
51
+ seq_len_pos = 1 if batch_first else 0
52
+ return src_size[seq_len_pos]
53
+
54
+
55
+ class Transformer(Module):
56
+ r"""A transformer model.
57
+
58
+ User is able to modify the attributes as needed. The architecture
59
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
60
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
61
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
62
+ Processing Systems, pages 6000-6010.
63
+
64
+ Args:
65
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
66
+ nhead: the number of heads in the multiheadattention models (default=8).
67
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
68
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
69
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
70
+ dropout: the dropout value (default=0.1).
71
+ activation: the activation function of encoder/decoder intermediate layer, can be a string
72
+ ("relu" or "gelu") or a unary callable. Default: relu
73
+ custom_encoder: custom encoder (default=None).
74
+ custom_decoder: custom decoder (default=None).
75
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
76
+ batch_first: If ``True``, then the input and output tensors are provided
77
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
78
+ norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
79
+ other attention and feedforward operations, otherwise after. Default: ``False`` (after).
80
+ bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
81
+ bias. Default: ``True``.
82
+
83
+ Examples::
84
+ >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
85
+ >>> src = torch.rand((10, 32, 512))
86
+ >>> tgt = torch.rand((20, 32, 512))
87
+ >>> out = transformer_model(src, tgt)
88
+
89
+ Note: A full example to apply nn.Transformer module for the word language model is available in
90
+ https://github.com/pytorch/examples/tree/master/word_language_model
91
+ """
92
+
93
+ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
94
+ num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
95
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
96
+ custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
97
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
98
+ bias: bool = True, device=None, dtype=None) -> None:
99
+ factory_kwargs = {'device': device, 'dtype': dtype}
100
+ super().__init__()
101
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
102
+
103
+ if custom_encoder is not None:
104
+ self.encoder = custom_encoder
105
+ else:
106
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
107
+ activation, layer_norm_eps, batch_first, norm_first,
108
+ bias, **factory_kwargs)
109
+ encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
110
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
111
+
112
+ if custom_decoder is not None:
113
+ self.decoder = custom_decoder
114
+ else:
115
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
116
+ activation, layer_norm_eps, batch_first, norm_first,
117
+ bias, **factory_kwargs)
118
+ decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
119
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
120
+
121
+ self._reset_parameters()
122
+
123
+ self.d_model = d_model
124
+ self.nhead = nhead
125
+
126
+ self.batch_first = batch_first
127
+
128
+ def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
129
+ memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
130
+ tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
131
+ src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
132
+ memory_is_causal: bool = False) -> Tensor:
133
+ r"""Take in and process masked source/target sequences.
134
+
135
+ .. note::
136
+
137
+ If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
138
+ not allowed to participate in the attention,
139
+ which is the opposite of the definition for :attr:`attn_mask`
140
+ in :func:`torch.nn.functional.scaled_dot_product_attention`.
141
+
142
+ Args:
143
+ src: the sequence to the encoder (required).
144
+ tgt: the sequence to the decoder (required).
145
+ src_mask: the additive mask for the src sequence (optional).
146
+ tgt_mask: the additive mask for the tgt sequence (optional).
147
+ memory_mask: the additive mask for the encoder output (optional).
148
+ src_key_padding_mask: the Tensor mask for src keys per batch (optional).
149
+ tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
150
+ memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
151
+ src_is_causal: If specified, applies a causal mask as ``src_mask``.
152
+ Default: ``None``; try to detect a causal mask.
153
+ Warning:
154
+ ``src_is_causal`` provides a hint that ``src_mask`` is
155
+ the causal mask. Providing incorrect hints can result in
156
+ incorrect execution, including forward and backward
157
+ compatibility.
158
+ tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
159
+ Default: ``None``; try to detect a causal mask.
160
+ Warning:
161
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
162
+ the causal mask. Providing incorrect hints can result in
163
+ incorrect execution, including forward and backward
164
+ compatibility.
165
+ memory_is_causal: If specified, applies a causal mask as
166
+ ``memory_mask``.
167
+ Default: ``False``.
168
+ Warning:
169
+ ``memory_is_causal`` provides a hint that
170
+ ``memory_mask`` is the causal mask. Providing incorrect
171
+ hints can result in incorrect execution, including
172
+ forward and backward compatibility.
173
+
174
+ Shape:
175
+ - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
176
+ `(N, S, E)` if `batch_first=True`.
177
+ - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
178
+ `(N, T, E)` if `batch_first=True`.
179
+ - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
180
+ - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
181
+ - memory_mask: :math:`(T, S)`.
182
+ - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
183
+ - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
184
+ - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
185
+
186
+ Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
187
+ positions. If a BoolTensor is provided, positions with ``True``
188
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
189
+ is provided, it will be added to the attention weight.
190
+ [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
191
+ the attention. If a BoolTensor is provided, the positions with the
192
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
193
+
194
+ - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
195
+ `(N, T, E)` if `batch_first=True`.
196
+
197
+ Note: Due to the multi-head attention architecture in the transformer model,
198
+ the output sequence length of a transformer is same as the input sequence
199
+ (i.e. target) length of the decoder.
200
+
201
+ where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
202
+ batch size, :math:`E` is the feature number
203
+
204
+ Examples:
205
+ >>> # xdoctest: +SKIP
206
+ >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
207
+ """
208
+ is_batched = src.dim() == 3
209
+ if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
210
+ raise RuntimeError("the batch number of src and tgt must be equal")
211
+ elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
212
+ raise RuntimeError("the batch number of src and tgt must be equal")
213
+
214
+ if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
215
+ raise RuntimeError("the feature number of src and tgt must be equal to d_model")
216
+
217
+ memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
218
+ is_causal=src_is_causal)
219
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
220
+ tgt_key_padding_mask=tgt_key_padding_mask,
221
+ memory_key_padding_mask=memory_key_padding_mask,
222
+ tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
223
+ return output
224
+
225
+ @staticmethod
226
+ def generate_square_subsequent_mask(
227
+ sz: int,
228
+ device: Optional[torch.device] = None,
229
+ dtype: Optional[torch.dtype] = None,
230
+ ) -> Tensor:
231
+ r"""Generate a square causal mask for the sequence.
232
+
233
+ The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
234
+ """
235
+ return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
236
+
237
+ def _reset_parameters(self):
238
+ r"""Initiate parameters in the transformer model."""
239
+ for p in self.parameters():
240
+ if p.dim() > 1:
241
+ xavier_uniform_(p)
242
+
243
+
244
+ class TransformerEncoder(Module):
245
+ r"""TransformerEncoder is a stack of N encoder layers.
246
+
247
+ Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
248
+
249
+ Args:
250
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
251
+ num_layers: the number of sub-encoder-layers in the encoder (required).
252
+ norm: the layer normalization component (optional).
253
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
254
+ (and convert back on output). This will improve the overall performance of
255
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
256
+
257
+ Examples::
258
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
259
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
260
+ >>> src = torch.rand(10, 32, 512)
261
+ >>> out = transformer_encoder(src)
262
+ """
263
+
264
+ __constants__ = ['norm']
265
+
266
+ def __init__(
267
+ self,
268
+ encoder_layer: "TransformerEncoderLayer",
269
+ num_layers: int,
270
+ norm: Optional[Module] = None,
271
+ enable_nested_tensor: bool = True,
272
+ mask_check: bool = True
273
+ ) -> None:
274
+ super().__init__()
275
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
276
+ self.layers = _get_clones(encoder_layer, num_layers)
277
+ self.num_layers = num_layers
278
+ self.norm = norm
279
+ # this attribute saves the value providedat object construction
280
+ self.enable_nested_tensor = enable_nested_tensor
281
+ # this attribute controls whether nested tensors are used
282
+ self.use_nested_tensor = enable_nested_tensor
283
+ self.mask_check = mask_check
284
+
285
+ enc_layer = "encoder_layer"
286
+ why_not_sparsity_fast_path = ''
287
+ if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
288
+ why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
289
+ elif encoder_layer.norm_first :
290
+ why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
291
+ elif not encoder_layer.self_attn.batch_first:
292
+ why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
293
+ "(use batch_first for better inference performance)")
294
+ elif not encoder_layer.self_attn._qkv_same_embed_dim:
295
+ why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
296
+ elif encoder_layer.self_attn.in_proj_bias is None:
297
+ why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
298
+ elif not encoder_layer.activation_relu_or_gelu:
299
+ why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
300
+ elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
301
+ why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
302
+ elif encoder_layer.self_attn.num_heads % 2 == 1:
303
+ why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
304
+
305
+ if enable_nested_tensor and why_not_sparsity_fast_path:
306
+ warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
307
+ self.use_nested_tensor = False
308
+
309
+
310
+ def forward(
311
+ self,
312
+ src: Tensor,
313
+ mask: Optional[Tensor] = None,
314
+ src_key_padding_mask: Optional[Tensor] = None,
315
+ is_causal: Optional[bool] = None) -> Tensor:
316
+ r"""Pass the input through the encoder layers in turn.
317
+
318
+ Args:
319
+ src: the sequence to the encoder (required).
320
+ mask: the mask for the src sequence (optional).
321
+ src_key_padding_mask: the mask for the src keys per batch (optional).
322
+ is_causal: If specified, applies a causal mask as ``mask``.
323
+ Default: ``None``; try to detect a causal mask.
324
+ Warning:
325
+ ``is_causal`` provides a hint that ``mask`` is the
326
+ causal mask. Providing incorrect hints can result in
327
+ incorrect execution, including forward and backward
328
+ compatibility.
329
+
330
+ Shape:
331
+ see the docs in :class:`~torch.nn.Transformer`.
332
+ """
333
+ src_key_padding_mask = F._canonical_mask(
334
+ mask=src_key_padding_mask,
335
+ mask_name="src_key_padding_mask",
336
+ other_type=F._none_or_dtype(mask),
337
+ other_name="mask",
338
+ target_type=src.dtype
339
+ )
340
+
341
+ mask = F._canonical_mask(
342
+ mask=mask,
343
+ mask_name="mask",
344
+ other_type=None,
345
+ other_name="",
346
+ target_type=src.dtype,
347
+ check_other=False,
348
+ )
349
+
350
+ output = src
351
+ convert_to_nested = False
352
+ first_layer = self.layers[0]
353
+ src_key_padding_mask_for_layers = src_key_padding_mask
354
+ why_not_sparsity_fast_path = ''
355
+ str_first_layer = "self.layers[0]"
356
+ batch_first = first_layer.self_attn.batch_first
357
+ is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
358
+
359
+ if not is_fastpath_enabled:
360
+ why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
361
+ elif not hasattr(self, "use_nested_tensor"):
362
+ why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
363
+ elif not self.use_nested_tensor:
364
+ why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
365
+ elif first_layer.training:
366
+ why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
367
+ elif not src.dim() == 3:
368
+ why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
369
+ elif src_key_padding_mask is None:
370
+ why_not_sparsity_fast_path = "src_key_padding_mask was None"
371
+ elif (((not hasattr(self, "mask_check")) or self.mask_check)
372
+ and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
373
+ why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
374
+ elif output.is_nested:
375
+ why_not_sparsity_fast_path = "NestedTensor input is not supported"
376
+ elif mask is not None:
377
+ why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
378
+ elif torch.is_autocast_enabled():
379
+ why_not_sparsity_fast_path = "autocast is enabled"
380
+
381
+ if not why_not_sparsity_fast_path:
382
+ tensor_args = (
383
+ src,
384
+ first_layer.self_attn.in_proj_weight,
385
+ first_layer.self_attn.in_proj_bias,
386
+ first_layer.self_attn.out_proj.weight,
387
+ first_layer.self_attn.out_proj.bias,
388
+ first_layer.norm1.weight,
389
+ first_layer.norm1.bias,
390
+ first_layer.norm2.weight,
391
+ first_layer.norm2.bias,
392
+ first_layer.linear1.weight,
393
+ first_layer.linear1.bias,
394
+ first_layer.linear2.weight,
395
+ first_layer.linear2.bias,
396
+ )
397
+ _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
398
+ if torch.overrides.has_torch_function(tensor_args):
399
+ why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
400
+ elif src.device.type not in _supported_device_type:
401
+ why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
402
+ elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
403
+ why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
404
+ "input/output projection weights or biases requires_grad")
405
+
406
+ if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
407
+ convert_to_nested = True
408
+ output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
409
+ src_key_padding_mask_for_layers = None
410
+
411
+ seq_len = _get_seq_len(src, batch_first)
412
+ is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
413
+
414
+ for mod in self.layers:
415
+ output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
416
+
417
+ if convert_to_nested:
418
+ output = output.to_padded_tensor(0., src.size())
419
+
420
+ if self.norm is not None:
421
+ output = self.norm(output)
422
+
423
+ return output
424
+
425
+
426
+ class TransformerDecoder(Module):
427
+ r"""TransformerDecoder is a stack of N decoder layers.
428
+
429
+ Args:
430
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
431
+ num_layers: the number of sub-decoder-layers in the decoder (required).
432
+ norm: the layer normalization component (optional).
433
+
434
+ Examples::
435
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
436
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
437
+ >>> memory = torch.rand(10, 32, 512)
438
+ >>> tgt = torch.rand(20, 32, 512)
439
+ >>> out = transformer_decoder(tgt, memory)
440
+ """
441
+
442
+ __constants__ = ['norm']
443
+
444
+ def __init__(
445
+ self,
446
+ decoder_layer: "TransformerDecoderLayer",
447
+ num_layers: int,
448
+ norm: Optional[Module] = None
449
+ ) -> None:
450
+ super().__init__()
451
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
452
+ self.layers = _get_clones(decoder_layer, num_layers)
453
+ self.num_layers = num_layers
454
+ self.norm = norm
455
+
456
+ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
457
+ memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
458
+ memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
459
+ memory_is_causal: bool = False) -> Tensor:
460
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
461
+
462
+ Args:
463
+ tgt: the sequence to the decoder (required).
464
+ memory: the sequence from the last layer of the encoder (required).
465
+ tgt_mask: the mask for the tgt sequence (optional).
466
+ memory_mask: the mask for the memory sequence (optional).
467
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
468
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
469
+ tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
470
+ Default: ``None``; try to detect a causal mask.
471
+ Warning:
472
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
473
+ the causal mask. Providing incorrect hints can result in
474
+ incorrect execution, including forward and backward
475
+ compatibility.
476
+ memory_is_causal: If specified, applies a causal mask as
477
+ ``memory mask``.
478
+ Default: ``False``.
479
+ Warning:
480
+ ``memory_is_causal`` provides a hint that
481
+ ``memory_mask`` is the causal mask. Providing incorrect
482
+ hints can result in incorrect execution, including
483
+ forward and backward compatibility.
484
+
485
+ Shape:
486
+ see the docs in :class:`~torch.nn.Transformer`.
487
+ """
488
+ output = tgt
489
+
490
+ seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
491
+ tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
492
+
493
+ for mod in self.layers:
494
+ output = mod(output, memory, tgt_mask=tgt_mask,
495
+ memory_mask=memory_mask,
496
+ tgt_key_padding_mask=tgt_key_padding_mask,
497
+ memory_key_padding_mask=memory_key_padding_mask,
498
+ tgt_is_causal=tgt_is_causal,
499
+ memory_is_causal=memory_is_causal)
500
+
501
+ if self.norm is not None:
502
+ output = self.norm(output)
503
+
504
+ return output
505
+
506
+ class TransformerEncoderLayer(Module):
507
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
508
+
509
+ This standard encoder layer is based on the paper "Attention Is All You Need".
510
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
511
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
512
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
513
+ in a different way during application.
514
+
515
+ TransformerEncoderLayer can handle either traditional torch.tensor inputs,
516
+ or Nested Tensor inputs. Derived classes are expected to similarly accept
517
+ both input formats. (Not all combinations of inputs are currently
518
+ supported by TransformerEncoderLayer while Nested Tensor is in prototype
519
+ state.)
520
+
521
+ If you are implementing a custom layer, you may derive it either from
522
+ the Module or TransformerEncoderLayer class. If your custom layer
523
+ supports both torch.Tensors and Nested Tensors inputs, make its
524
+ implementation a derived class of TransformerEncoderLayer. If your custom
525
+ Layer supports only torch.Tensor inputs, derive its implementation from
526
+ Module.
527
+
528
+ Args:
529
+ d_model: the number of expected features in the input (required).
530
+ nhead: the number of heads in the multiheadattention models (required).
531
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
532
+ dropout: the dropout value (default=0.1).
533
+ activation: the activation function of the intermediate layer, can be a string
534
+ ("relu" or "gelu") or a unary callable. Default: relu
535
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
536
+ batch_first: If ``True``, then the input and output tensors are provided
537
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
538
+ norm_first: if ``True``, layer norm is done prior to attention and feedforward
539
+ operations, respectively. Otherwise it's done after. Default: ``False`` (after).
540
+ bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
541
+ bias. Default: ``True``.
542
+
543
+ Examples::
544
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
545
+ >>> src = torch.rand(10, 32, 512)
546
+ >>> out = encoder_layer(src)
547
+
548
+ Alternatively, when ``batch_first`` is ``True``:
549
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
550
+ >>> src = torch.rand(32, 10, 512)
551
+ >>> out = encoder_layer(src)
552
+
553
+ Fast path:
554
+ forward() will use a special optimized implementation described in
555
+ `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
556
+ conditions are met:
557
+
558
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
559
+ argument ``requires_grad``
560
+ - training is disabled (using ``.eval()``)
561
+ - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
562
+ - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
563
+ - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
564
+ - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
565
+ nor ``src_key_padding_mask`` is passed
566
+ - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
567
+ unless the caller has manually modified one without modifying the other)
568
+
569
+ If the optimized implementation is in use, a
570
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
571
+ passed for ``src`` to represent padding more efficiently than using a padding
572
+ mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
573
+ returned, and an additional speedup proportional to the fraction of the input that
574
+ is padding can be expected.
575
+
576
+ .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
577
+ https://arxiv.org/abs/2205.14135
578
+
579
+ """
580
+
581
+ __constants__ = ['norm_first']
582
+
583
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
584
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
585
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
586
+ bias: bool = True, device=None, dtype=None) -> None:
587
+ factory_kwargs = {'device': device, 'dtype': dtype}
588
+ super().__init__()
589
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
590
+ bias=bias, batch_first=batch_first,
591
+ **factory_kwargs)
592
+ # Implementation of Feedforward model
593
+ self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
594
+ self.dropout = Dropout(dropout)
595
+ self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
596
+
597
+ self.norm_first = norm_first
598
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
599
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
600
+ self.dropout1 = Dropout(dropout)
601
+ self.dropout2 = Dropout(dropout)
602
+
603
+ # Legacy string support for activation function.
604
+ if isinstance(activation, str):
605
+ activation = _get_activation_fn(activation)
606
+
607
+ # We can't test self.activation in forward() in TorchScript,
608
+ # so stash some information about it instead.
609
+ if activation is F.relu or isinstance(activation, torch.nn.ReLU):
610
+ self.activation_relu_or_gelu = 1
611
+ elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
612
+ self.activation_relu_or_gelu = 2
613
+ else:
614
+ self.activation_relu_or_gelu = 0
615
+ self.activation = activation
616
+
617
+ def __setstate__(self, state):
618
+ super().__setstate__(state)
619
+ if not hasattr(self, 'activation'):
620
+ self.activation = F.relu
621
+
622
+
623
+ def forward(
624
+ self,
625
+ src: Tensor,
626
+ src_mask: Optional[Tensor] = None,
627
+ src_key_padding_mask: Optional[Tensor] = None,
628
+ is_causal: bool = False) -> Tensor:
629
+ r"""Pass the input through the encoder layer.
630
+
631
+ Args:
632
+ src: the sequence to the encoder layer (required).
633
+ src_mask: the mask for the src sequence (optional).
634
+ src_key_padding_mask: the mask for the src keys per batch (optional).
635
+ is_causal: If specified, applies a causal mask as ``src mask``.
636
+ Default: ``False``.
637
+ Warning:
638
+ ``is_causal`` provides a hint that ``src_mask`` is the
639
+ causal mask. Providing incorrect hints can result in
640
+ incorrect execution, including forward and backward
641
+ compatibility.
642
+
643
+ Shape:
644
+ see the docs in :class:`~torch.nn.Transformer`.
645
+ """
646
+ src_key_padding_mask = F._canonical_mask(
647
+ mask=src_key_padding_mask,
648
+ mask_name="src_key_padding_mask",
649
+ other_type=F._none_or_dtype(src_mask),
650
+ other_name="src_mask",
651
+ target_type=src.dtype
652
+ )
653
+
654
+ src_mask = F._canonical_mask(
655
+ mask=src_mask,
656
+ mask_name="src_mask",
657
+ other_type=None,
658
+ other_name="",
659
+ target_type=src.dtype,
660
+ check_other=False,
661
+ )
662
+
663
+ is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
664
+
665
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
666
+ why_not_sparsity_fast_path = ''
667
+ if not is_fastpath_enabled:
668
+ why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
669
+ elif not src.dim() == 3:
670
+ why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
671
+ elif self.training:
672
+ why_not_sparsity_fast_path = "training is enabled"
673
+ elif not self.self_attn.batch_first:
674
+ why_not_sparsity_fast_path = "self_attn.batch_first was not True"
675
+ elif self.self_attn.in_proj_bias is None:
676
+ why_not_sparsity_fast_path = "self_attn was passed bias=False"
677
+ elif not self.self_attn._qkv_same_embed_dim:
678
+ why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
679
+ elif not self.activation_relu_or_gelu:
680
+ why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
681
+ elif not (self.norm1.eps == self.norm2.eps):
682
+ why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
683
+ elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
684
+ why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
685
+ elif self.self_attn.num_heads % 2 == 1:
686
+ why_not_sparsity_fast_path = "num_head is odd"
687
+ elif torch.is_autocast_enabled():
688
+ why_not_sparsity_fast_path = "autocast is enabled"
689
+ if not why_not_sparsity_fast_path:
690
+ tensor_args = (
691
+ src,
692
+ self.self_attn.in_proj_weight,
693
+ self.self_attn.in_proj_bias,
694
+ self.self_attn.out_proj.weight,
695
+ self.self_attn.out_proj.bias,
696
+ self.norm1.weight,
697
+ self.norm1.bias,
698
+ self.norm2.weight,
699
+ self.norm2.bias,
700
+ self.linear1.weight,
701
+ self.linear1.bias,
702
+ self.linear2.weight,
703
+ self.linear2.bias,
704
+ )
705
+
706
+ # We have to use list comprehensions below because TorchScript does not support
707
+ # generator expressions.
708
+ _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
709
+ if torch.overrides.has_torch_function(tensor_args):
710
+ why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
711
+ elif not all((x.device.type in _supported_device_type) for x in tensor_args):
712
+ why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
713
+ f"{_supported_device_type}")
714
+ elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
715
+ why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
716
+ "input/output projection weights or biases requires_grad")
717
+
718
+ if not why_not_sparsity_fast_path:
719
+ merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
720
+ return torch._transformer_encoder_layer_fwd(
721
+ src,
722
+ self.self_attn.embed_dim,
723
+ self.self_attn.num_heads,
724
+ self.self_attn.in_proj_weight,
725
+ self.self_attn.in_proj_bias,
726
+ self.self_attn.out_proj.weight,
727
+ self.self_attn.out_proj.bias,
728
+ self.activation_relu_or_gelu == 2,
729
+ self.norm_first,
730
+ self.norm1.eps,
731
+ self.norm1.weight,
732
+ self.norm1.bias,
733
+ self.norm2.weight,
734
+ self.norm2.bias,
735
+ self.linear1.weight,
736
+ self.linear1.bias,
737
+ self.linear2.weight,
738
+ self.linear2.bias,
739
+ merged_mask,
740
+ mask_type,
741
+ )
742
+
743
+
744
+ x = src
745
+ if self.norm_first:
746
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
747
+ x = x + self._ff_block(self.norm2(x))
748
+ else:
749
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
750
+ x = self.norm2(x + self._ff_block(x))
751
+
752
+ return x
753
+
754
+ # self-attention block
755
+ def _sa_block(self, x: Tensor,
756
+ attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
757
+ x = self.self_attn(x, x, x,
758
+ attn_mask=attn_mask,
759
+ key_padding_mask=key_padding_mask,
760
+ need_weights=False, is_causal=is_causal)[0]
761
+ return self.dropout1(x)
762
+
763
+ # feed forward block
764
+ def _ff_block(self, x: Tensor) -> Tensor:
765
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
766
+ return self.dropout2(x)
767
+
768
+
769
+ class TransformerDecoderLayer(Module):
770
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
771
+
772
+ This standard decoder layer is based on the paper "Attention Is All You Need".
773
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
774
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
775
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
776
+ in a different way during application.
777
+
778
+ Args:
779
+ d_model: the number of expected features in the input (required).
780
+ nhead: the number of heads in the multiheadattention models (required).
781
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
782
+ dropout: the dropout value (default=0.1).
783
+ activation: the activation function of the intermediate layer, can be a string
784
+ ("relu" or "gelu") or a unary callable. Default: relu
785
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
786
+ batch_first: If ``True``, then the input and output tensors are provided
787
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
788
+ norm_first: if ``True``, layer norm is done prior to self attention, multihead
789
+ attention and feedforward operations, respectively. Otherwise it's done after.
790
+ Default: ``False`` (after).
791
+ bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
792
+ bias. Default: ``True``.
793
+
794
+ Examples::
795
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
796
+ >>> memory = torch.rand(10, 32, 512)
797
+ >>> tgt = torch.rand(20, 32, 512)
798
+ >>> out = decoder_layer(tgt, memory)
799
+
800
+ Alternatively, when ``batch_first`` is ``True``:
801
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
802
+ >>> memory = torch.rand(32, 10, 512)
803
+ >>> tgt = torch.rand(32, 20, 512)
804
+ >>> out = decoder_layer(tgt, memory)
805
+ """
806
+
807
+ __constants__ = ['norm_first']
808
+
809
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
810
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
811
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
812
+ bias: bool = True, device=None, dtype=None) -> None:
813
+ factory_kwargs = {'device': device, 'dtype': dtype}
814
+ super().__init__()
815
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
816
+ bias=bias, **factory_kwargs)
817
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
818
+ bias=bias, **factory_kwargs)
819
+ # Implementation of Feedforward model
820
+ self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
821
+ self.dropout = Dropout(dropout)
822
+ self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
823
+
824
+ self.norm_first = norm_first
825
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
826
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
827
+ self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
828
+ self.dropout1 = Dropout(dropout)
829
+ self.dropout2 = Dropout(dropout)
830
+ self.dropout3 = Dropout(dropout)
831
+
832
+ # Legacy string support for activation function.
833
+ if isinstance(activation, str):
834
+ self.activation = _get_activation_fn(activation)
835
+ else:
836
+ self.activation = activation
837
+
838
+ def __setstate__(self, state):
839
+ if 'activation' not in state:
840
+ state['activation'] = F.relu
841
+ super().__setstate__(state)
842
+
843
+ def forward(
844
+ self,
845
+ tgt: Tensor,
846
+ memory: Tensor,
847
+ tgt_mask: Optional[Tensor] = None,
848
+ memory_mask: Optional[Tensor] = None,
849
+ tgt_key_padding_mask: Optional[Tensor] = None,
850
+ memory_key_padding_mask: Optional[Tensor] = None,
851
+ tgt_is_causal: bool = False,
852
+ memory_is_causal: bool = False,
853
+ ) -> Tensor:
854
+ r"""Pass the inputs (and mask) through the decoder layer.
855
+
856
+ Args:
857
+ tgt: the sequence to the decoder layer (required).
858
+ memory: the sequence from the last layer of the encoder (required).
859
+ tgt_mask: the mask for the tgt sequence (optional).
860
+ memory_mask: the mask for the memory sequence (optional).
861
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
862
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
863
+ tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
864
+ Default: ``False``.
865
+ Warning:
866
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
867
+ the causal mask. Providing incorrect hints can result in
868
+ incorrect execution, including forward and backward
869
+ compatibility.
870
+ memory_is_causal: If specified, applies a causal mask as
871
+ ``memory mask``.
872
+ Default: ``False``.
873
+ Warning:
874
+ ``memory_is_causal`` provides a hint that
875
+ ``memory_mask`` is the causal mask. Providing incorrect
876
+ hints can result in incorrect execution, including
877
+ forward and backward compatibility.
878
+
879
+ Shape:
880
+ see the docs in :class:`~torch.nn.Transformer`.
881
+ """
882
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
883
+
884
+ x = tgt
885
+ if self.norm_first:
886
+ x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
887
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
888
+ x = x + self._ff_block(self.norm3(x))
889
+ else:
890
+ x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
891
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
892
+ x = self.norm3(x + self._ff_block(x))
893
+
894
+ return x
895
+
896
+ # self-attention block
897
+ def _sa_block(self, x: Tensor,
898
+ attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
899
+ x = self.self_attn(x, x, x,
900
+ attn_mask=attn_mask,
901
+ key_padding_mask=key_padding_mask,
902
+ is_causal=is_causal,
903
+ need_weights=False)[0]
904
+ return self.dropout1(x)
905
+
906
+ # multihead attention block
907
+ def _mha_block(self, x: Tensor, mem: Tensor,
908
+ attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
909
+ x = self.multihead_attn(x, mem, mem,
910
+ attn_mask=attn_mask,
911
+ key_padding_mask=key_padding_mask,
912
+ is_causal=is_causal,
913
+ need_weights=False)[0]
914
+ return self.dropout2(x)
915
+
916
+ # feed forward block
917
+ def _ff_block(self, x: Tensor) -> Tensor:
918
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
919
+ return self.dropout3(x)
920
+
921
+
922
+ def _get_clones(module, N):
923
+ # FIXME: copy.deepcopy() is not defined on nn.module
924
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
925
+
926
+
927
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
928
+ if activation == "relu":
929
+ return F.relu
930
+ elif activation == "gelu":
931
+ return F.gelu
932
+
933
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}")
934
+
935
+
936
+ def _detect_is_causal_mask(
937
+ mask: Optional[Tensor],
938
+ is_causal: Optional[bool] = None,
939
+ size: Optional[int] = None,
940
+ ) -> bool:
941
+ """Return whether the given attention mask is causal.
942
+
943
+ Warning:
944
+ If ``is_causal`` is not ``None``, its value will be returned as is. If a
945
+ user supplies an incorrect ``is_causal`` hint,
946
+
947
+ ``is_causal=False`` when the mask is in fact a causal attention.mask
948
+ may lead to reduced performance relative to what would be achievable
949
+ with ``is_causal=True``;
950
+ ``is_causal=True`` when the mask is in fact not a causal attention.mask
951
+ may lead to incorrect and unpredictable execution - in some scenarios,
952
+ a causal mask may be applied based on the hint, in other execution
953
+ scenarios the specified mask may be used. The choice may not appear
954
+ to be deterministic, in that a number of factors like alignment,
955
+ hardware SKU, etc influence the decision whether to use a mask or
956
+ rely on the hint.
957
+ ``size`` if not None, check whether the mask is a causal mask of the provided size
958
+ Otherwise, checks for any causal mask.
959
+ """
960
+ # Prevent type refinement
961
+ make_causal = (is_causal is True)
962
+
963
+ if is_causal is None and mask is not None:
964
+ sz = size if size is not None else mask.size(-2)
965
+ causal_comparison = _generate_square_subsequent_mask(
966
+ sz, device=mask.device, dtype=mask.dtype)
967
+
968
+ # Do not use `torch.equal` so we handle batched masks by
969
+ # broadcasting the comparison.
970
+ if mask.size() == causal_comparison.size():
971
+ make_causal = bool((mask == causal_comparison).all())
972
+ else:
973
+ make_causal = False
974
+
975
+ return make_causal
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.06 kB). View file