koichi12 commited on
Commit
f196197
·
verified ·
1 Parent(s): 5fce27e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/torchvision/image.so +3 -0
  3. .venv/lib/python3.11/site-packages/torchvision/models/__init__.py +23 -0
  4. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_api.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_meta.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_utils.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/alexnet.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/convnext.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/densenet.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/efficientnet.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/googlenet.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/inception.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/maxvit.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mnasnet.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenet.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/regnet.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/resnet.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/squeezenet.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vgg.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchvision/models/_api.py +277 -0
  28. .venv/lib/python3.11/site-packages/torchvision/models/_meta.py +1554 -0
  29. .venv/lib/python3.11/site-packages/torchvision/models/_utils.py +256 -0
  30. .venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py +572 -0
  31. .venv/lib/python3.11/site-packages/torchvision/models/inception.py +478 -0
  32. .venv/lib/python3.11/site-packages/torchvision/models/mnasnet.py +434 -0
  33. .venv/lib/python3.11/site-packages/torchvision/models/mobilenet.py +6 -0
  34. .venv/lib/python3.11/site-packages/torchvision/models/mobilenetv2.py +260 -0
  35. .venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__init__.py +1 -0
  36. .venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/__init__.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/_utils.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/raft.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torchvision/models/optical_flow/_utils.py +48 -0
  40. .venv/lib/python3.11/site-packages/torchvision/models/optical_flow/raft.py +947 -0
  41. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__init__.py +5 -0
  42. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/__init__.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/googlenet.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/inception.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenet.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv2.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv3.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/resnet.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/shufflenetv2.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/utils.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -344,3 +344,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
344
  .venv/lib/python3.11/site-packages/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
345
  .venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
346
  .venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
344
  .venv/lib/python3.11/site-packages/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
345
  .venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
346
  .venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
347
+ .venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torchvision/image.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c82377c2c2be60cedf80c171874d8d50d8b09102fe42c20b3a426b7715a1fc4d
3
+ size 667281
.venv/lib/python3.11/site-packages/torchvision/models/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .alexnet import *
2
+ from .convnext import *
3
+ from .densenet import *
4
+ from .efficientnet import *
5
+ from .googlenet import *
6
+ from .inception import *
7
+ from .mnasnet import *
8
+ from .mobilenet import *
9
+ from .regnet import *
10
+ from .resnet import *
11
+ from .shufflenetv2 import *
12
+ from .squeezenet import *
13
+ from .vgg import *
14
+ from .vision_transformer import *
15
+ from .swin_transformer import *
16
+ from .maxvit import *
17
+ from . import detection, optical_flow, quantization, segmentation, video
18
+
19
+ # The Weights and WeightsEnum are developer-facing utils that we make public for
20
+ # downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
21
+ # TODO: we could / should document them publicly, but it's not clear where, as
22
+ # they're not intended for end users.
23
+ from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.08 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_api.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_meta.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/alexnet.cpython-311.pyc ADDED
Binary file (6.81 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/convnext.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/densenet.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/efficientnet.cpython-311.pyc ADDED
Binary file (45.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-311.pyc ADDED
Binary file (31.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/googlenet.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/inception.cpython-311.pyc ADDED
Binary file (27.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/maxvit.cpython-311.pyc ADDED
Binary file (38.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mnasnet.cpython-311.pyc ADDED
Binary file (21.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenet.cpython-311.pyc ADDED
Binary file (373 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/regnet.cpython-311.pyc ADDED
Binary file (57.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (39.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/squeezenet.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-311.pyc ADDED
Binary file (49 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vgg.cpython-311.pyc ADDED
Binary file (22.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-311.pyc ADDED
Binary file (33.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/_api.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fnmatch
2
+ import importlib
3
+ import inspect
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from functools import partial
8
+ from inspect import signature
9
+ from types import ModuleType
10
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
11
+
12
+ from torch import nn
13
+
14
+ from .._internally_replaced_utils import load_state_dict_from_url
15
+
16
+
17
+ __all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
18
+
19
+
20
+ @dataclass
21
+ class Weights:
22
+ """
23
+ This class is used to group important attributes associated with the pre-trained weights.
24
+
25
+ Args:
26
+ url (str): The location where we find the weights.
27
+ transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
28
+ needed to use the model. The reason we attach a constructor method rather than an already constructed
29
+ object is because the specific object might have memory and thus we want to delay initialization until
30
+ needed.
31
+ meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
32
+ informative attributes (for example the number of parameters/flops, recipe link/methods used in training
33
+ etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
34
+ meta-data (for example the `classes` of a classification model) needed to use the model.
35
+ """
36
+
37
+ url: str
38
+ transforms: Callable
39
+ meta: Dict[str, Any]
40
+
41
+ def __eq__(self, other: Any) -> bool:
42
+ # We need this custom implementation for correct deep-copy and deserialization behavior.
43
+ # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
44
+ # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
45
+ # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
46
+ # for it, the check against the defined members would fail and effectively prevent the weights from being
47
+ # deep-copied or deserialized.
48
+ # See https://github.com/pytorch/vision/pull/7107 for details.
49
+ if not isinstance(other, Weights):
50
+ return NotImplemented
51
+
52
+ if self.url != other.url:
53
+ return False
54
+
55
+ if self.meta != other.meta:
56
+ return False
57
+
58
+ if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
59
+ return (
60
+ self.transforms.func == other.transforms.func
61
+ and self.transforms.args == other.transforms.args
62
+ and self.transforms.keywords == other.transforms.keywords
63
+ )
64
+ else:
65
+ return self.transforms == other.transforms
66
+
67
+
68
+ class WeightsEnum(Enum):
69
+ """
70
+ This class is the parent class of all model weights. Each model building method receives an optional `weights`
71
+ parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
72
+ `Weights`.
73
+
74
+ Args:
75
+ value (Weights): The data class entry with the weight information.
76
+ """
77
+
78
+ @classmethod
79
+ def verify(cls, obj: Any) -> Any:
80
+ if obj is not None:
81
+ if type(obj) is str:
82
+ obj = cls[obj.replace(cls.__name__ + ".", "")]
83
+ elif not isinstance(obj, cls):
84
+ raise TypeError(
85
+ f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
86
+ )
87
+ return obj
88
+
89
+ def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
90
+ return load_state_dict_from_url(self.url, *args, **kwargs)
91
+
92
+ def __repr__(self) -> str:
93
+ return f"{self.__class__.__name__}.{self._name_}"
94
+
95
+ @property
96
+ def url(self):
97
+ return self.value.url
98
+
99
+ @property
100
+ def transforms(self):
101
+ return self.value.transforms
102
+
103
+ @property
104
+ def meta(self):
105
+ return self.value.meta
106
+
107
+
108
+ def get_weight(name: str) -> WeightsEnum:
109
+ """
110
+ Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
111
+
112
+ Args:
113
+ name (str): The name of the weight enum entry.
114
+
115
+ Returns:
116
+ WeightsEnum: The requested weight enum.
117
+ """
118
+ try:
119
+ enum_name, value_name = name.split(".")
120
+ except ValueError:
121
+ raise ValueError(f"Invalid weight name provided: '{name}'.")
122
+
123
+ base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
124
+ base_module = importlib.import_module(base_module_name)
125
+ model_modules = [base_module] + [
126
+ x[1]
127
+ for x in inspect.getmembers(base_module, inspect.ismodule)
128
+ if x[1].__file__.endswith("__init__.py") # type: ignore[union-attr]
129
+ ]
130
+
131
+ weights_enum = None
132
+ for m in model_modules:
133
+ potential_class = m.__dict__.get(enum_name, None)
134
+ if potential_class is not None and issubclass(potential_class, WeightsEnum):
135
+ weights_enum = potential_class
136
+ break
137
+
138
+ if weights_enum is None:
139
+ raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
140
+
141
+ return weights_enum[value_name]
142
+
143
+
144
+ def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
145
+ """
146
+ Returns the weights enum class associated to the given model.
147
+
148
+ Args:
149
+ name (callable or str): The model builder function or the name under which it is registered.
150
+
151
+ Returns:
152
+ weights_enum (WeightsEnum): The weights enum class associated with the model.
153
+ """
154
+ model = get_model_builder(name) if isinstance(name, str) else name
155
+ return _get_enum_from_fn(model)
156
+
157
+
158
+ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
159
+ """
160
+ Internal method that gets the weight enum of a specific model builder method.
161
+
162
+ Args:
163
+ fn (Callable): The builder method used to create the model.
164
+ Returns:
165
+ WeightsEnum: The requested weight enum.
166
+ """
167
+ sig = signature(fn)
168
+ if "weights" not in sig.parameters:
169
+ raise ValueError("The method is missing the 'weights' argument.")
170
+
171
+ ann = signature(fn).parameters["weights"].annotation
172
+ weights_enum = None
173
+ if isinstance(ann, type) and issubclass(ann, WeightsEnum):
174
+ weights_enum = ann
175
+ else:
176
+ # handle cases like Union[Optional, T]
177
+ # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
178
+ for t in ann.__args__: # type: ignore[union-attr]
179
+ if isinstance(t, type) and issubclass(t, WeightsEnum):
180
+ weights_enum = t
181
+ break
182
+
183
+ if weights_enum is None:
184
+ raise ValueError(
185
+ "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
186
+ )
187
+
188
+ return weights_enum
189
+
190
+
191
+ M = TypeVar("M", bound=nn.Module)
192
+
193
+ BUILTIN_MODELS = {}
194
+
195
+
196
+ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
197
+ def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
198
+ key = name if name is not None else fn.__name__
199
+ if key in BUILTIN_MODELS:
200
+ raise ValueError(f"An entry is already registered under the name '{key}'.")
201
+ BUILTIN_MODELS[key] = fn
202
+ return fn
203
+
204
+ return wrapper
205
+
206
+
207
+ def list_models(
208
+ module: Optional[ModuleType] = None,
209
+ include: Union[Iterable[str], str, None] = None,
210
+ exclude: Union[Iterable[str], str, None] = None,
211
+ ) -> List[str]:
212
+ """
213
+ Returns a list with the names of registered models.
214
+
215
+ Args:
216
+ module (ModuleType, optional): The module from which we want to extract the available models.
217
+ include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
218
+ Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
219
+ wildcards. In case of many filters, the results is the union of individual filters.
220
+ exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
221
+ Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
222
+ wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
223
+
224
+ Returns:
225
+ models (list): A list with the names of available models.
226
+ """
227
+ all_models = {
228
+ k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
229
+ }
230
+ if include:
231
+ models: Set[str] = set()
232
+ if isinstance(include, str):
233
+ include = [include]
234
+ for include_filter in include:
235
+ models = models | set(fnmatch.filter(all_models, include_filter))
236
+ else:
237
+ models = all_models
238
+
239
+ if exclude:
240
+ if isinstance(exclude, str):
241
+ exclude = [exclude]
242
+ for exclude_filter in exclude:
243
+ models = models - set(fnmatch.filter(all_models, exclude_filter))
244
+ return sorted(models)
245
+
246
+
247
+ def get_model_builder(name: str) -> Callable[..., nn.Module]:
248
+ """
249
+ Gets the model name and returns the model builder method.
250
+
251
+ Args:
252
+ name (str): The name under which the model is registered.
253
+
254
+ Returns:
255
+ fn (Callable): The model builder method.
256
+ """
257
+ name = name.lower()
258
+ try:
259
+ fn = BUILTIN_MODELS[name]
260
+ except KeyError:
261
+ raise ValueError(f"Unknown model {name}")
262
+ return fn
263
+
264
+
265
+ def get_model(name: str, **config: Any) -> nn.Module:
266
+ """
267
+ Gets the model name and configuration and returns an instantiated model.
268
+
269
+ Args:
270
+ name (str): The name under which the model is registered.
271
+ **config (Any): parameters passed to the model builder method.
272
+
273
+ Returns:
274
+ model (nn.Module): The initialized model.
275
+ """
276
+ fn = get_model_builder(name)
277
+ return fn(**config)
.venv/lib/python3.11/site-packages/torchvision/models/_meta.py ADDED
@@ -0,0 +1,1554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of the private API. Please do not refer to any variables defined here directly as they will be
3
+ removed on future versions without warning.
4
+ """
5
+
6
+ # This will eventually be replaced with a call at torchvision.datasets.info("imagenet").categories
7
+ _IMAGENET_CATEGORIES = [
8
+ "tench",
9
+ "goldfish",
10
+ "great white shark",
11
+ "tiger shark",
12
+ "hammerhead",
13
+ "electric ray",
14
+ "stingray",
15
+ "cock",
16
+ "hen",
17
+ "ostrich",
18
+ "brambling",
19
+ "goldfinch",
20
+ "house finch",
21
+ "junco",
22
+ "indigo bunting",
23
+ "robin",
24
+ "bulbul",
25
+ "jay",
26
+ "magpie",
27
+ "chickadee",
28
+ "water ouzel",
29
+ "kite",
30
+ "bald eagle",
31
+ "vulture",
32
+ "great grey owl",
33
+ "European fire salamander",
34
+ "common newt",
35
+ "eft",
36
+ "spotted salamander",
37
+ "axolotl",
38
+ "bullfrog",
39
+ "tree frog",
40
+ "tailed frog",
41
+ "loggerhead",
42
+ "leatherback turtle",
43
+ "mud turtle",
44
+ "terrapin",
45
+ "box turtle",
46
+ "banded gecko",
47
+ "common iguana",
48
+ "American chameleon",
49
+ "whiptail",
50
+ "agama",
51
+ "frilled lizard",
52
+ "alligator lizard",
53
+ "Gila monster",
54
+ "green lizard",
55
+ "African chameleon",
56
+ "Komodo dragon",
57
+ "African crocodile",
58
+ "American alligator",
59
+ "triceratops",
60
+ "thunder snake",
61
+ "ringneck snake",
62
+ "hognose snake",
63
+ "green snake",
64
+ "king snake",
65
+ "garter snake",
66
+ "water snake",
67
+ "vine snake",
68
+ "night snake",
69
+ "boa constrictor",
70
+ "rock python",
71
+ "Indian cobra",
72
+ "green mamba",
73
+ "sea snake",
74
+ "horned viper",
75
+ "diamondback",
76
+ "sidewinder",
77
+ "trilobite",
78
+ "harvestman",
79
+ "scorpion",
80
+ "black and gold garden spider",
81
+ "barn spider",
82
+ "garden spider",
83
+ "black widow",
84
+ "tarantula",
85
+ "wolf spider",
86
+ "tick",
87
+ "centipede",
88
+ "black grouse",
89
+ "ptarmigan",
90
+ "ruffed grouse",
91
+ "prairie chicken",
92
+ "peacock",
93
+ "quail",
94
+ "partridge",
95
+ "African grey",
96
+ "macaw",
97
+ "sulphur-crested cockatoo",
98
+ "lorikeet",
99
+ "coucal",
100
+ "bee eater",
101
+ "hornbill",
102
+ "hummingbird",
103
+ "jacamar",
104
+ "toucan",
105
+ "drake",
106
+ "red-breasted merganser",
107
+ "goose",
108
+ "black swan",
109
+ "tusker",
110
+ "echidna",
111
+ "platypus",
112
+ "wallaby",
113
+ "koala",
114
+ "wombat",
115
+ "jellyfish",
116
+ "sea anemone",
117
+ "brain coral",
118
+ "flatworm",
119
+ "nematode",
120
+ "conch",
121
+ "snail",
122
+ "slug",
123
+ "sea slug",
124
+ "chiton",
125
+ "chambered nautilus",
126
+ "Dungeness crab",
127
+ "rock crab",
128
+ "fiddler crab",
129
+ "king crab",
130
+ "American lobster",
131
+ "spiny lobster",
132
+ "crayfish",
133
+ "hermit crab",
134
+ "isopod",
135
+ "white stork",
136
+ "black stork",
137
+ "spoonbill",
138
+ "flamingo",
139
+ "little blue heron",
140
+ "American egret",
141
+ "bittern",
142
+ "crane bird",
143
+ "limpkin",
144
+ "European gallinule",
145
+ "American coot",
146
+ "bustard",
147
+ "ruddy turnstone",
148
+ "red-backed sandpiper",
149
+ "redshank",
150
+ "dowitcher",
151
+ "oystercatcher",
152
+ "pelican",
153
+ "king penguin",
154
+ "albatross",
155
+ "grey whale",
156
+ "killer whale",
157
+ "dugong",
158
+ "sea lion",
159
+ "Chihuahua",
160
+ "Japanese spaniel",
161
+ "Maltese dog",
162
+ "Pekinese",
163
+ "Shih-Tzu",
164
+ "Blenheim spaniel",
165
+ "papillon",
166
+ "toy terrier",
167
+ "Rhodesian ridgeback",
168
+ "Afghan hound",
169
+ "basset",
170
+ "beagle",
171
+ "bloodhound",
172
+ "bluetick",
173
+ "black-and-tan coonhound",
174
+ "Walker hound",
175
+ "English foxhound",
176
+ "redbone",
177
+ "borzoi",
178
+ "Irish wolfhound",
179
+ "Italian greyhound",
180
+ "whippet",
181
+ "Ibizan hound",
182
+ "Norwegian elkhound",
183
+ "otterhound",
184
+ "Saluki",
185
+ "Scottish deerhound",
186
+ "Weimaraner",
187
+ "Staffordshire bullterrier",
188
+ "American Staffordshire terrier",
189
+ "Bedlington terrier",
190
+ "Border terrier",
191
+ "Kerry blue terrier",
192
+ "Irish terrier",
193
+ "Norfolk terrier",
194
+ "Norwich terrier",
195
+ "Yorkshire terrier",
196
+ "wire-haired fox terrier",
197
+ "Lakeland terrier",
198
+ "Sealyham terrier",
199
+ "Airedale",
200
+ "cairn",
201
+ "Australian terrier",
202
+ "Dandie Dinmont",
203
+ "Boston bull",
204
+ "miniature schnauzer",
205
+ "giant schnauzer",
206
+ "standard schnauzer",
207
+ "Scotch terrier",
208
+ "Tibetan terrier",
209
+ "silky terrier",
210
+ "soft-coated wheaten terrier",
211
+ "West Highland white terrier",
212
+ "Lhasa",
213
+ "flat-coated retriever",
214
+ "curly-coated retriever",
215
+ "golden retriever",
216
+ "Labrador retriever",
217
+ "Chesapeake Bay retriever",
218
+ "German short-haired pointer",
219
+ "vizsla",
220
+ "English setter",
221
+ "Irish setter",
222
+ "Gordon setter",
223
+ "Brittany spaniel",
224
+ "clumber",
225
+ "English springer",
226
+ "Welsh springer spaniel",
227
+ "cocker spaniel",
228
+ "Sussex spaniel",
229
+ "Irish water spaniel",
230
+ "kuvasz",
231
+ "schipperke",
232
+ "groenendael",
233
+ "malinois",
234
+ "briard",
235
+ "kelpie",
236
+ "komondor",
237
+ "Old English sheepdog",
238
+ "Shetland sheepdog",
239
+ "collie",
240
+ "Border collie",
241
+ "Bouvier des Flandres",
242
+ "Rottweiler",
243
+ "German shepherd",
244
+ "Doberman",
245
+ "miniature pinscher",
246
+ "Greater Swiss Mountain dog",
247
+ "Bernese mountain dog",
248
+ "Appenzeller",
249
+ "EntleBucher",
250
+ "boxer",
251
+ "bull mastiff",
252
+ "Tibetan mastiff",
253
+ "French bulldog",
254
+ "Great Dane",
255
+ "Saint Bernard",
256
+ "Eskimo dog",
257
+ "malamute",
258
+ "Siberian husky",
259
+ "dalmatian",
260
+ "affenpinscher",
261
+ "basenji",
262
+ "pug",
263
+ "Leonberg",
264
+ "Newfoundland",
265
+ "Great Pyrenees",
266
+ "Samoyed",
267
+ "Pomeranian",
268
+ "chow",
269
+ "keeshond",
270
+ "Brabancon griffon",
271
+ "Pembroke",
272
+ "Cardigan",
273
+ "toy poodle",
274
+ "miniature poodle",
275
+ "standard poodle",
276
+ "Mexican hairless",
277
+ "timber wolf",
278
+ "white wolf",
279
+ "red wolf",
280
+ "coyote",
281
+ "dingo",
282
+ "dhole",
283
+ "African hunting dog",
284
+ "hyena",
285
+ "red fox",
286
+ "kit fox",
287
+ "Arctic fox",
288
+ "grey fox",
289
+ "tabby",
290
+ "tiger cat",
291
+ "Persian cat",
292
+ "Siamese cat",
293
+ "Egyptian cat",
294
+ "cougar",
295
+ "lynx",
296
+ "leopard",
297
+ "snow leopard",
298
+ "jaguar",
299
+ "lion",
300
+ "tiger",
301
+ "cheetah",
302
+ "brown bear",
303
+ "American black bear",
304
+ "ice bear",
305
+ "sloth bear",
306
+ "mongoose",
307
+ "meerkat",
308
+ "tiger beetle",
309
+ "ladybug",
310
+ "ground beetle",
311
+ "long-horned beetle",
312
+ "leaf beetle",
313
+ "dung beetle",
314
+ "rhinoceros beetle",
315
+ "weevil",
316
+ "fly",
317
+ "bee",
318
+ "ant",
319
+ "grasshopper",
320
+ "cricket",
321
+ "walking stick",
322
+ "cockroach",
323
+ "mantis",
324
+ "cicada",
325
+ "leafhopper",
326
+ "lacewing",
327
+ "dragonfly",
328
+ "damselfly",
329
+ "admiral",
330
+ "ringlet",
331
+ "monarch",
332
+ "cabbage butterfly",
333
+ "sulphur butterfly",
334
+ "lycaenid",
335
+ "starfish",
336
+ "sea urchin",
337
+ "sea cucumber",
338
+ "wood rabbit",
339
+ "hare",
340
+ "Angora",
341
+ "hamster",
342
+ "porcupine",
343
+ "fox squirrel",
344
+ "marmot",
345
+ "beaver",
346
+ "guinea pig",
347
+ "sorrel",
348
+ "zebra",
349
+ "hog",
350
+ "wild boar",
351
+ "warthog",
352
+ "hippopotamus",
353
+ "ox",
354
+ "water buffalo",
355
+ "bison",
356
+ "ram",
357
+ "bighorn",
358
+ "ibex",
359
+ "hartebeest",
360
+ "impala",
361
+ "gazelle",
362
+ "Arabian camel",
363
+ "llama",
364
+ "weasel",
365
+ "mink",
366
+ "polecat",
367
+ "black-footed ferret",
368
+ "otter",
369
+ "skunk",
370
+ "badger",
371
+ "armadillo",
372
+ "three-toed sloth",
373
+ "orangutan",
374
+ "gorilla",
375
+ "chimpanzee",
376
+ "gibbon",
377
+ "siamang",
378
+ "guenon",
379
+ "patas",
380
+ "baboon",
381
+ "macaque",
382
+ "langur",
383
+ "colobus",
384
+ "proboscis monkey",
385
+ "marmoset",
386
+ "capuchin",
387
+ "howler monkey",
388
+ "titi",
389
+ "spider monkey",
390
+ "squirrel monkey",
391
+ "Madagascar cat",
392
+ "indri",
393
+ "Indian elephant",
394
+ "African elephant",
395
+ "lesser panda",
396
+ "giant panda",
397
+ "barracouta",
398
+ "eel",
399
+ "coho",
400
+ "rock beauty",
401
+ "anemone fish",
402
+ "sturgeon",
403
+ "gar",
404
+ "lionfish",
405
+ "puffer",
406
+ "abacus",
407
+ "abaya",
408
+ "academic gown",
409
+ "accordion",
410
+ "acoustic guitar",
411
+ "aircraft carrier",
412
+ "airliner",
413
+ "airship",
414
+ "altar",
415
+ "ambulance",
416
+ "amphibian",
417
+ "analog clock",
418
+ "apiary",
419
+ "apron",
420
+ "ashcan",
421
+ "assault rifle",
422
+ "backpack",
423
+ "bakery",
424
+ "balance beam",
425
+ "balloon",
426
+ "ballpoint",
427
+ "Band Aid",
428
+ "banjo",
429
+ "bannister",
430
+ "barbell",
431
+ "barber chair",
432
+ "barbershop",
433
+ "barn",
434
+ "barometer",
435
+ "barrel",
436
+ "barrow",
437
+ "baseball",
438
+ "basketball",
439
+ "bassinet",
440
+ "bassoon",
441
+ "bathing cap",
442
+ "bath towel",
443
+ "bathtub",
444
+ "beach wagon",
445
+ "beacon",
446
+ "beaker",
447
+ "bearskin",
448
+ "beer bottle",
449
+ "beer glass",
450
+ "bell cote",
451
+ "bib",
452
+ "bicycle-built-for-two",
453
+ "bikini",
454
+ "binder",
455
+ "binoculars",
456
+ "birdhouse",
457
+ "boathouse",
458
+ "bobsled",
459
+ "bolo tie",
460
+ "bonnet",
461
+ "bookcase",
462
+ "bookshop",
463
+ "bottlecap",
464
+ "bow",
465
+ "bow tie",
466
+ "brass",
467
+ "brassiere",
468
+ "breakwater",
469
+ "breastplate",
470
+ "broom",
471
+ "bucket",
472
+ "buckle",
473
+ "bulletproof vest",
474
+ "bullet train",
475
+ "butcher shop",
476
+ "cab",
477
+ "caldron",
478
+ "candle",
479
+ "cannon",
480
+ "canoe",
481
+ "can opener",
482
+ "cardigan",
483
+ "car mirror",
484
+ "carousel",
485
+ "carpenter's kit",
486
+ "carton",
487
+ "car wheel",
488
+ "cash machine",
489
+ "cassette",
490
+ "cassette player",
491
+ "castle",
492
+ "catamaran",
493
+ "CD player",
494
+ "cello",
495
+ "cellular telephone",
496
+ "chain",
497
+ "chainlink fence",
498
+ "chain mail",
499
+ "chain saw",
500
+ "chest",
501
+ "chiffonier",
502
+ "chime",
503
+ "china cabinet",
504
+ "Christmas stocking",
505
+ "church",
506
+ "cinema",
507
+ "cleaver",
508
+ "cliff dwelling",
509
+ "cloak",
510
+ "clog",
511
+ "cocktail shaker",
512
+ "coffee mug",
513
+ "coffeepot",
514
+ "coil",
515
+ "combination lock",
516
+ "computer keyboard",
517
+ "confectionery",
518
+ "container ship",
519
+ "convertible",
520
+ "corkscrew",
521
+ "cornet",
522
+ "cowboy boot",
523
+ "cowboy hat",
524
+ "cradle",
525
+ "crane",
526
+ "crash helmet",
527
+ "crate",
528
+ "crib",
529
+ "Crock Pot",
530
+ "croquet ball",
531
+ "crutch",
532
+ "cuirass",
533
+ "dam",
534
+ "desk",
535
+ "desktop computer",
536
+ "dial telephone",
537
+ "diaper",
538
+ "digital clock",
539
+ "digital watch",
540
+ "dining table",
541
+ "dishrag",
542
+ "dishwasher",
543
+ "disk brake",
544
+ "dock",
545
+ "dogsled",
546
+ "dome",
547
+ "doormat",
548
+ "drilling platform",
549
+ "drum",
550
+ "drumstick",
551
+ "dumbbell",
552
+ "Dutch oven",
553
+ "electric fan",
554
+ "electric guitar",
555
+ "electric locomotive",
556
+ "entertainment center",
557
+ "envelope",
558
+ "espresso maker",
559
+ "face powder",
560
+ "feather boa",
561
+ "file",
562
+ "fireboat",
563
+ "fire engine",
564
+ "fire screen",
565
+ "flagpole",
566
+ "flute",
567
+ "folding chair",
568
+ "football helmet",
569
+ "forklift",
570
+ "fountain",
571
+ "fountain pen",
572
+ "four-poster",
573
+ "freight car",
574
+ "French horn",
575
+ "frying pan",
576
+ "fur coat",
577
+ "garbage truck",
578
+ "gasmask",
579
+ "gas pump",
580
+ "goblet",
581
+ "go-kart",
582
+ "golf ball",
583
+ "golfcart",
584
+ "gondola",
585
+ "gong",
586
+ "gown",
587
+ "grand piano",
588
+ "greenhouse",
589
+ "grille",
590
+ "grocery store",
591
+ "guillotine",
592
+ "hair slide",
593
+ "hair spray",
594
+ "half track",
595
+ "hammer",
596
+ "hamper",
597
+ "hand blower",
598
+ "hand-held computer",
599
+ "handkerchief",
600
+ "hard disc",
601
+ "harmonica",
602
+ "harp",
603
+ "harvester",
604
+ "hatchet",
605
+ "holster",
606
+ "home theater",
607
+ "honeycomb",
608
+ "hook",
609
+ "hoopskirt",
610
+ "horizontal bar",
611
+ "horse cart",
612
+ "hourglass",
613
+ "iPod",
614
+ "iron",
615
+ "jack-o'-lantern",
616
+ "jean",
617
+ "jeep",
618
+ "jersey",
619
+ "jigsaw puzzle",
620
+ "jinrikisha",
621
+ "joystick",
622
+ "kimono",
623
+ "knee pad",
624
+ "knot",
625
+ "lab coat",
626
+ "ladle",
627
+ "lampshade",
628
+ "laptop",
629
+ "lawn mower",
630
+ "lens cap",
631
+ "letter opener",
632
+ "library",
633
+ "lifeboat",
634
+ "lighter",
635
+ "limousine",
636
+ "liner",
637
+ "lipstick",
638
+ "Loafer",
639
+ "lotion",
640
+ "loudspeaker",
641
+ "loupe",
642
+ "lumbermill",
643
+ "magnetic compass",
644
+ "mailbag",
645
+ "mailbox",
646
+ "maillot",
647
+ "maillot tank suit",
648
+ "manhole cover",
649
+ "maraca",
650
+ "marimba",
651
+ "mask",
652
+ "matchstick",
653
+ "maypole",
654
+ "maze",
655
+ "measuring cup",
656
+ "medicine chest",
657
+ "megalith",
658
+ "microphone",
659
+ "microwave",
660
+ "military uniform",
661
+ "milk can",
662
+ "minibus",
663
+ "miniskirt",
664
+ "minivan",
665
+ "missile",
666
+ "mitten",
667
+ "mixing bowl",
668
+ "mobile home",
669
+ "Model T",
670
+ "modem",
671
+ "monastery",
672
+ "monitor",
673
+ "moped",
674
+ "mortar",
675
+ "mortarboard",
676
+ "mosque",
677
+ "mosquito net",
678
+ "motor scooter",
679
+ "mountain bike",
680
+ "mountain tent",
681
+ "mouse",
682
+ "mousetrap",
683
+ "moving van",
684
+ "muzzle",
685
+ "nail",
686
+ "neck brace",
687
+ "necklace",
688
+ "nipple",
689
+ "notebook",
690
+ "obelisk",
691
+ "oboe",
692
+ "ocarina",
693
+ "odometer",
694
+ "oil filter",
695
+ "organ",
696
+ "oscilloscope",
697
+ "overskirt",
698
+ "oxcart",
699
+ "oxygen mask",
700
+ "packet",
701
+ "paddle",
702
+ "paddlewheel",
703
+ "padlock",
704
+ "paintbrush",
705
+ "pajama",
706
+ "palace",
707
+ "panpipe",
708
+ "paper towel",
709
+ "parachute",
710
+ "parallel bars",
711
+ "park bench",
712
+ "parking meter",
713
+ "passenger car",
714
+ "patio",
715
+ "pay-phone",
716
+ "pedestal",
717
+ "pencil box",
718
+ "pencil sharpener",
719
+ "perfume",
720
+ "Petri dish",
721
+ "photocopier",
722
+ "pick",
723
+ "pickelhaube",
724
+ "picket fence",
725
+ "pickup",
726
+ "pier",
727
+ "piggy bank",
728
+ "pill bottle",
729
+ "pillow",
730
+ "ping-pong ball",
731
+ "pinwheel",
732
+ "pirate",
733
+ "pitcher",
734
+ "plane",
735
+ "planetarium",
736
+ "plastic bag",
737
+ "plate rack",
738
+ "plow",
739
+ "plunger",
740
+ "Polaroid camera",
741
+ "pole",
742
+ "police van",
743
+ "poncho",
744
+ "pool table",
745
+ "pop bottle",
746
+ "pot",
747
+ "potter's wheel",
748
+ "power drill",
749
+ "prayer rug",
750
+ "printer",
751
+ "prison",
752
+ "projectile",
753
+ "projector",
754
+ "puck",
755
+ "punching bag",
756
+ "purse",
757
+ "quill",
758
+ "quilt",
759
+ "racer",
760
+ "racket",
761
+ "radiator",
762
+ "radio",
763
+ "radio telescope",
764
+ "rain barrel",
765
+ "recreational vehicle",
766
+ "reel",
767
+ "reflex camera",
768
+ "refrigerator",
769
+ "remote control",
770
+ "restaurant",
771
+ "revolver",
772
+ "rifle",
773
+ "rocking chair",
774
+ "rotisserie",
775
+ "rubber eraser",
776
+ "rugby ball",
777
+ "rule",
778
+ "running shoe",
779
+ "safe",
780
+ "safety pin",
781
+ "saltshaker",
782
+ "sandal",
783
+ "sarong",
784
+ "sax",
785
+ "scabbard",
786
+ "scale",
787
+ "school bus",
788
+ "schooner",
789
+ "scoreboard",
790
+ "screen",
791
+ "screw",
792
+ "screwdriver",
793
+ "seat belt",
794
+ "sewing machine",
795
+ "shield",
796
+ "shoe shop",
797
+ "shoji",
798
+ "shopping basket",
799
+ "shopping cart",
800
+ "shovel",
801
+ "shower cap",
802
+ "shower curtain",
803
+ "ski",
804
+ "ski mask",
805
+ "sleeping bag",
806
+ "slide rule",
807
+ "sliding door",
808
+ "slot",
809
+ "snorkel",
810
+ "snowmobile",
811
+ "snowplow",
812
+ "soap dispenser",
813
+ "soccer ball",
814
+ "sock",
815
+ "solar dish",
816
+ "sombrero",
817
+ "soup bowl",
818
+ "space bar",
819
+ "space heater",
820
+ "space shuttle",
821
+ "spatula",
822
+ "speedboat",
823
+ "spider web",
824
+ "spindle",
825
+ "sports car",
826
+ "spotlight",
827
+ "stage",
828
+ "steam locomotive",
829
+ "steel arch bridge",
830
+ "steel drum",
831
+ "stethoscope",
832
+ "stole",
833
+ "stone wall",
834
+ "stopwatch",
835
+ "stove",
836
+ "strainer",
837
+ "streetcar",
838
+ "stretcher",
839
+ "studio couch",
840
+ "stupa",
841
+ "submarine",
842
+ "suit",
843
+ "sundial",
844
+ "sunglass",
845
+ "sunglasses",
846
+ "sunscreen",
847
+ "suspension bridge",
848
+ "swab",
849
+ "sweatshirt",
850
+ "swimming trunks",
851
+ "swing",
852
+ "switch",
853
+ "syringe",
854
+ "table lamp",
855
+ "tank",
856
+ "tape player",
857
+ "teapot",
858
+ "teddy",
859
+ "television",
860
+ "tennis ball",
861
+ "thatch",
862
+ "theater curtain",
863
+ "thimble",
864
+ "thresher",
865
+ "throne",
866
+ "tile roof",
867
+ "toaster",
868
+ "tobacco shop",
869
+ "toilet seat",
870
+ "torch",
871
+ "totem pole",
872
+ "tow truck",
873
+ "toyshop",
874
+ "tractor",
875
+ "trailer truck",
876
+ "tray",
877
+ "trench coat",
878
+ "tricycle",
879
+ "trimaran",
880
+ "tripod",
881
+ "triumphal arch",
882
+ "trolleybus",
883
+ "trombone",
884
+ "tub",
885
+ "turnstile",
886
+ "typewriter keyboard",
887
+ "umbrella",
888
+ "unicycle",
889
+ "upright",
890
+ "vacuum",
891
+ "vase",
892
+ "vault",
893
+ "velvet",
894
+ "vending machine",
895
+ "vestment",
896
+ "viaduct",
897
+ "violin",
898
+ "volleyball",
899
+ "waffle iron",
900
+ "wall clock",
901
+ "wallet",
902
+ "wardrobe",
903
+ "warplane",
904
+ "washbasin",
905
+ "washer",
906
+ "water bottle",
907
+ "water jug",
908
+ "water tower",
909
+ "whiskey jug",
910
+ "whistle",
911
+ "wig",
912
+ "window screen",
913
+ "window shade",
914
+ "Windsor tie",
915
+ "wine bottle",
916
+ "wing",
917
+ "wok",
918
+ "wooden spoon",
919
+ "wool",
920
+ "worm fence",
921
+ "wreck",
922
+ "yawl",
923
+ "yurt",
924
+ "web site",
925
+ "comic book",
926
+ "crossword puzzle",
927
+ "street sign",
928
+ "traffic light",
929
+ "book jacket",
930
+ "menu",
931
+ "plate",
932
+ "guacamole",
933
+ "consomme",
934
+ "hot pot",
935
+ "trifle",
936
+ "ice cream",
937
+ "ice lolly",
938
+ "French loaf",
939
+ "bagel",
940
+ "pretzel",
941
+ "cheeseburger",
942
+ "hotdog",
943
+ "mashed potato",
944
+ "head cabbage",
945
+ "broccoli",
946
+ "cauliflower",
947
+ "zucchini",
948
+ "spaghetti squash",
949
+ "acorn squash",
950
+ "butternut squash",
951
+ "cucumber",
952
+ "artichoke",
953
+ "bell pepper",
954
+ "cardoon",
955
+ "mushroom",
956
+ "Granny Smith",
957
+ "strawberry",
958
+ "orange",
959
+ "lemon",
960
+ "fig",
961
+ "pineapple",
962
+ "banana",
963
+ "jackfruit",
964
+ "custard apple",
965
+ "pomegranate",
966
+ "hay",
967
+ "carbonara",
968
+ "chocolate sauce",
969
+ "dough",
970
+ "meat loaf",
971
+ "pizza",
972
+ "potpie",
973
+ "burrito",
974
+ "red wine",
975
+ "espresso",
976
+ "cup",
977
+ "eggnog",
978
+ "alp",
979
+ "bubble",
980
+ "cliff",
981
+ "coral reef",
982
+ "geyser",
983
+ "lakeside",
984
+ "promontory",
985
+ "sandbar",
986
+ "seashore",
987
+ "valley",
988
+ "volcano",
989
+ "ballplayer",
990
+ "groom",
991
+ "scuba diver",
992
+ "rapeseed",
993
+ "daisy",
994
+ "yellow lady's slipper",
995
+ "corn",
996
+ "acorn",
997
+ "hip",
998
+ "buckeye",
999
+ "coral fungus",
1000
+ "agaric",
1001
+ "gyromitra",
1002
+ "stinkhorn",
1003
+ "earthstar",
1004
+ "hen-of-the-woods",
1005
+ "bolete",
1006
+ "ear",
1007
+ "toilet tissue",
1008
+ ]
1009
+
1010
+ # To be replaced with torchvision.datasets.info("coco").categories
1011
+ _COCO_CATEGORIES = [
1012
+ "__background__",
1013
+ "person",
1014
+ "bicycle",
1015
+ "car",
1016
+ "motorcycle",
1017
+ "airplane",
1018
+ "bus",
1019
+ "train",
1020
+ "truck",
1021
+ "boat",
1022
+ "traffic light",
1023
+ "fire hydrant",
1024
+ "N/A",
1025
+ "stop sign",
1026
+ "parking meter",
1027
+ "bench",
1028
+ "bird",
1029
+ "cat",
1030
+ "dog",
1031
+ "horse",
1032
+ "sheep",
1033
+ "cow",
1034
+ "elephant",
1035
+ "bear",
1036
+ "zebra",
1037
+ "giraffe",
1038
+ "N/A",
1039
+ "backpack",
1040
+ "umbrella",
1041
+ "N/A",
1042
+ "N/A",
1043
+ "handbag",
1044
+ "tie",
1045
+ "suitcase",
1046
+ "frisbee",
1047
+ "skis",
1048
+ "snowboard",
1049
+ "sports ball",
1050
+ "kite",
1051
+ "baseball bat",
1052
+ "baseball glove",
1053
+ "skateboard",
1054
+ "surfboard",
1055
+ "tennis racket",
1056
+ "bottle",
1057
+ "N/A",
1058
+ "wine glass",
1059
+ "cup",
1060
+ "fork",
1061
+ "knife",
1062
+ "spoon",
1063
+ "bowl",
1064
+ "banana",
1065
+ "apple",
1066
+ "sandwich",
1067
+ "orange",
1068
+ "broccoli",
1069
+ "carrot",
1070
+ "hot dog",
1071
+ "pizza",
1072
+ "donut",
1073
+ "cake",
1074
+ "chair",
1075
+ "couch",
1076
+ "potted plant",
1077
+ "bed",
1078
+ "N/A",
1079
+ "dining table",
1080
+ "N/A",
1081
+ "N/A",
1082
+ "toilet",
1083
+ "N/A",
1084
+ "tv",
1085
+ "laptop",
1086
+ "mouse",
1087
+ "remote",
1088
+ "keyboard",
1089
+ "cell phone",
1090
+ "microwave",
1091
+ "oven",
1092
+ "toaster",
1093
+ "sink",
1094
+ "refrigerator",
1095
+ "N/A",
1096
+ "book",
1097
+ "clock",
1098
+ "vase",
1099
+ "scissors",
1100
+ "teddy bear",
1101
+ "hair drier",
1102
+ "toothbrush",
1103
+ ]
1104
+
1105
+ # To be replaced with torchvision.datasets.info("coco_kp")
1106
+ _COCO_PERSON_CATEGORIES = ["no person", "person"]
1107
+ _COCO_PERSON_KEYPOINT_NAMES = [
1108
+ "nose",
1109
+ "left_eye",
1110
+ "right_eye",
1111
+ "left_ear",
1112
+ "right_ear",
1113
+ "left_shoulder",
1114
+ "right_shoulder",
1115
+ "left_elbow",
1116
+ "right_elbow",
1117
+ "left_wrist",
1118
+ "right_wrist",
1119
+ "left_hip",
1120
+ "right_hip",
1121
+ "left_knee",
1122
+ "right_knee",
1123
+ "left_ankle",
1124
+ "right_ankle",
1125
+ ]
1126
+
1127
+ # To be replaced with torchvision.datasets.info("voc").categories
1128
+ _VOC_CATEGORIES = [
1129
+ "__background__",
1130
+ "aeroplane",
1131
+ "bicycle",
1132
+ "bird",
1133
+ "boat",
1134
+ "bottle",
1135
+ "bus",
1136
+ "car",
1137
+ "cat",
1138
+ "chair",
1139
+ "cow",
1140
+ "diningtable",
1141
+ "dog",
1142
+ "horse",
1143
+ "motorbike",
1144
+ "person",
1145
+ "pottedplant",
1146
+ "sheep",
1147
+ "sofa",
1148
+ "train",
1149
+ "tvmonitor",
1150
+ ]
1151
+
1152
+ # To be replaced with torchvision.datasets.info("kinetics400").categories
1153
+ _KINETICS400_CATEGORIES = [
1154
+ "abseiling",
1155
+ "air drumming",
1156
+ "answering questions",
1157
+ "applauding",
1158
+ "applying cream",
1159
+ "archery",
1160
+ "arm wrestling",
1161
+ "arranging flowers",
1162
+ "assembling computer",
1163
+ "auctioning",
1164
+ "baby waking up",
1165
+ "baking cookies",
1166
+ "balloon blowing",
1167
+ "bandaging",
1168
+ "barbequing",
1169
+ "bartending",
1170
+ "beatboxing",
1171
+ "bee keeping",
1172
+ "belly dancing",
1173
+ "bench pressing",
1174
+ "bending back",
1175
+ "bending metal",
1176
+ "biking through snow",
1177
+ "blasting sand",
1178
+ "blowing glass",
1179
+ "blowing leaves",
1180
+ "blowing nose",
1181
+ "blowing out candles",
1182
+ "bobsledding",
1183
+ "bookbinding",
1184
+ "bouncing on trampoline",
1185
+ "bowling",
1186
+ "braiding hair",
1187
+ "breading or breadcrumbing",
1188
+ "breakdancing",
1189
+ "brush painting",
1190
+ "brushing hair",
1191
+ "brushing teeth",
1192
+ "building cabinet",
1193
+ "building shed",
1194
+ "bungee jumping",
1195
+ "busking",
1196
+ "canoeing or kayaking",
1197
+ "capoeira",
1198
+ "carrying baby",
1199
+ "cartwheeling",
1200
+ "carving pumpkin",
1201
+ "catching fish",
1202
+ "catching or throwing baseball",
1203
+ "catching or throwing frisbee",
1204
+ "catching or throwing softball",
1205
+ "celebrating",
1206
+ "changing oil",
1207
+ "changing wheel",
1208
+ "checking tires",
1209
+ "cheerleading",
1210
+ "chopping wood",
1211
+ "clapping",
1212
+ "clay pottery making",
1213
+ "clean and jerk",
1214
+ "cleaning floor",
1215
+ "cleaning gutters",
1216
+ "cleaning pool",
1217
+ "cleaning shoes",
1218
+ "cleaning toilet",
1219
+ "cleaning windows",
1220
+ "climbing a rope",
1221
+ "climbing ladder",
1222
+ "climbing tree",
1223
+ "contact juggling",
1224
+ "cooking chicken",
1225
+ "cooking egg",
1226
+ "cooking on campfire",
1227
+ "cooking sausages",
1228
+ "counting money",
1229
+ "country line dancing",
1230
+ "cracking neck",
1231
+ "crawling baby",
1232
+ "crossing river",
1233
+ "crying",
1234
+ "curling hair",
1235
+ "cutting nails",
1236
+ "cutting pineapple",
1237
+ "cutting watermelon",
1238
+ "dancing ballet",
1239
+ "dancing charleston",
1240
+ "dancing gangnam style",
1241
+ "dancing macarena",
1242
+ "deadlifting",
1243
+ "decorating the christmas tree",
1244
+ "digging",
1245
+ "dining",
1246
+ "disc golfing",
1247
+ "diving cliff",
1248
+ "dodgeball",
1249
+ "doing aerobics",
1250
+ "doing laundry",
1251
+ "doing nails",
1252
+ "drawing",
1253
+ "dribbling basketball",
1254
+ "drinking",
1255
+ "drinking beer",
1256
+ "drinking shots",
1257
+ "driving car",
1258
+ "driving tractor",
1259
+ "drop kicking",
1260
+ "drumming fingers",
1261
+ "dunking basketball",
1262
+ "dying hair",
1263
+ "eating burger",
1264
+ "eating cake",
1265
+ "eating carrots",
1266
+ "eating chips",
1267
+ "eating doughnuts",
1268
+ "eating hotdog",
1269
+ "eating ice cream",
1270
+ "eating spaghetti",
1271
+ "eating watermelon",
1272
+ "egg hunting",
1273
+ "exercising arm",
1274
+ "exercising with an exercise ball",
1275
+ "extinguishing fire",
1276
+ "faceplanting",
1277
+ "feeding birds",
1278
+ "feeding fish",
1279
+ "feeding goats",
1280
+ "filling eyebrows",
1281
+ "finger snapping",
1282
+ "fixing hair",
1283
+ "flipping pancake",
1284
+ "flying kite",
1285
+ "folding clothes",
1286
+ "folding napkins",
1287
+ "folding paper",
1288
+ "front raises",
1289
+ "frying vegetables",
1290
+ "garbage collecting",
1291
+ "gargling",
1292
+ "getting a haircut",
1293
+ "getting a tattoo",
1294
+ "giving or receiving award",
1295
+ "golf chipping",
1296
+ "golf driving",
1297
+ "golf putting",
1298
+ "grinding meat",
1299
+ "grooming dog",
1300
+ "grooming horse",
1301
+ "gymnastics tumbling",
1302
+ "hammer throw",
1303
+ "headbanging",
1304
+ "headbutting",
1305
+ "high jump",
1306
+ "high kick",
1307
+ "hitting baseball",
1308
+ "hockey stop",
1309
+ "holding snake",
1310
+ "hopscotch",
1311
+ "hoverboarding",
1312
+ "hugging",
1313
+ "hula hooping",
1314
+ "hurdling",
1315
+ "hurling (sport)",
1316
+ "ice climbing",
1317
+ "ice fishing",
1318
+ "ice skating",
1319
+ "ironing",
1320
+ "javelin throw",
1321
+ "jetskiing",
1322
+ "jogging",
1323
+ "juggling balls",
1324
+ "juggling fire",
1325
+ "juggling soccer ball",
1326
+ "jumping into pool",
1327
+ "jumpstyle dancing",
1328
+ "kicking field goal",
1329
+ "kicking soccer ball",
1330
+ "kissing",
1331
+ "kitesurfing",
1332
+ "knitting",
1333
+ "krumping",
1334
+ "laughing",
1335
+ "laying bricks",
1336
+ "long jump",
1337
+ "lunge",
1338
+ "making a cake",
1339
+ "making a sandwich",
1340
+ "making bed",
1341
+ "making jewelry",
1342
+ "making pizza",
1343
+ "making snowman",
1344
+ "making sushi",
1345
+ "making tea",
1346
+ "marching",
1347
+ "massaging back",
1348
+ "massaging feet",
1349
+ "massaging legs",
1350
+ "massaging person's head",
1351
+ "milking cow",
1352
+ "mopping floor",
1353
+ "motorcycling",
1354
+ "moving furniture",
1355
+ "mowing lawn",
1356
+ "news anchoring",
1357
+ "opening bottle",
1358
+ "opening present",
1359
+ "paragliding",
1360
+ "parasailing",
1361
+ "parkour",
1362
+ "passing American football (in game)",
1363
+ "passing American football (not in game)",
1364
+ "peeling apples",
1365
+ "peeling potatoes",
1366
+ "petting animal (not cat)",
1367
+ "petting cat",
1368
+ "picking fruit",
1369
+ "planting trees",
1370
+ "plastering",
1371
+ "playing accordion",
1372
+ "playing badminton",
1373
+ "playing bagpipes",
1374
+ "playing basketball",
1375
+ "playing bass guitar",
1376
+ "playing cards",
1377
+ "playing cello",
1378
+ "playing chess",
1379
+ "playing clarinet",
1380
+ "playing controller",
1381
+ "playing cricket",
1382
+ "playing cymbals",
1383
+ "playing didgeridoo",
1384
+ "playing drums",
1385
+ "playing flute",
1386
+ "playing guitar",
1387
+ "playing harmonica",
1388
+ "playing harp",
1389
+ "playing ice hockey",
1390
+ "playing keyboard",
1391
+ "playing kickball",
1392
+ "playing monopoly",
1393
+ "playing organ",
1394
+ "playing paintball",
1395
+ "playing piano",
1396
+ "playing poker",
1397
+ "playing recorder",
1398
+ "playing saxophone",
1399
+ "playing squash or racquetball",
1400
+ "playing tennis",
1401
+ "playing trombone",
1402
+ "playing trumpet",
1403
+ "playing ukulele",
1404
+ "playing violin",
1405
+ "playing volleyball",
1406
+ "playing xylophone",
1407
+ "pole vault",
1408
+ "presenting weather forecast",
1409
+ "pull ups",
1410
+ "pumping fist",
1411
+ "pumping gas",
1412
+ "punching bag",
1413
+ "punching person (boxing)",
1414
+ "push up",
1415
+ "pushing car",
1416
+ "pushing cart",
1417
+ "pushing wheelchair",
1418
+ "reading book",
1419
+ "reading newspaper",
1420
+ "recording music",
1421
+ "riding a bike",
1422
+ "riding camel",
1423
+ "riding elephant",
1424
+ "riding mechanical bull",
1425
+ "riding mountain bike",
1426
+ "riding mule",
1427
+ "riding or walking with horse",
1428
+ "riding scooter",
1429
+ "riding unicycle",
1430
+ "ripping paper",
1431
+ "robot dancing",
1432
+ "rock climbing",
1433
+ "rock scissors paper",
1434
+ "roller skating",
1435
+ "running on treadmill",
1436
+ "sailing",
1437
+ "salsa dancing",
1438
+ "sanding floor",
1439
+ "scrambling eggs",
1440
+ "scuba diving",
1441
+ "setting table",
1442
+ "shaking hands",
1443
+ "shaking head",
1444
+ "sharpening knives",
1445
+ "sharpening pencil",
1446
+ "shaving head",
1447
+ "shaving legs",
1448
+ "shearing sheep",
1449
+ "shining shoes",
1450
+ "shooting basketball",
1451
+ "shooting goal (soccer)",
1452
+ "shot put",
1453
+ "shoveling snow",
1454
+ "shredding paper",
1455
+ "shuffling cards",
1456
+ "side kick",
1457
+ "sign language interpreting",
1458
+ "singing",
1459
+ "situp",
1460
+ "skateboarding",
1461
+ "ski jumping",
1462
+ "skiing (not slalom or crosscountry)",
1463
+ "skiing crosscountry",
1464
+ "skiing slalom",
1465
+ "skipping rope",
1466
+ "skydiving",
1467
+ "slacklining",
1468
+ "slapping",
1469
+ "sled dog racing",
1470
+ "smoking",
1471
+ "smoking hookah",
1472
+ "snatch weight lifting",
1473
+ "sneezing",
1474
+ "sniffing",
1475
+ "snorkeling",
1476
+ "snowboarding",
1477
+ "snowkiting",
1478
+ "snowmobiling",
1479
+ "somersaulting",
1480
+ "spinning poi",
1481
+ "spray painting",
1482
+ "spraying",
1483
+ "springboard diving",
1484
+ "squat",
1485
+ "sticking tongue out",
1486
+ "stomping grapes",
1487
+ "stretching arm",
1488
+ "stretching leg",
1489
+ "strumming guitar",
1490
+ "surfing crowd",
1491
+ "surfing water",
1492
+ "sweeping floor",
1493
+ "swimming backstroke",
1494
+ "swimming breast stroke",
1495
+ "swimming butterfly stroke",
1496
+ "swing dancing",
1497
+ "swinging legs",
1498
+ "swinging on something",
1499
+ "sword fighting",
1500
+ "tai chi",
1501
+ "taking a shower",
1502
+ "tango dancing",
1503
+ "tap dancing",
1504
+ "tapping guitar",
1505
+ "tapping pen",
1506
+ "tasting beer",
1507
+ "tasting food",
1508
+ "testifying",
1509
+ "texting",
1510
+ "throwing axe",
1511
+ "throwing ball",
1512
+ "throwing discus",
1513
+ "tickling",
1514
+ "tobogganing",
1515
+ "tossing coin",
1516
+ "tossing salad",
1517
+ "training dog",
1518
+ "trapezing",
1519
+ "trimming or shaving beard",
1520
+ "trimming trees",
1521
+ "triple jump",
1522
+ "tying bow tie",
1523
+ "tying knot (not on a tie)",
1524
+ "tying tie",
1525
+ "unboxing",
1526
+ "unloading truck",
1527
+ "using computer",
1528
+ "using remote controller (not gaming)",
1529
+ "using segway",
1530
+ "vault",
1531
+ "waiting in line",
1532
+ "walking the dog",
1533
+ "washing dishes",
1534
+ "washing feet",
1535
+ "washing hair",
1536
+ "washing hands",
1537
+ "water skiing",
1538
+ "water sliding",
1539
+ "watering plants",
1540
+ "waxing back",
1541
+ "waxing chest",
1542
+ "waxing eyebrows",
1543
+ "waxing legs",
1544
+ "weaving basket",
1545
+ "welding",
1546
+ "whistling",
1547
+ "windsurfing",
1548
+ "wrapping present",
1549
+ "wrestling",
1550
+ "writing",
1551
+ "yawning",
1552
+ "yoga",
1553
+ "zumba",
1554
+ ]
.venv/lib/python3.11/site-packages/torchvision/models/_utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import warnings
4
+ from collections import OrderedDict
5
+ from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
6
+
7
+ from torch import nn
8
+
9
+ from .._utils import sequence_to_str
10
+ from ._api import WeightsEnum
11
+
12
+
13
+ class IntermediateLayerGetter(nn.ModuleDict):
14
+ """
15
+ Module wrapper that returns intermediate layers from a model
16
+
17
+ It has a strong assumption that the modules have been registered
18
+ into the model in the same order as they are used.
19
+ This means that one should **not** reuse the same nn.Module
20
+ twice in the forward if you want this to work.
21
+
22
+ Additionally, it is only able to query submodules that are directly
23
+ assigned to the model. So if `model` is passed, `model.feature1` can
24
+ be returned, but not `model.feature1.layer2`.
25
+
26
+ Args:
27
+ model (nn.Module): model on which we will extract the features
28
+ return_layers (Dict[name, new_name]): a dict containing the names
29
+ of the modules for which the activations will be returned as
30
+ the key of the dict, and the value of the dict is the name
31
+ of the returned activation (which the user can specify).
32
+
33
+ Examples::
34
+
35
+ >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
36
+ >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
37
+ >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
38
+ >>> {'layer1': 'feat1', 'layer3': 'feat2'})
39
+ >>> out = new_m(torch.rand(1, 3, 224, 224))
40
+ >>> print([(k, v.shape) for k, v in out.items()])
41
+ >>> [('feat1', torch.Size([1, 64, 56, 56])),
42
+ >>> ('feat2', torch.Size([1, 256, 14, 14]))]
43
+ """
44
+
45
+ _version = 2
46
+ __annotations__ = {
47
+ "return_layers": Dict[str, str],
48
+ }
49
+
50
+ def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
51
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
52
+ raise ValueError("return_layers are not present in model")
53
+ orig_return_layers = return_layers
54
+ return_layers = {str(k): str(v) for k, v in return_layers.items()}
55
+ layers = OrderedDict()
56
+ for name, module in model.named_children():
57
+ layers[name] = module
58
+ if name in return_layers:
59
+ del return_layers[name]
60
+ if not return_layers:
61
+ break
62
+
63
+ super().__init__(layers)
64
+ self.return_layers = orig_return_layers
65
+
66
+ def forward(self, x):
67
+ out = OrderedDict()
68
+ for name, module in self.items():
69
+ x = module(x)
70
+ if name in self.return_layers:
71
+ out_name = self.return_layers[name]
72
+ out[out_name] = x
73
+ return out
74
+
75
+
76
+ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
77
+ """
78
+ This function is taken from the original tf repo.
79
+ It ensures that all layers have a channel number that is divisible by 8
80
+ It can be seen here:
81
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
82
+ """
83
+ if min_value is None:
84
+ min_value = divisor
85
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
86
+ # Make sure that round down does not go down by more than 10%.
87
+ if new_v < 0.9 * v:
88
+ new_v += divisor
89
+ return new_v
90
+
91
+
92
+ D = TypeVar("D")
93
+
94
+
95
+ def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
96
+ """Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
97
+
98
+ For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
99
+
100
+ .. code::
101
+
102
+ def old_fn(foo, bar, baz=None):
103
+ ...
104
+
105
+ def new_fn(foo, *, bar, baz=None):
106
+ ...
107
+
108
+ Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
109
+ and at the same time warn the user of the deprecation, this decorator can be used:
110
+
111
+ .. code::
112
+
113
+ @kwonly_to_pos_or_kw
114
+ def new_fn(foo, *, bar, baz=None):
115
+ ...
116
+
117
+ new_fn("foo", "bar, "baz")
118
+ """
119
+ params = inspect.signature(fn).parameters
120
+
121
+ try:
122
+ keyword_only_start_idx = next(
123
+ idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
124
+ )
125
+ except StopIteration:
126
+ raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
127
+
128
+ keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
129
+
130
+ @functools.wraps(fn)
131
+ def wrapper(*args: Any, **kwargs: Any) -> D:
132
+ args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
133
+ if keyword_only_args:
134
+ keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
135
+ warnings.warn(
136
+ f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
137
+ f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) "
138
+ f"instead."
139
+ )
140
+ kwargs.update(keyword_only_kwargs)
141
+
142
+ return fn(*args, **kwargs)
143
+
144
+ return wrapper
145
+
146
+
147
+ W = TypeVar("W", bound=WeightsEnum)
148
+ M = TypeVar("M", bound=nn.Module)
149
+ V = TypeVar("V")
150
+
151
+
152
+ def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
153
+ """Decorates a model builder with the new interface to make it compatible with the old.
154
+
155
+ In particular this handles two things:
156
+
157
+ 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
158
+ :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
159
+ 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
160
+ ``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
161
+
162
+ Args:
163
+ **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
164
+ name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
165
+ case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
166
+ the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
167
+ should be accessed with :meth:`~dict.get`.
168
+ """
169
+
170
+ def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
171
+ @kwonly_to_pos_or_kw
172
+ @functools.wraps(builder)
173
+ def inner_wrapper(*args: Any, **kwargs: Any) -> M:
174
+ for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
175
+ # If neither the weights nor the pretrained parameter as passed, or the weights argument already use
176
+ # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
177
+ # weight argument, since it is a valid value.
178
+ sentinel = object()
179
+ weights_arg = kwargs.get(weights_param, sentinel)
180
+ if (
181
+ (weights_param not in kwargs and pretrained_param not in kwargs)
182
+ or isinstance(weights_arg, WeightsEnum)
183
+ or (isinstance(weights_arg, str) and weights_arg != "legacy")
184
+ or weights_arg is None
185
+ ):
186
+ continue
187
+
188
+ # If the pretrained parameter was passed as positional argument, it is now mapped to
189
+ # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
190
+ # signature to infer the names of positionally passed arguments and thus has no knowledge that there
191
+ # used to be a pretrained parameter.
192
+ pretrained_positional = weights_arg is not sentinel
193
+ if pretrained_positional:
194
+ # We put the pretrained argument under its legacy name in the keyword argument dictionary to have
195
+ # unified access to the value if the default value is a callable.
196
+ kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
197
+ else:
198
+ pretrained_arg = kwargs[pretrained_param]
199
+
200
+ if pretrained_arg:
201
+ default_weights_arg = default(kwargs) if callable(default) else default
202
+ if not isinstance(default_weights_arg, WeightsEnum):
203
+ raise ValueError(f"No weights available for model {builder.__name__}")
204
+ else:
205
+ default_weights_arg = None
206
+
207
+ if not pretrained_positional:
208
+ warnings.warn(
209
+ f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
210
+ f"please use '{weights_param}' instead."
211
+ )
212
+
213
+ msg = (
214
+ f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and "
215
+ f"may be removed in the future. "
216
+ f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
217
+ )
218
+ if pretrained_arg:
219
+ msg = (
220
+ f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
221
+ f"to get the most up-to-date weights."
222
+ )
223
+ warnings.warn(msg)
224
+
225
+ del kwargs[pretrained_param]
226
+ kwargs[weights_param] = default_weights_arg
227
+
228
+ return builder(*args, **kwargs)
229
+
230
+ return inner_wrapper
231
+
232
+ return outer_wrapper
233
+
234
+
235
+ def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
236
+ if param in kwargs:
237
+ if kwargs[param] != new_value:
238
+ raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
239
+ else:
240
+ kwargs[param] = new_value
241
+
242
+
243
+ def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V:
244
+ if actual is not None:
245
+ if actual != expected:
246
+ raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.")
247
+ return expected
248
+
249
+
250
+ class _ModelURLs(dict):
251
+ def __getitem__(self, item):
252
+ warnings.warn(
253
+ "Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may "
254
+ "be removed in the future. Please access them via the appropriate Weights Enum instead."
255
+ )
256
+ return super().__getitem__(item)
.venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import re
4
+ import warnings
5
+ from collections import OrderedDict
6
+ from copy import deepcopy
7
+ from itertools import chain
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torchvision
12
+ from torch import fx, nn
13
+ from torch.fx.graph_module import _copy_attr
14
+
15
+
16
+ __all__ = ["create_feature_extractor", "get_graph_node_names"]
17
+
18
+
19
+ class LeafModuleAwareTracer(fx.Tracer):
20
+ """
21
+ An fx.Tracer that allows the user to specify a set of leaf modules, i.e.
22
+ modules that are not to be traced through. The resulting graph ends up
23
+ having single nodes referencing calls to the leaf modules' forward methods.
24
+ """
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ self.leaf_modules = {}
28
+ if "leaf_modules" in kwargs:
29
+ leaf_modules = kwargs.pop("leaf_modules")
30
+ self.leaf_modules = leaf_modules
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
34
+ if isinstance(m, tuple(self.leaf_modules)):
35
+ return True
36
+ return super().is_leaf_module(m, module_qualname)
37
+
38
+
39
+ class NodePathTracer(LeafModuleAwareTracer):
40
+ """
41
+ NodePathTracer is an FX tracer that, for each operation, also records the
42
+ name of the Node from which the operation originated. A node name here is
43
+ a `.` separated path walking the hierarchy from top level module down to
44
+ leaf operation or leaf module. The name of the top level module is not
45
+ included as part of the node name. For example, if we trace a module whose
46
+ forward method applies a ReLU module, the name for that node will simply
47
+ be 'relu'.
48
+
49
+ Some notes on the specifics:
50
+ - Nodes are recorded to `self.node_to_qualname` which is a dictionary
51
+ mapping a given Node object to its node name.
52
+ - Nodes are recorded in the order which they are executed during
53
+ tracing.
54
+ - When a duplicate node name is encountered, a suffix of the form
55
+ _{int} is added. The counter starts from 1.
56
+ """
57
+
58
+ def __init__(self, *args, **kwargs):
59
+ super().__init__(*args, **kwargs)
60
+ # Track the qualified name of the Node being traced
61
+ self.current_module_qualname = ""
62
+ # A map from FX Node to the qualified name\#
63
+ # NOTE: This is loosely like the "qualified name" mentioned in the
64
+ # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
65
+ # for the purposes of the torchvision feature extractor
66
+ self.node_to_qualname = OrderedDict()
67
+
68
+ def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
69
+ """
70
+ Override of `fx.Tracer.call_module`
71
+ This override:
72
+ 1) Stores away the qualified name of the caller for restoration later
73
+ 2) Adds the qualified name of the caller to
74
+ `current_module_qualname` for retrieval by `create_proxy`
75
+ 3) Once a leaf module is reached, calls `create_proxy`
76
+ 4) Restores the caller's qualified name into current_module_qualname
77
+ """
78
+ old_qualname = self.current_module_qualname
79
+ try:
80
+ module_qualname = self.path_of_module(m)
81
+ self.current_module_qualname = module_qualname
82
+ if not self.is_leaf_module(m, module_qualname):
83
+ out = forward(*args, **kwargs)
84
+ return out
85
+ return self.create_proxy("call_module", module_qualname, args, kwargs)
86
+ finally:
87
+ self.current_module_qualname = old_qualname
88
+
89
+ def create_proxy(
90
+ self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_
91
+ ) -> fx.proxy.Proxy:
92
+ """
93
+ Override of `Tracer.create_proxy`. This override intercepts the recording
94
+ of every operation and stores away the current traced module's qualified
95
+ name in `node_to_qualname`
96
+ """
97
+ proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
98
+ self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node)
99
+ return proxy
100
+
101
+ def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str:
102
+ node_qualname = module_qualname
103
+
104
+ if node.op != "call_module":
105
+ # In this case module_qualname from torch.fx doesn't go all the
106
+ # way to the leaf function/op, so we need to append it
107
+ if len(node_qualname) > 0:
108
+ # Only append '.' if we are deeper than the top level module
109
+ node_qualname += "."
110
+ node_qualname += str(node)
111
+
112
+ # Now we need to add an _{index} postfix on any repeated node names
113
+ # For modules we do this from scratch
114
+ # But for anything else, torch.fx already has a globally scoped
115
+ # _{index} postfix. But we want it locally (relative to direct parent)
116
+ # scoped. So first we need to undo the torch.fx postfix
117
+ if re.match(r".+_[0-9]+$", node_qualname) is not None:
118
+ node_qualname = node_qualname.rsplit("_", 1)[0]
119
+
120
+ # ... and now we add on our own postfix
121
+ for existing_qualname in reversed(self.node_to_qualname.values()):
122
+ # Check to see if existing_qualname is of the form
123
+ # {node_qualname} or {node_qualname}_{int}
124
+ if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None:
125
+ postfix = existing_qualname.replace(node_qualname, "")
126
+ if len(postfix):
127
+ # existing_qualname is of the form {node_qualname}_{int}
128
+ next_index = int(postfix[1:]) + 1
129
+ else:
130
+ # existing_qualname is of the form {node_qualname}
131
+ next_index = 1
132
+ node_qualname += f"_{next_index}"
133
+ break
134
+
135
+ return node_qualname
136
+
137
+
138
+ def _is_subseq(x, y):
139
+ """Check if y is a subsequence of x
140
+ https://stackoverflow.com/a/24017747/4391249
141
+ """
142
+ iter_x = iter(x)
143
+ return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
144
+
145
+
146
+ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
147
+ """
148
+ Utility function for warning the user if there are differences between
149
+ the train graph nodes and the eval graph nodes.
150
+ """
151
+ train_nodes = list(train_tracer.node_to_qualname.values())
152
+ eval_nodes = list(eval_tracer.node_to_qualname.values())
153
+
154
+ if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)):
155
+ return
156
+
157
+ suggestion_msg = (
158
+ "When choosing nodes for feature extraction, you may need to specify "
159
+ "output nodes for train and eval mode separately."
160
+ )
161
+
162
+ if _is_subseq(train_nodes, eval_nodes):
163
+ msg = (
164
+ "NOTE: The nodes obtained by tracing the model in eval mode "
165
+ "are a subsequence of those obtained in train mode. "
166
+ )
167
+ elif _is_subseq(eval_nodes, train_nodes):
168
+ msg = (
169
+ "NOTE: The nodes obtained by tracing the model in train mode "
170
+ "are a subsequence of those obtained in eval mode. "
171
+ )
172
+ else:
173
+ msg = "The nodes obtained by tracing the model in train mode are different to those obtained in eval mode. "
174
+ warnings.warn(msg + suggestion_msg)
175
+
176
+
177
+ def _get_leaf_modules_for_ops() -> List[type]:
178
+ members = inspect.getmembers(torchvision.ops)
179
+ result = []
180
+ for _, obj in members:
181
+ if inspect.isclass(obj) and issubclass(obj, torch.nn.Module):
182
+ result.append(obj)
183
+ return result
184
+
185
+
186
+ def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
187
+ default_autowrap_modules = (math, torchvision.ops)
188
+ default_leaf_modules = _get_leaf_modules_for_ops()
189
+ result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
190
+ result_tracer_kwargs["autowrap_modules"] = (
191
+ tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules))
192
+ if "autowrap_modules" in result_tracer_kwargs
193
+ else default_autowrap_modules
194
+ )
195
+ result_tracer_kwargs["leaf_modules"] = (
196
+ list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules))
197
+ if "leaf_modules" in result_tracer_kwargs
198
+ else default_leaf_modules
199
+ )
200
+ return result_tracer_kwargs
201
+
202
+
203
+ def get_graph_node_names(
204
+ model: nn.Module,
205
+ tracer_kwargs: Optional[Dict[str, Any]] = None,
206
+ suppress_diff_warning: bool = False,
207
+ concrete_args: Optional[Dict[str, Any]] = None,
208
+ ) -> Tuple[List[str], List[str]]:
209
+ """
210
+ Dev utility to return node names in order of execution. See note on node
211
+ names under :func:`create_feature_extractor`. Useful for seeing which node
212
+ names are available for feature extraction. There are two reasons that
213
+ node names can't easily be read directly from the code for a model:
214
+
215
+ 1. Not all submodules are traced through. Modules from ``torch.nn`` all
216
+ fall within this category.
217
+ 2. Nodes representing the repeated application of the same operation
218
+ or leaf module get a ``_{counter}`` postfix.
219
+
220
+ The model is traced twice: once in train mode, and once in eval mode. Both
221
+ sets of node names are returned.
222
+
223
+ For more details on the node naming conventions used here, please see the
224
+ :ref:`relevant subheading <about-node-names>` in the
225
+ `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
226
+
227
+ Args:
228
+ model (nn.Module): model for which we'd like to print node names
229
+ tracer_kwargs (dict, optional): a dictionary of keyword arguments for
230
+ ``NodePathTracer`` (they are eventually passed onto
231
+ `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
232
+ By default, it will be set to wrap and make leaf nodes all torchvision ops:
233
+ {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
234
+ WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
235
+ provided dictionary.
236
+ suppress_diff_warning (bool, optional): whether to suppress a warning
237
+ when there are discrepancies between the train and eval version of
238
+ the graph. Defaults to False.
239
+ concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
240
+ not be treated as Proxies. According to the `Pytorch docs
241
+ <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_,
242
+ this parameter's API may not be guaranteed.
243
+
244
+ Returns:
245
+ tuple(list, list): a list of node names from tracing the model in
246
+ train mode, and another from tracing the model in eval mode.
247
+
248
+ Examples::
249
+
250
+ >>> model = torchvision.models.resnet18()
251
+ >>> train_nodes, eval_nodes = get_graph_node_names(model)
252
+ """
253
+ tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
254
+ is_training = model.training
255
+ train_tracer = NodePathTracer(**tracer_kwargs)
256
+ train_tracer.trace(model.train(), concrete_args=concrete_args)
257
+ eval_tracer = NodePathTracer(**tracer_kwargs)
258
+ eval_tracer.trace(model.eval(), concrete_args=concrete_args)
259
+ train_nodes = list(train_tracer.node_to_qualname.values())
260
+ eval_nodes = list(eval_tracer.node_to_qualname.values())
261
+ if not suppress_diff_warning:
262
+ _warn_graph_differences(train_tracer, eval_tracer)
263
+ # Restore training state
264
+ model.train(is_training)
265
+ return train_nodes, eval_nodes
266
+
267
+
268
+ class DualGraphModule(fx.GraphModule):
269
+ """
270
+ A derivative of `fx.GraphModule`. Differs in the following ways:
271
+ - Requires a train and eval version of the underlying graph
272
+ - Copies submodules according to the nodes of both train and eval graphs.
273
+ - Calling train(mode) switches between train graph and eval graph.
274
+ """
275
+
276
+ def __init__(
277
+ self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
278
+ ):
279
+ """
280
+ Args:
281
+ root (nn.Module): module from which the copied module hierarchy is
282
+ built
283
+ train_graph (fx.Graph): the graph that should be used in train mode
284
+ eval_graph (fx.Graph): the graph that should be used in eval mode
285
+ """
286
+ super(fx.GraphModule, self).__init__()
287
+
288
+ self.__class__.__name__ = class_name
289
+
290
+ self.train_graph = train_graph
291
+ self.eval_graph = eval_graph
292
+
293
+ # Copy all get_attr and call_module ops (indicated by BOTH train and
294
+ # eval graphs)
295
+ for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
296
+ if node.op in ["get_attr", "call_module"]:
297
+ if not isinstance(node.target, str):
298
+ raise TypeError(f"node.target should be of type str instead of {type(node.target)}")
299
+ _copy_attr(root, self, node.target)
300
+
301
+ # train mode by default
302
+ self.train()
303
+ self.graph = train_graph
304
+
305
+ # (borrowed from fx.GraphModule):
306
+ # Store the Tracer class responsible for creating a Graph separately as part of the
307
+ # GraphModule state, except when the Tracer is defined in a local namespace.
308
+ # Locally defined Tracers are not pickleable. This is needed because torch.package will
309
+ # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
310
+ # to re-create the Graph during deserialization.
311
+ if self.eval_graph._tracer_cls != self.train_graph._tracer_cls:
312
+ raise TypeError(
313
+ f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train"
314
+ )
315
+ self._tracer_cls = None
316
+ if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__:
317
+ self._tracer_cls = self.graph._tracer_cls
318
+
319
+ def train(self, mode=True):
320
+ """
321
+ Swap out the graph depending on the selected training mode.
322
+ NOTE this should be safe when calling model.eval() because that just
323
+ calls this with mode == False.
324
+ """
325
+ # NOTE: Only set self.graph if the current graph is not the desired
326
+ # one. This saves us from recompiling the graph where not necessary.
327
+ if mode and not self.training:
328
+ self.graph = self.train_graph
329
+ elif not mode and self.training:
330
+ self.graph = self.eval_graph
331
+ return super().train(mode=mode)
332
+
333
+
334
+ def create_feature_extractor(
335
+ model: nn.Module,
336
+ return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
337
+ train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
338
+ eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
339
+ tracer_kwargs: Optional[Dict[str, Any]] = None,
340
+ suppress_diff_warning: bool = False,
341
+ concrete_args: Optional[Dict[str, Any]] = None,
342
+ ) -> fx.GraphModule:
343
+ """
344
+ Creates a new graph module that returns intermediate nodes from a given
345
+ model as dictionary with user specified keys as strings, and the requested
346
+ outputs as values. This is achieved by re-writing the computation graph of
347
+ the model via FX to return the desired nodes as outputs. All unused nodes
348
+ are removed, together with their corresponding parameters.
349
+
350
+ Desired output nodes must be specified as a ``.`` separated
351
+ path walking the module hierarchy from top level module down to leaf
352
+ operation or leaf module. For more details on the node naming conventions
353
+ used here, please see the :ref:`relevant subheading <about-node-names>`
354
+ in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
355
+
356
+ Not all models will be FX traceable, although with some massaging they can
357
+ be made to cooperate. Here's a (not exhaustive) list of tips:
358
+
359
+ - If you don't need to trace through a particular, problematic
360
+ sub-module, turn it into a "leaf module" by passing a list of
361
+ ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
362
+ It will not be traced through, but rather, the resulting graph will
363
+ hold a reference to that module's forward method.
364
+ - Likewise, you may turn functions into leaf functions by passing a
365
+ list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
366
+ example below).
367
+ - Some inbuilt Python functions can be problematic. For instance,
368
+ ``int`` will raise an error during tracing. You may wrap them in your
369
+ own function and then pass that in ``autowrap_functions`` as one of
370
+ the ``tracer_kwargs``.
371
+
372
+ For further information on FX see the
373
+ `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
374
+
375
+ Args:
376
+ model (nn.Module): model on which we will extract the features
377
+ return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
378
+ containing the names (or partial names - see note above)
379
+ of the nodes for which the activations will be returned. If it is
380
+ a ``Dict``, the keys are the node names, and the values
381
+ are the user-specified keys for the graph module's returned
382
+ dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
383
+ node specification strings directly to output names. In the case
384
+ that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
385
+ this should not be specified.
386
+ train_return_nodes (list or dict, optional): similar to
387
+ ``return_nodes``. This can be used if the return nodes
388
+ for train mode are different than those from eval mode.
389
+ If this is specified, ``eval_return_nodes`` must also be specified,
390
+ and ``return_nodes`` should not be specified.
391
+ eval_return_nodes (list or dict, optional): similar to
392
+ ``return_nodes``. This can be used if the return nodes
393
+ for train mode are different than those from eval mode.
394
+ If this is specified, ``train_return_nodes`` must also be specified,
395
+ and `return_nodes` should not be specified.
396
+ tracer_kwargs (dict, optional): a dictionary of keyword arguments for
397
+ ``NodePathTracer`` (which passes them onto it's parent class
398
+ `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
399
+ By default, it will be set to wrap and make leaf nodes all torchvision ops:
400
+ {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
401
+ WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
402
+ provided dictionary.
403
+ suppress_diff_warning (bool, optional): whether to suppress a warning
404
+ when there are discrepancies between the train and eval version of
405
+ the graph. Defaults to False.
406
+ concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
407
+ not be treated as Proxies. According to the `Pytorch docs
408
+ <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_,
409
+ this parameter's API may not be guaranteed.
410
+
411
+ Examples::
412
+
413
+ >>> # Feature extraction with resnet
414
+ >>> model = torchvision.models.resnet18()
415
+ >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
416
+ >>> model = create_feature_extractor(
417
+ >>> model, {'layer1': 'feat1', 'layer3': 'feat2'})
418
+ >>> out = model(torch.rand(1, 3, 224, 224))
419
+ >>> print([(k, v.shape) for k, v in out.items()])
420
+ >>> [('feat1', torch.Size([1, 64, 56, 56])),
421
+ >>> ('feat2', torch.Size([1, 256, 14, 14]))]
422
+
423
+ >>> # Specifying leaf modules and leaf functions
424
+ >>> def leaf_function(x):
425
+ >>> # This would raise a TypeError if traced through
426
+ >>> return int(x)
427
+ >>>
428
+ >>> class LeafModule(torch.nn.Module):
429
+ >>> def forward(self, x):
430
+ >>> # This would raise a TypeError if traced through
431
+ >>> int(x.shape[0])
432
+ >>> return torch.nn.functional.relu(x + 4)
433
+ >>>
434
+ >>> class MyModule(torch.nn.Module):
435
+ >>> def __init__(self):
436
+ >>> super().__init__()
437
+ >>> self.conv = torch.nn.Conv2d(3, 1, 3)
438
+ >>> self.leaf_module = LeafModule()
439
+ >>>
440
+ >>> def forward(self, x):
441
+ >>> leaf_function(x.shape[0])
442
+ >>> x = self.conv(x)
443
+ >>> return self.leaf_module(x)
444
+ >>>
445
+ >>> model = create_feature_extractor(
446
+ >>> MyModule(), return_nodes=['leaf_module'],
447
+ >>> tracer_kwargs={'leaf_modules': [LeafModule],
448
+ >>> 'autowrap_functions': [leaf_function]})
449
+
450
+ """
451
+ tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
452
+ is_training = model.training
453
+
454
+ if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):
455
+
456
+ raise ValueError(
457
+ "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified"
458
+ )
459
+
460
+ if (train_return_nodes is None) ^ (eval_return_nodes is None):
461
+ raise ValueError(
462
+ "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified"
463
+ )
464
+
465
+ if not ((return_nodes is None) ^ (train_return_nodes is None)):
466
+ raise ValueError("If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified")
467
+
468
+ # Put *_return_nodes into Dict[str, str] format
469
+ def to_strdict(n) -> Dict[str, str]:
470
+ if isinstance(n, list):
471
+ return {str(i): str(i) for i in n}
472
+ return {str(k): str(v) for k, v in n.items()}
473
+
474
+ if train_return_nodes is None:
475
+ return_nodes = to_strdict(return_nodes)
476
+ train_return_nodes = deepcopy(return_nodes)
477
+ eval_return_nodes = deepcopy(return_nodes)
478
+ else:
479
+ train_return_nodes = to_strdict(train_return_nodes)
480
+ eval_return_nodes = to_strdict(eval_return_nodes)
481
+
482
+ # Repeat the tracing and graph rewriting for train and eval mode
483
+ tracers = {}
484
+ graphs = {}
485
+ mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes}
486
+ for mode in ["train", "eval"]:
487
+ if mode == "train":
488
+ model.train()
489
+ elif mode == "eval":
490
+ model.eval()
491
+
492
+ # Instantiate our NodePathTracer and use that to trace the model
493
+ tracer = NodePathTracer(**tracer_kwargs)
494
+ graph = tracer.trace(model, concrete_args=concrete_args)
495
+
496
+ name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
497
+ graph_module = fx.GraphModule(tracer.root, graph, name)
498
+
499
+ available_nodes = list(tracer.node_to_qualname.values())
500
+ # FIXME We don't know if we should expect this to happen
501
+ if len(set(available_nodes)) != len(available_nodes):
502
+ raise ValueError(
503
+ "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
504
+ )
505
+ # Check that all outputs in return_nodes are present in the model
506
+ for query in mode_return_nodes[mode].keys():
507
+ # To check if a query is available we need to check that at least
508
+ # one of the available names starts with it up to a .
509
+ if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]):
510
+ raise ValueError(
511
+ f"node: '{query}' is not present in model. Hint: use "
512
+ "`get_graph_node_names` to make sure the "
513
+ "`return_nodes` you specified are present. It may even "
514
+ "be that you need to specify `train_return_nodes` and "
515
+ "`eval_return_nodes` separately."
516
+ )
517
+
518
+ # Remove existing output nodes (train mode)
519
+ orig_output_nodes = []
520
+ for n in reversed(graph_module.graph.nodes):
521
+ if n.op == "output":
522
+ orig_output_nodes.append(n)
523
+ if not orig_output_nodes:
524
+ raise ValueError("No output nodes found in graph_module.graph.nodes")
525
+
526
+ for n in orig_output_nodes:
527
+ graph_module.graph.erase_node(n)
528
+
529
+ # Find nodes corresponding to return_nodes and make them into output_nodes
530
+ nodes = [n for n in graph_module.graph.nodes]
531
+ output_nodes = OrderedDict()
532
+ for n in reversed(nodes):
533
+ module_qualname = tracer.node_to_qualname.get(n)
534
+ if module_qualname is None:
535
+ # NOTE - Know cases where this happens:
536
+ # - Node representing creation of a tensor constant - probably
537
+ # not interesting as a return node
538
+ # - When packing outputs into a named tuple like in InceptionV3
539
+ continue
540
+ for query in mode_return_nodes[mode]:
541
+ depth = query.count(".")
542
+ if ".".join(module_qualname.split(".")[: depth + 1]) == query:
543
+ output_nodes[mode_return_nodes[mode][query]] = n
544
+ mode_return_nodes[mode].pop(query)
545
+ break
546
+ output_nodes = OrderedDict(reversed(list(output_nodes.items())))
547
+
548
+ # And add them in the end of the graph
549
+ with graph_module.graph.inserting_after(nodes[-1]):
550
+ graph_module.graph.output(output_nodes)
551
+
552
+ # Remove unused modules / parameters
553
+ graph_module.graph.eliminate_dead_code()
554
+ graph_module.recompile()
555
+
556
+ # Keep track of the tracer and graph, so we can choose the main one
557
+ tracers[mode] = tracer
558
+ graphs[mode] = graph
559
+
560
+ # Warn user if there are any discrepancies between the graphs of the
561
+ # train and eval modes
562
+ if not suppress_diff_warning:
563
+ _warn_graph_differences(tracers["train"], tracers["eval"])
564
+
565
+ # Build the final graph module
566
+ graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name)
567
+
568
+ # Restore original training mode
569
+ model.train(is_training)
570
+ graph_module.train(is_training)
571
+
572
+ return graph_module
.venv/lib/python3.11/site-packages/torchvision/models/inception.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, Tensor
9
+
10
+ from ..transforms._presets import ImageClassification
11
+ from ..utils import _log_api_usage_once
12
+ from ._api import register_model, Weights, WeightsEnum
13
+ from ._meta import _IMAGENET_CATEGORIES
14
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
15
+
16
+
17
+ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
18
+
19
+
20
+ InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
21
+ InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
22
+
23
+ # Script annotations failed with _GoogleNetOutputs = namedtuple ...
24
+ # _InceptionOutputs set here for backwards compat
25
+ _InceptionOutputs = InceptionOutputs
26
+
27
+
28
+ class Inception3(nn.Module):
29
+ def __init__(
30
+ self,
31
+ num_classes: int = 1000,
32
+ aux_logits: bool = True,
33
+ transform_input: bool = False,
34
+ inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
35
+ init_weights: Optional[bool] = None,
36
+ dropout: float = 0.5,
37
+ ) -> None:
38
+ super().__init__()
39
+ _log_api_usage_once(self)
40
+ if inception_blocks is None:
41
+ inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
42
+ if init_weights is None:
43
+ warnings.warn(
44
+ "The default weight initialization of inception_v3 will be changed in future releases of "
45
+ "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
46
+ " due to scipy/scipy#11299), please set init_weights=True.",
47
+ FutureWarning,
48
+ )
49
+ init_weights = True
50
+ if len(inception_blocks) != 7:
51
+ raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}")
52
+ conv_block = inception_blocks[0]
53
+ inception_a = inception_blocks[1]
54
+ inception_b = inception_blocks[2]
55
+ inception_c = inception_blocks[3]
56
+ inception_d = inception_blocks[4]
57
+ inception_e = inception_blocks[5]
58
+ inception_aux = inception_blocks[6]
59
+
60
+ self.aux_logits = aux_logits
61
+ self.transform_input = transform_input
62
+ self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
63
+ self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
64
+ self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
65
+ self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
66
+ self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
67
+ self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
68
+ self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
69
+ self.Mixed_5b = inception_a(192, pool_features=32)
70
+ self.Mixed_5c = inception_a(256, pool_features=64)
71
+ self.Mixed_5d = inception_a(288, pool_features=64)
72
+ self.Mixed_6a = inception_b(288)
73
+ self.Mixed_6b = inception_c(768, channels_7x7=128)
74
+ self.Mixed_6c = inception_c(768, channels_7x7=160)
75
+ self.Mixed_6d = inception_c(768, channels_7x7=160)
76
+ self.Mixed_6e = inception_c(768, channels_7x7=192)
77
+ self.AuxLogits: Optional[nn.Module] = None
78
+ if aux_logits:
79
+ self.AuxLogits = inception_aux(768, num_classes)
80
+ self.Mixed_7a = inception_d(768)
81
+ self.Mixed_7b = inception_e(1280)
82
+ self.Mixed_7c = inception_e(2048)
83
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
84
+ self.dropout = nn.Dropout(p=dropout)
85
+ self.fc = nn.Linear(2048, num_classes)
86
+ if init_weights:
87
+ for m in self.modules():
88
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
89
+ stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
90
+ torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
91
+ elif isinstance(m, nn.BatchNorm2d):
92
+ nn.init.constant_(m.weight, 1)
93
+ nn.init.constant_(m.bias, 0)
94
+
95
+ def _transform_input(self, x: Tensor) -> Tensor:
96
+ if self.transform_input:
97
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
98
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
99
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
100
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
101
+ return x
102
+
103
+ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
104
+ # N x 3 x 299 x 299
105
+ x = self.Conv2d_1a_3x3(x)
106
+ # N x 32 x 149 x 149
107
+ x = self.Conv2d_2a_3x3(x)
108
+ # N x 32 x 147 x 147
109
+ x = self.Conv2d_2b_3x3(x)
110
+ # N x 64 x 147 x 147
111
+ x = self.maxpool1(x)
112
+ # N x 64 x 73 x 73
113
+ x = self.Conv2d_3b_1x1(x)
114
+ # N x 80 x 73 x 73
115
+ x = self.Conv2d_4a_3x3(x)
116
+ # N x 192 x 71 x 71
117
+ x = self.maxpool2(x)
118
+ # N x 192 x 35 x 35
119
+ x = self.Mixed_5b(x)
120
+ # N x 256 x 35 x 35
121
+ x = self.Mixed_5c(x)
122
+ # N x 288 x 35 x 35
123
+ x = self.Mixed_5d(x)
124
+ # N x 288 x 35 x 35
125
+ x = self.Mixed_6a(x)
126
+ # N x 768 x 17 x 17
127
+ x = self.Mixed_6b(x)
128
+ # N x 768 x 17 x 17
129
+ x = self.Mixed_6c(x)
130
+ # N x 768 x 17 x 17
131
+ x = self.Mixed_6d(x)
132
+ # N x 768 x 17 x 17
133
+ x = self.Mixed_6e(x)
134
+ # N x 768 x 17 x 17
135
+ aux: Optional[Tensor] = None
136
+ if self.AuxLogits is not None:
137
+ if self.training:
138
+ aux = self.AuxLogits(x)
139
+ # N x 768 x 17 x 17
140
+ x = self.Mixed_7a(x)
141
+ # N x 1280 x 8 x 8
142
+ x = self.Mixed_7b(x)
143
+ # N x 2048 x 8 x 8
144
+ x = self.Mixed_7c(x)
145
+ # N x 2048 x 8 x 8
146
+ # Adaptive average pooling
147
+ x = self.avgpool(x)
148
+ # N x 2048 x 1 x 1
149
+ x = self.dropout(x)
150
+ # N x 2048 x 1 x 1
151
+ x = torch.flatten(x, 1)
152
+ # N x 2048
153
+ x = self.fc(x)
154
+ # N x 1000 (num_classes)
155
+ return x, aux
156
+
157
+ @torch.jit.unused
158
+ def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
159
+ if self.training and self.aux_logits:
160
+ return InceptionOutputs(x, aux)
161
+ else:
162
+ return x # type: ignore[return-value]
163
+
164
+ def forward(self, x: Tensor) -> InceptionOutputs:
165
+ x = self._transform_input(x)
166
+ x, aux = self._forward(x)
167
+ aux_defined = self.training and self.aux_logits
168
+ if torch.jit.is_scripting():
169
+ if not aux_defined:
170
+ warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
171
+ return InceptionOutputs(x, aux)
172
+ else:
173
+ return self.eager_outputs(x, aux)
174
+
175
+
176
+ class InceptionA(nn.Module):
177
+ def __init__(
178
+ self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
179
+ ) -> None:
180
+ super().__init__()
181
+ if conv_block is None:
182
+ conv_block = BasicConv2d
183
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
184
+
185
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
186
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
187
+
188
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
189
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
190
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
191
+
192
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
193
+
194
+ def _forward(self, x: Tensor) -> List[Tensor]:
195
+ branch1x1 = self.branch1x1(x)
196
+
197
+ branch5x5 = self.branch5x5_1(x)
198
+ branch5x5 = self.branch5x5_2(branch5x5)
199
+
200
+ branch3x3dbl = self.branch3x3dbl_1(x)
201
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
202
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
203
+
204
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
205
+ branch_pool = self.branch_pool(branch_pool)
206
+
207
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
208
+ return outputs
209
+
210
+ def forward(self, x: Tensor) -> Tensor:
211
+ outputs = self._forward(x)
212
+ return torch.cat(outputs, 1)
213
+
214
+
215
+ class InceptionB(nn.Module):
216
+ def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
217
+ super().__init__()
218
+ if conv_block is None:
219
+ conv_block = BasicConv2d
220
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
221
+
222
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
223
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
224
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
225
+
226
+ def _forward(self, x: Tensor) -> List[Tensor]:
227
+ branch3x3 = self.branch3x3(x)
228
+
229
+ branch3x3dbl = self.branch3x3dbl_1(x)
230
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
231
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
232
+
233
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
234
+
235
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
236
+ return outputs
237
+
238
+ def forward(self, x: Tensor) -> Tensor:
239
+ outputs = self._forward(x)
240
+ return torch.cat(outputs, 1)
241
+
242
+
243
+ class InceptionC(nn.Module):
244
+ def __init__(
245
+ self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
246
+ ) -> None:
247
+ super().__init__()
248
+ if conv_block is None:
249
+ conv_block = BasicConv2d
250
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
251
+
252
+ c7 = channels_7x7
253
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
254
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
255
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
256
+
257
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
258
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
259
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
260
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
261
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
262
+
263
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
264
+
265
+ def _forward(self, x: Tensor) -> List[Tensor]:
266
+ branch1x1 = self.branch1x1(x)
267
+
268
+ branch7x7 = self.branch7x7_1(x)
269
+ branch7x7 = self.branch7x7_2(branch7x7)
270
+ branch7x7 = self.branch7x7_3(branch7x7)
271
+
272
+ branch7x7dbl = self.branch7x7dbl_1(x)
273
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
274
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
275
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
276
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
277
+
278
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
279
+ branch_pool = self.branch_pool(branch_pool)
280
+
281
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
282
+ return outputs
283
+
284
+ def forward(self, x: Tensor) -> Tensor:
285
+ outputs = self._forward(x)
286
+ return torch.cat(outputs, 1)
287
+
288
+
289
+ class InceptionD(nn.Module):
290
+ def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
291
+ super().__init__()
292
+ if conv_block is None:
293
+ conv_block = BasicConv2d
294
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
295
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
296
+
297
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
298
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
299
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
300
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
301
+
302
+ def _forward(self, x: Tensor) -> List[Tensor]:
303
+ branch3x3 = self.branch3x3_1(x)
304
+ branch3x3 = self.branch3x3_2(branch3x3)
305
+
306
+ branch7x7x3 = self.branch7x7x3_1(x)
307
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
308
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
309
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
310
+
311
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
312
+ outputs = [branch3x3, branch7x7x3, branch_pool]
313
+ return outputs
314
+
315
+ def forward(self, x: Tensor) -> Tensor:
316
+ outputs = self._forward(x)
317
+ return torch.cat(outputs, 1)
318
+
319
+
320
+ class InceptionE(nn.Module):
321
+ def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
322
+ super().__init__()
323
+ if conv_block is None:
324
+ conv_block = BasicConv2d
325
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
326
+
327
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
328
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
329
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
330
+
331
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
332
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
333
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
334
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
335
+
336
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
337
+
338
+ def _forward(self, x: Tensor) -> List[Tensor]:
339
+ branch1x1 = self.branch1x1(x)
340
+
341
+ branch3x3 = self.branch3x3_1(x)
342
+ branch3x3 = [
343
+ self.branch3x3_2a(branch3x3),
344
+ self.branch3x3_2b(branch3x3),
345
+ ]
346
+ branch3x3 = torch.cat(branch3x3, 1)
347
+
348
+ branch3x3dbl = self.branch3x3dbl_1(x)
349
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
350
+ branch3x3dbl = [
351
+ self.branch3x3dbl_3a(branch3x3dbl),
352
+ self.branch3x3dbl_3b(branch3x3dbl),
353
+ ]
354
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
355
+
356
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
357
+ branch_pool = self.branch_pool(branch_pool)
358
+
359
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
360
+ return outputs
361
+
362
+ def forward(self, x: Tensor) -> Tensor:
363
+ outputs = self._forward(x)
364
+ return torch.cat(outputs, 1)
365
+
366
+
367
+ class InceptionAux(nn.Module):
368
+ def __init__(
369
+ self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
370
+ ) -> None:
371
+ super().__init__()
372
+ if conv_block is None:
373
+ conv_block = BasicConv2d
374
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
375
+ self.conv1 = conv_block(128, 768, kernel_size=5)
376
+ self.conv1.stddev = 0.01 # type: ignore[assignment]
377
+ self.fc = nn.Linear(768, num_classes)
378
+ self.fc.stddev = 0.001 # type: ignore[assignment]
379
+
380
+ def forward(self, x: Tensor) -> Tensor:
381
+ # N x 768 x 17 x 17
382
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
383
+ # N x 768 x 5 x 5
384
+ x = self.conv0(x)
385
+ # N x 128 x 5 x 5
386
+ x = self.conv1(x)
387
+ # N x 768 x 1 x 1
388
+ # Adaptive average pooling
389
+ x = F.adaptive_avg_pool2d(x, (1, 1))
390
+ # N x 768 x 1 x 1
391
+ x = torch.flatten(x, 1)
392
+ # N x 768
393
+ x = self.fc(x)
394
+ # N x 1000
395
+ return x
396
+
397
+
398
+ class BasicConv2d(nn.Module):
399
+ def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
400
+ super().__init__()
401
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
402
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
403
+
404
+ def forward(self, x: Tensor) -> Tensor:
405
+ x = self.conv(x)
406
+ x = self.bn(x)
407
+ return F.relu(x, inplace=True)
408
+
409
+
410
+ class Inception_V3_Weights(WeightsEnum):
411
+ IMAGENET1K_V1 = Weights(
412
+ url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
413
+ transforms=partial(ImageClassification, crop_size=299, resize_size=342),
414
+ meta={
415
+ "num_params": 27161264,
416
+ "min_size": (75, 75),
417
+ "categories": _IMAGENET_CATEGORIES,
418
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
419
+ "_metrics": {
420
+ "ImageNet-1K": {
421
+ "acc@1": 77.294,
422
+ "acc@5": 93.450,
423
+ }
424
+ },
425
+ "_ops": 5.713,
426
+ "_file_size": 103.903,
427
+ "_docs": """These weights are ported from the original paper.""",
428
+ },
429
+ )
430
+ DEFAULT = IMAGENET1K_V1
431
+
432
+
433
+ @register_model()
434
+ @handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
435
+ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
436
+ """
437
+ Inception v3 model architecture from
438
+ `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_.
439
+
440
+ .. note::
441
+ **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
442
+ N x 3 x 299 x 299, so ensure your images are sized accordingly.
443
+
444
+ Args:
445
+ weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The
446
+ pretrained weights for the model. See
447
+ :class:`~torchvision.models.Inception_V3_Weights` below for
448
+ more details, and possible values. By default, no pre-trained
449
+ weights are used.
450
+ progress (bool, optional): If True, displays a progress bar of the
451
+ download to stderr. Default is True.
452
+ **kwargs: parameters passed to the ``torchvision.models.Inception3``
453
+ base class. Please refer to the `source code
454
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_
455
+ for more details about this class.
456
+
457
+ .. autoclass:: torchvision.models.Inception_V3_Weights
458
+ :members:
459
+ """
460
+ weights = Inception_V3_Weights.verify(weights)
461
+
462
+ original_aux_logits = kwargs.get("aux_logits", True)
463
+ if weights is not None:
464
+ if "transform_input" not in kwargs:
465
+ _ovewrite_named_param(kwargs, "transform_input", True)
466
+ _ovewrite_named_param(kwargs, "aux_logits", True)
467
+ _ovewrite_named_param(kwargs, "init_weights", False)
468
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
469
+
470
+ model = Inception3(**kwargs)
471
+
472
+ if weights is not None:
473
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
474
+ if not original_aux_logits:
475
+ model.aux_logits = False
476
+ model.AuxLogits = None
477
+
478
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/mnasnet.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+
9
+ from ..transforms._presets import ImageClassification
10
+ from ..utils import _log_api_usage_once
11
+ from ._api import register_model, Weights, WeightsEnum
12
+ from ._meta import _IMAGENET_CATEGORIES
13
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
14
+
15
+
16
+ __all__ = [
17
+ "MNASNet",
18
+ "MNASNet0_5_Weights",
19
+ "MNASNet0_75_Weights",
20
+ "MNASNet1_0_Weights",
21
+ "MNASNet1_3_Weights",
22
+ "mnasnet0_5",
23
+ "mnasnet0_75",
24
+ "mnasnet1_0",
25
+ "mnasnet1_3",
26
+ ]
27
+
28
+
29
+ # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
30
+ # 1.0 - tensorflow.
31
+ _BN_MOMENTUM = 1 - 0.9997
32
+
33
+
34
+ class _InvertedResidual(nn.Module):
35
+ def __init__(
36
+ self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
37
+ ) -> None:
38
+ super().__init__()
39
+ if stride not in [1, 2]:
40
+ raise ValueError(f"stride should be 1 or 2 instead of {stride}")
41
+ if kernel_size not in [3, 5]:
42
+ raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
43
+ mid_ch = in_ch * expansion_factor
44
+ self.apply_residual = in_ch == out_ch and stride == 1
45
+ self.layers = nn.Sequential(
46
+ # Pointwise
47
+ nn.Conv2d(in_ch, mid_ch, 1, bias=False),
48
+ nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
49
+ nn.ReLU(inplace=True),
50
+ # Depthwise
51
+ nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
52
+ nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
53
+ nn.ReLU(inplace=True),
54
+ # Linear pointwise. Note that there's no activation.
55
+ nn.Conv2d(mid_ch, out_ch, 1, bias=False),
56
+ nn.BatchNorm2d(out_ch, momentum=bn_momentum),
57
+ )
58
+
59
+ def forward(self, input: Tensor) -> Tensor:
60
+ if self.apply_residual:
61
+ return self.layers(input) + input
62
+ else:
63
+ return self.layers(input)
64
+
65
+
66
+ def _stack(
67
+ in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
68
+ ) -> nn.Sequential:
69
+ """Creates a stack of inverted residuals."""
70
+ if repeats < 1:
71
+ raise ValueError(f"repeats should be >= 1, instead got {repeats}")
72
+ # First one has no skip, because feature map size changes.
73
+ first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
74
+ remaining = []
75
+ for _ in range(1, repeats):
76
+ remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
77
+ return nn.Sequential(first, *remaining)
78
+
79
+
80
+ def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
81
+ """Asymmetric rounding to make `val` divisible by `divisor`. With default
82
+ bias, will round up, unless the number is no more than 10% greater than the
83
+ smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
84
+ if not 0.0 < round_up_bias < 1.0:
85
+ raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
86
+ new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
87
+ return new_val if new_val >= round_up_bias * val else new_val + divisor
88
+
89
+
90
+ def _get_depths(alpha: float) -> List[int]:
91
+ """Scales tensor depths as in reference MobileNet code, prefers rounding up
92
+ rather than down."""
93
+ depths = [32, 16, 24, 40, 80, 96, 192, 320]
94
+ return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
95
+
96
+
97
+ class MNASNet(torch.nn.Module):
98
+ """MNASNet, as described in https://arxiv.org/abs/1807.11626. This
99
+ implements the B1 variant of the model.
100
+ >>> model = MNASNet(1.0, num_classes=1000)
101
+ >>> x = torch.rand(1, 3, 224, 224)
102
+ >>> y = model(x)
103
+ >>> y.dim()
104
+ 2
105
+ >>> y.nelement()
106
+ 1000
107
+ """
108
+
109
+ # Version 2 adds depth scaling in the initial stages of the network.
110
+ _version = 2
111
+
112
+ def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
113
+ super().__init__()
114
+ _log_api_usage_once(self)
115
+ if alpha <= 0.0:
116
+ raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
117
+ self.alpha = alpha
118
+ self.num_classes = num_classes
119
+ depths = _get_depths(alpha)
120
+ layers = [
121
+ # First layer: regular conv.
122
+ nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
123
+ nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
124
+ nn.ReLU(inplace=True),
125
+ # Depthwise separable, no skip.
126
+ nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
127
+ nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
128
+ nn.ReLU(inplace=True),
129
+ nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
130
+ nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
131
+ # MNASNet blocks: stacks of inverted residuals.
132
+ _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
133
+ _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
134
+ _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
135
+ _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
136
+ _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
137
+ _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
138
+ # Final mapping to classifier input.
139
+ nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
140
+ nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
141
+ nn.ReLU(inplace=True),
142
+ ]
143
+ self.layers = nn.Sequential(*layers)
144
+ self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
145
+
146
+ for m in self.modules():
147
+ if isinstance(m, nn.Conv2d):
148
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
149
+ if m.bias is not None:
150
+ nn.init.zeros_(m.bias)
151
+ elif isinstance(m, nn.BatchNorm2d):
152
+ nn.init.ones_(m.weight)
153
+ nn.init.zeros_(m.bias)
154
+ elif isinstance(m, nn.Linear):
155
+ nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
156
+ nn.init.zeros_(m.bias)
157
+
158
+ def forward(self, x: Tensor) -> Tensor:
159
+ x = self.layers(x)
160
+ # Equivalent to global avgpool and removing H and W dimensions.
161
+ x = x.mean([2, 3])
162
+ return self.classifier(x)
163
+
164
+ def _load_from_state_dict(
165
+ self,
166
+ state_dict: Dict,
167
+ prefix: str,
168
+ local_metadata: Dict,
169
+ strict: bool,
170
+ missing_keys: List[str],
171
+ unexpected_keys: List[str],
172
+ error_msgs: List[str],
173
+ ) -> None:
174
+ version = local_metadata.get("version", None)
175
+ if version not in [1, 2]:
176
+ raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
177
+
178
+ if version == 1 and not self.alpha == 1.0:
179
+ # In the initial version of the model (v1), stem was fixed-size.
180
+ # All other layer configurations were the same. This will patch
181
+ # the model so that it's identical to v1. Model with alpha 1.0 is
182
+ # unaffected.
183
+ depths = _get_depths(self.alpha)
184
+ v1_stem = [
185
+ nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
186
+ nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
187
+ nn.ReLU(inplace=True),
188
+ nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
189
+ nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
190
+ nn.ReLU(inplace=True),
191
+ nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
192
+ nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
193
+ _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
194
+ ]
195
+ for idx, layer in enumerate(v1_stem):
196
+ self.layers[idx] = layer
197
+
198
+ # The model is now identical to v1, and must be saved as such.
199
+ self._version = 1
200
+ warnings.warn(
201
+ "A new version of MNASNet model has been implemented. "
202
+ "Your checkpoint was saved using the previous version. "
203
+ "This checkpoint will load and work as before, but "
204
+ "you may want to upgrade by training a newer model or "
205
+ "transfer learning from an updated ImageNet checkpoint.",
206
+ UserWarning,
207
+ )
208
+
209
+ super()._load_from_state_dict(
210
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
211
+ )
212
+
213
+
214
+ _COMMON_META = {
215
+ "min_size": (1, 1),
216
+ "categories": _IMAGENET_CATEGORIES,
217
+ "recipe": "https://github.com/1e100/mnasnet_trainer",
218
+ }
219
+
220
+
221
+ class MNASNet0_5_Weights(WeightsEnum):
222
+ IMAGENET1K_V1 = Weights(
223
+ url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
224
+ transforms=partial(ImageClassification, crop_size=224),
225
+ meta={
226
+ **_COMMON_META,
227
+ "num_params": 2218512,
228
+ "_metrics": {
229
+ "ImageNet-1K": {
230
+ "acc@1": 67.734,
231
+ "acc@5": 87.490,
232
+ }
233
+ },
234
+ "_ops": 0.104,
235
+ "_file_size": 8.591,
236
+ "_docs": """These weights reproduce closely the results of the paper.""",
237
+ },
238
+ )
239
+ DEFAULT = IMAGENET1K_V1
240
+
241
+
242
+ class MNASNet0_75_Weights(WeightsEnum):
243
+ IMAGENET1K_V1 = Weights(
244
+ url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
245
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
246
+ meta={
247
+ **_COMMON_META,
248
+ "recipe": "https://github.com/pytorch/vision/pull/6019",
249
+ "num_params": 3170208,
250
+ "_metrics": {
251
+ "ImageNet-1K": {
252
+ "acc@1": 71.180,
253
+ "acc@5": 90.496,
254
+ }
255
+ },
256
+ "_ops": 0.215,
257
+ "_file_size": 12.303,
258
+ "_docs": """
259
+ These weights were trained from scratch by using TorchVision's `new training recipe
260
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
261
+ """,
262
+ },
263
+ )
264
+ DEFAULT = IMAGENET1K_V1
265
+
266
+
267
+ class MNASNet1_0_Weights(WeightsEnum):
268
+ IMAGENET1K_V1 = Weights(
269
+ url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
270
+ transforms=partial(ImageClassification, crop_size=224),
271
+ meta={
272
+ **_COMMON_META,
273
+ "num_params": 4383312,
274
+ "_metrics": {
275
+ "ImageNet-1K": {
276
+ "acc@1": 73.456,
277
+ "acc@5": 91.510,
278
+ }
279
+ },
280
+ "_ops": 0.314,
281
+ "_file_size": 16.915,
282
+ "_docs": """These weights reproduce closely the results of the paper.""",
283
+ },
284
+ )
285
+ DEFAULT = IMAGENET1K_V1
286
+
287
+
288
+ class MNASNet1_3_Weights(WeightsEnum):
289
+ IMAGENET1K_V1 = Weights(
290
+ url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
291
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
292
+ meta={
293
+ **_COMMON_META,
294
+ "recipe": "https://github.com/pytorch/vision/pull/6019",
295
+ "num_params": 6282256,
296
+ "_metrics": {
297
+ "ImageNet-1K": {
298
+ "acc@1": 76.506,
299
+ "acc@5": 93.522,
300
+ }
301
+ },
302
+ "_ops": 0.526,
303
+ "_file_size": 24.246,
304
+ "_docs": """
305
+ These weights were trained from scratch by using TorchVision's `new training recipe
306
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
307
+ """,
308
+ },
309
+ )
310
+ DEFAULT = IMAGENET1K_V1
311
+
312
+
313
+ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
314
+ if weights is not None:
315
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
316
+
317
+ model = MNASNet(alpha, **kwargs)
318
+
319
+ if weights:
320
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
321
+
322
+ return model
323
+
324
+
325
+ @register_model()
326
+ @handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
327
+ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
328
+ """MNASNet with depth multiplier of 0.5 from
329
+ `MnasNet: Platform-Aware Neural Architecture Search for Mobile
330
+ <https://arxiv.org/abs/1807.11626>`_ paper.
331
+
332
+ Args:
333
+ weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
334
+ pretrained weights to use. See
335
+ :class:`~torchvision.models.MNASNet0_5_Weights` below for
336
+ more details, and possible values. By default, no pre-trained
337
+ weights are used.
338
+ progress (bool, optional): If True, displays a progress bar of the
339
+ download to stderr. Default is True.
340
+ **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
341
+ base class. Please refer to the `source code
342
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
343
+ for more details about this class.
344
+
345
+ .. autoclass:: torchvision.models.MNASNet0_5_Weights
346
+ :members:
347
+ """
348
+ weights = MNASNet0_5_Weights.verify(weights)
349
+
350
+ return _mnasnet(0.5, weights, progress, **kwargs)
351
+
352
+
353
+ @register_model()
354
+ @handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
355
+ def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
356
+ """MNASNet with depth multiplier of 0.75 from
357
+ `MnasNet: Platform-Aware Neural Architecture Search for Mobile
358
+ <https://arxiv.org/abs/1807.11626>`_ paper.
359
+
360
+ Args:
361
+ weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
362
+ pretrained weights to use. See
363
+ :class:`~torchvision.models.MNASNet0_75_Weights` below for
364
+ more details, and possible values. By default, no pre-trained
365
+ weights are used.
366
+ progress (bool, optional): If True, displays a progress bar of the
367
+ download to stderr. Default is True.
368
+ **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
369
+ base class. Please refer to the `source code
370
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
371
+ for more details about this class.
372
+
373
+ .. autoclass:: torchvision.models.MNASNet0_75_Weights
374
+ :members:
375
+ """
376
+ weights = MNASNet0_75_Weights.verify(weights)
377
+
378
+ return _mnasnet(0.75, weights, progress, **kwargs)
379
+
380
+
381
+ @register_model()
382
+ @handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
383
+ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
384
+ """MNASNet with depth multiplier of 1.0 from
385
+ `MnasNet: Platform-Aware Neural Architecture Search for Mobile
386
+ <https://arxiv.org/abs/1807.11626>`_ paper.
387
+
388
+ Args:
389
+ weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
390
+ pretrained weights to use. See
391
+ :class:`~torchvision.models.MNASNet1_0_Weights` below for
392
+ more details, and possible values. By default, no pre-trained
393
+ weights are used.
394
+ progress (bool, optional): If True, displays a progress bar of the
395
+ download to stderr. Default is True.
396
+ **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
397
+ base class. Please refer to the `source code
398
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
399
+ for more details about this class.
400
+
401
+ .. autoclass:: torchvision.models.MNASNet1_0_Weights
402
+ :members:
403
+ """
404
+ weights = MNASNet1_0_Weights.verify(weights)
405
+
406
+ return _mnasnet(1.0, weights, progress, **kwargs)
407
+
408
+
409
+ @register_model()
410
+ @handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
411
+ def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
412
+ """MNASNet with depth multiplier of 1.3 from
413
+ `MnasNet: Platform-Aware Neural Architecture Search for Mobile
414
+ <https://arxiv.org/abs/1807.11626>`_ paper.
415
+
416
+ Args:
417
+ weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
418
+ pretrained weights to use. See
419
+ :class:`~torchvision.models.MNASNet1_3_Weights` below for
420
+ more details, and possible values. By default, no pre-trained
421
+ weights are used.
422
+ progress (bool, optional): If True, displays a progress bar of the
423
+ download to stderr. Default is True.
424
+ **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
425
+ base class. Please refer to the `source code
426
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
427
+ for more details about this class.
428
+
429
+ .. autoclass:: torchvision.models.MNASNet1_3_Weights
430
+ :members:
431
+ """
432
+ weights = MNASNet1_3_Weights.verify(weights)
433
+
434
+ return _mnasnet(1.3, weights, progress, **kwargs)
.venv/lib/python3.11/site-packages/torchvision/models/mobilenet.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .mobilenetv2 import * # noqa: F401, F403
2
+ from .mobilenetv3 import * # noqa: F401, F403
3
+ from .mobilenetv2 import __all__ as mv2_all
4
+ from .mobilenetv3 import __all__ as mv3_all
5
+
6
+ __all__ = mv2_all + mv3_all
.venv/lib/python3.11/site-packages/torchvision/models/mobilenetv2.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, List, Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.misc import Conv2dNormActivation
8
+ from ..transforms._presets import ImageClassification
9
+ from ..utils import _log_api_usage_once
10
+ from ._api import register_model, Weights, WeightsEnum
11
+ from ._meta import _IMAGENET_CATEGORIES
12
+ from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
13
+
14
+
15
+ __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
16
+
17
+
18
+ # necessary for backwards compatibility
19
+ class InvertedResidual(nn.Module):
20
+ def __init__(
21
+ self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
22
+ ) -> None:
23
+ super().__init__()
24
+ self.stride = stride
25
+ if stride not in [1, 2]:
26
+ raise ValueError(f"stride should be 1 or 2 instead of {stride}")
27
+
28
+ if norm_layer is None:
29
+ norm_layer = nn.BatchNorm2d
30
+
31
+ hidden_dim = int(round(inp * expand_ratio))
32
+ self.use_res_connect = self.stride == 1 and inp == oup
33
+
34
+ layers: List[nn.Module] = []
35
+ if expand_ratio != 1:
36
+ # pw
37
+ layers.append(
38
+ Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
39
+ )
40
+ layers.extend(
41
+ [
42
+ # dw
43
+ Conv2dNormActivation(
44
+ hidden_dim,
45
+ hidden_dim,
46
+ stride=stride,
47
+ groups=hidden_dim,
48
+ norm_layer=norm_layer,
49
+ activation_layer=nn.ReLU6,
50
+ ),
51
+ # pw-linear
52
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
53
+ norm_layer(oup),
54
+ ]
55
+ )
56
+ self.conv = nn.Sequential(*layers)
57
+ self.out_channels = oup
58
+ self._is_cn = stride > 1
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ if self.use_res_connect:
62
+ return x + self.conv(x)
63
+ else:
64
+ return self.conv(x)
65
+
66
+
67
+ class MobileNetV2(nn.Module):
68
+ def __init__(
69
+ self,
70
+ num_classes: int = 1000,
71
+ width_mult: float = 1.0,
72
+ inverted_residual_setting: Optional[List[List[int]]] = None,
73
+ round_nearest: int = 8,
74
+ block: Optional[Callable[..., nn.Module]] = None,
75
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
76
+ dropout: float = 0.2,
77
+ ) -> None:
78
+ """
79
+ MobileNet V2 main class
80
+
81
+ Args:
82
+ num_classes (int): Number of classes
83
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
84
+ inverted_residual_setting: Network structure
85
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
86
+ Set to 1 to turn off rounding
87
+ block: Module specifying inverted residual building block for mobilenet
88
+ norm_layer: Module specifying the normalization layer to use
89
+ dropout (float): The droupout probability
90
+
91
+ """
92
+ super().__init__()
93
+ _log_api_usage_once(self)
94
+
95
+ if block is None:
96
+ block = InvertedResidual
97
+
98
+ if norm_layer is None:
99
+ norm_layer = nn.BatchNorm2d
100
+
101
+ input_channel = 32
102
+ last_channel = 1280
103
+
104
+ if inverted_residual_setting is None:
105
+ inverted_residual_setting = [
106
+ # t, c, n, s
107
+ [1, 16, 1, 1],
108
+ [6, 24, 2, 2],
109
+ [6, 32, 3, 2],
110
+ [6, 64, 4, 2],
111
+ [6, 96, 3, 1],
112
+ [6, 160, 3, 2],
113
+ [6, 320, 1, 1],
114
+ ]
115
+
116
+ # only check the first element, assuming user knows t,c,n,s are required
117
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
118
+ raise ValueError(
119
+ f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
120
+ )
121
+
122
+ # building first layer
123
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
124
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
125
+ features: List[nn.Module] = [
126
+ Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
127
+ ]
128
+ # building inverted residual blocks
129
+ for t, c, n, s in inverted_residual_setting:
130
+ output_channel = _make_divisible(c * width_mult, round_nearest)
131
+ for i in range(n):
132
+ stride = s if i == 0 else 1
133
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
134
+ input_channel = output_channel
135
+ # building last several layers
136
+ features.append(
137
+ Conv2dNormActivation(
138
+ input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
139
+ )
140
+ )
141
+ # make it nn.Sequential
142
+ self.features = nn.Sequential(*features)
143
+
144
+ # building classifier
145
+ self.classifier = nn.Sequential(
146
+ nn.Dropout(p=dropout),
147
+ nn.Linear(self.last_channel, num_classes),
148
+ )
149
+
150
+ # weight initialization
151
+ for m in self.modules():
152
+ if isinstance(m, nn.Conv2d):
153
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
154
+ if m.bias is not None:
155
+ nn.init.zeros_(m.bias)
156
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
157
+ nn.init.ones_(m.weight)
158
+ nn.init.zeros_(m.bias)
159
+ elif isinstance(m, nn.Linear):
160
+ nn.init.normal_(m.weight, 0, 0.01)
161
+ nn.init.zeros_(m.bias)
162
+
163
+ def _forward_impl(self, x: Tensor) -> Tensor:
164
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
165
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
166
+ x = self.features(x)
167
+ # Cannot use "squeeze" as batch-size can be 1
168
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
169
+ x = torch.flatten(x, 1)
170
+ x = self.classifier(x)
171
+ return x
172
+
173
+ def forward(self, x: Tensor) -> Tensor:
174
+ return self._forward_impl(x)
175
+
176
+
177
+ _COMMON_META = {
178
+ "num_params": 3504872,
179
+ "min_size": (1, 1),
180
+ "categories": _IMAGENET_CATEGORIES,
181
+ }
182
+
183
+
184
+ class MobileNet_V2_Weights(WeightsEnum):
185
+ IMAGENET1K_V1 = Weights(
186
+ url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
187
+ transforms=partial(ImageClassification, crop_size=224),
188
+ meta={
189
+ **_COMMON_META,
190
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
191
+ "_metrics": {
192
+ "ImageNet-1K": {
193
+ "acc@1": 71.878,
194
+ "acc@5": 90.286,
195
+ }
196
+ },
197
+ "_ops": 0.301,
198
+ "_file_size": 13.555,
199
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
200
+ },
201
+ )
202
+ IMAGENET1K_V2 = Weights(
203
+ url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
204
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
205
+ meta={
206
+ **_COMMON_META,
207
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
208
+ "_metrics": {
209
+ "ImageNet-1K": {
210
+ "acc@1": 72.154,
211
+ "acc@5": 90.822,
212
+ }
213
+ },
214
+ "_ops": 0.301,
215
+ "_file_size": 13.598,
216
+ "_docs": """
217
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
218
+ `new training recipe
219
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
220
+ """,
221
+ },
222
+ )
223
+ DEFAULT = IMAGENET1K_V2
224
+
225
+
226
+ @register_model()
227
+ @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
228
+ def mobilenet_v2(
229
+ *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
230
+ ) -> MobileNetV2:
231
+ """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
232
+ Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
233
+
234
+ Args:
235
+ weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
236
+ pretrained weights to use. See
237
+ :class:`~torchvision.models.MobileNet_V2_Weights` below for
238
+ more details, and possible values. By default, no pre-trained
239
+ weights are used.
240
+ progress (bool, optional): If True, displays a progress bar of the
241
+ download to stderr. Default is True.
242
+ **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
243
+ base class. Please refer to the `source code
244
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
245
+ for more details about this class.
246
+
247
+ .. autoclass:: torchvision.models.MobileNet_V2_Weights
248
+ :members:
249
+ """
250
+ weights = MobileNet_V2_Weights.verify(weights)
251
+
252
+ if weights is not None:
253
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
254
+
255
+ model = MobileNetV2(**kwargs)
256
+
257
+ if weights is not None:
258
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
259
+
260
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .raft import *
.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (233 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (3.59 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/raft.cpython-311.pyc ADDED
Binary file (44.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+
7
+
8
+ def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
9
+ """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
10
+ h, w = img.shape[-2:]
11
+
12
+ xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
13
+ xgrid = 2 * xgrid / (w - 1) - 1
14
+ # Adding condition if h > 1 to enable this function be reused in raft-stereo
15
+ if h > 1:
16
+ ygrid = 2 * ygrid / (h - 1) - 1
17
+ normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
18
+
19
+ return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
20
+
21
+
22
+ def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
23
+ device = torch.device(device)
24
+ coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
25
+ coords = torch.stack(coords[::-1], dim=0).float()
26
+ return coords[None].repeat(batch_size, 1, 1, 1)
27
+
28
+
29
+ def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8):
30
+ """Upsample flow by the input factor (default 8).
31
+
32
+ If up_mask is None we just interpolate.
33
+ If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
34
+ Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
35
+ """
36
+ batch_size, num_channels, h, w = flow.shape
37
+ new_h, new_w = h * factor, w * factor
38
+
39
+ if up_mask is None:
40
+ return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
41
+
42
+ up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w)
43
+ up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1
44
+
45
+ upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w)
46
+ upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
47
+
48
+ return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w)
.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/raft.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from torch.nn.modules.batchnorm import BatchNorm2d
8
+ from torch.nn.modules.instancenorm import InstanceNorm2d
9
+ from torchvision.ops import Conv2dNormActivation
10
+
11
+ from ...transforms._presets import OpticalFlow
12
+ from ...utils import _log_api_usage_once
13
+ from .._api import register_model, Weights, WeightsEnum
14
+ from .._utils import handle_legacy_interface
15
+ from ._utils import grid_sample, make_coords_grid, upsample_flow
16
+
17
+
18
+ __all__ = (
19
+ "RAFT",
20
+ "raft_large",
21
+ "raft_small",
22
+ "Raft_Large_Weights",
23
+ "Raft_Small_Weights",
24
+ )
25
+
26
+
27
+ class ResidualBlock(nn.Module):
28
+ """Slightly modified Residual block with extra relu and biases."""
29
+
30
+ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False):
31
+ super().__init__()
32
+
33
+ # Note regarding bias=True:
34
+ # Usually we can pass bias=False in conv layers followed by a norm layer.
35
+ # But in the RAFT training reference, the BatchNorm2d layers are only activated for the first dataset,
36
+ # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful
37
+ # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm
38
+ # because these aren't frozen, but we don't bother (also, we wouldn't be able to load the original weights).
39
+ self.convnormrelu1 = Conv2dNormActivation(
40
+ in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
41
+ )
42
+ self.convnormrelu2 = Conv2dNormActivation(
43
+ out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
44
+ )
45
+
46
+ # make mypy happy
47
+ self.downsample: nn.Module
48
+
49
+ if stride == 1 and not always_project:
50
+ self.downsample = nn.Identity()
51
+ else:
52
+ self.downsample = Conv2dNormActivation(
53
+ in_channels,
54
+ out_channels,
55
+ norm_layer=norm_layer,
56
+ kernel_size=1,
57
+ stride=stride,
58
+ bias=True,
59
+ activation_layer=None,
60
+ )
61
+
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ def forward(self, x):
65
+ y = x
66
+ y = self.convnormrelu1(y)
67
+ y = self.convnormrelu2(y)
68
+
69
+ x = self.downsample(x)
70
+
71
+ return self.relu(x + y)
72
+
73
+
74
+ class BottleneckBlock(nn.Module):
75
+ """Slightly modified BottleNeck block (extra relu and biases)"""
76
+
77
+ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
78
+ super().__init__()
79
+
80
+ # See note in ResidualBlock for the reason behind bias=True
81
+ self.convnormrelu1 = Conv2dNormActivation(
82
+ in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True
83
+ )
84
+ self.convnormrelu2 = Conv2dNormActivation(
85
+ out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
86
+ )
87
+ self.convnormrelu3 = Conv2dNormActivation(
88
+ out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True
89
+ )
90
+ self.relu = nn.ReLU(inplace=True)
91
+
92
+ if stride == 1:
93
+ self.downsample = nn.Identity()
94
+ else:
95
+ self.downsample = Conv2dNormActivation(
96
+ in_channels,
97
+ out_channels,
98
+ norm_layer=norm_layer,
99
+ kernel_size=1,
100
+ stride=stride,
101
+ bias=True,
102
+ activation_layer=None,
103
+ )
104
+
105
+ def forward(self, x):
106
+ y = x
107
+ y = self.convnormrelu1(y)
108
+ y = self.convnormrelu2(y)
109
+ y = self.convnormrelu3(y)
110
+
111
+ x = self.downsample(x)
112
+
113
+ return self.relu(x + y)
114
+
115
+
116
+ class FeatureEncoder(nn.Module):
117
+ """The feature encoder, used both as the actual feature encoder, and as the context encoder.
118
+
119
+ It must downsample its input by 8.
120
+ """
121
+
122
+ def __init__(
123
+ self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d
124
+ ):
125
+ super().__init__()
126
+
127
+ if len(layers) != 5:
128
+ raise ValueError(f"The expected number of layers is 5, instead got {len(layers)}")
129
+
130
+ # See note in ResidualBlock for the reason behind bias=True
131
+ self.convnormrelu = Conv2dNormActivation(
132
+ 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=strides[0], bias=True
133
+ )
134
+
135
+ self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=strides[1])
136
+ self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=strides[2])
137
+ self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=strides[3])
138
+
139
+ self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1)
140
+
141
+ for m in self.modules():
142
+ if isinstance(m, nn.Conv2d):
143
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
144
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
145
+ if m.weight is not None:
146
+ nn.init.constant_(m.weight, 1)
147
+ if m.bias is not None:
148
+ nn.init.constant_(m.bias, 0)
149
+
150
+ num_downsamples = len(list(filter(lambda s: s == 2, strides)))
151
+ self.output_dim = layers[-1]
152
+ self.downsample_factor = 2**num_downsamples
153
+
154
+ def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
155
+ block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
156
+ block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1)
157
+ return nn.Sequential(block1, block2)
158
+
159
+ def forward(self, x):
160
+ x = self.convnormrelu(x)
161
+
162
+ x = self.layer1(x)
163
+ x = self.layer2(x)
164
+ x = self.layer3(x)
165
+
166
+ x = self.conv(x)
167
+
168
+ return x
169
+
170
+
171
+ class MotionEncoder(nn.Module):
172
+ """The motion encoder, part of the update block.
173
+
174
+ Takes the current predicted flow and the correlation features as input and returns an encoded version of these.
175
+ """
176
+
177
+ def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128):
178
+ super().__init__()
179
+
180
+ if len(flow_layers) != 2:
181
+ raise ValueError(f"The expected number of flow_layers is 2, instead got {len(flow_layers)}")
182
+ if len(corr_layers) not in (1, 2):
183
+ raise ValueError(f"The number of corr_layers should be 1 or 2, instead got {len(corr_layers)}")
184
+
185
+ self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1)
186
+ if len(corr_layers) == 2:
187
+ self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3)
188
+ else:
189
+ self.convcorr2 = nn.Identity()
190
+
191
+ self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7)
192
+ self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3)
193
+
194
+ # out_channels - 2 because we cat the flow (2 channels) at the end
195
+ self.conv = Conv2dNormActivation(
196
+ corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3
197
+ )
198
+
199
+ self.out_channels = out_channels
200
+
201
+ def forward(self, flow, corr_features):
202
+ corr = self.convcorr1(corr_features)
203
+ corr = self.convcorr2(corr)
204
+
205
+ flow_orig = flow
206
+ flow = self.convflow1(flow)
207
+ flow = self.convflow2(flow)
208
+
209
+ corr_flow = torch.cat([corr, flow], dim=1)
210
+ corr_flow = self.conv(corr_flow)
211
+ return torch.cat([corr_flow, flow_orig], dim=1)
212
+
213
+
214
+ class ConvGRU(nn.Module):
215
+ """Convolutional Gru unit."""
216
+
217
+ def __init__(self, *, input_size, hidden_size, kernel_size, padding):
218
+ super().__init__()
219
+ self.convz = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
220
+ self.convr = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
221
+ self.convq = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
222
+
223
+ def forward(self, h, x):
224
+ hx = torch.cat([h, x], dim=1)
225
+ z = torch.sigmoid(self.convz(hx))
226
+ r = torch.sigmoid(self.convr(hx))
227
+ q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
228
+ h = (1 - z) * h + z * q
229
+ return h
230
+
231
+
232
+ def _pass_through_h(h, _):
233
+ # Declared here for torchscript
234
+ return h
235
+
236
+
237
+ class RecurrentBlock(nn.Module):
238
+ """Recurrent block, part of the update block.
239
+
240
+ Takes the current hidden state and the concatenation of (motion encoder output, context) as input.
241
+ Returns an updated hidden state.
242
+ """
243
+
244
+ def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))):
245
+ super().__init__()
246
+
247
+ if len(kernel_size) != len(padding):
248
+ raise ValueError(
249
+ f"kernel_size should have the same length as padding, instead got len(kernel_size) = {len(kernel_size)} and len(padding) = {len(padding)}"
250
+ )
251
+ if len(kernel_size) not in (1, 2):
252
+ raise ValueError(f"kernel_size should either 1 or 2, instead got {len(kernel_size)}")
253
+
254
+ self.convgru1 = ConvGRU(
255
+ input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0]
256
+ )
257
+ if len(kernel_size) == 2:
258
+ self.convgru2 = ConvGRU(
259
+ input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[1], padding=padding[1]
260
+ )
261
+ else:
262
+ self.convgru2 = _pass_through_h
263
+
264
+ self.hidden_size = hidden_size
265
+
266
+ def forward(self, h, x):
267
+ h = self.convgru1(h, x)
268
+ h = self.convgru2(h, x)
269
+ return h
270
+
271
+
272
+ class FlowHead(nn.Module):
273
+ """Flow head, part of the update block.
274
+
275
+ Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow".
276
+ """
277
+
278
+ def __init__(self, *, in_channels, hidden_size):
279
+ super().__init__()
280
+ self.conv1 = nn.Conv2d(in_channels, hidden_size, 3, padding=1)
281
+ self.conv2 = nn.Conv2d(hidden_size, 2, 3, padding=1)
282
+ self.relu = nn.ReLU(inplace=True)
283
+
284
+ def forward(self, x):
285
+ return self.conv2(self.relu(self.conv1(x)))
286
+
287
+
288
+ class UpdateBlock(nn.Module):
289
+ """The update block which contains the motion encoder, the recurrent block, and the flow head.
290
+
291
+ It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block.
292
+ """
293
+
294
+ def __init__(self, *, motion_encoder, recurrent_block, flow_head):
295
+ super().__init__()
296
+ self.motion_encoder = motion_encoder
297
+ self.recurrent_block = recurrent_block
298
+ self.flow_head = flow_head
299
+
300
+ self.hidden_state_size = recurrent_block.hidden_size
301
+
302
+ def forward(self, hidden_state, context, corr_features, flow):
303
+ motion_features = self.motion_encoder(flow, corr_features)
304
+ x = torch.cat([context, motion_features], dim=1)
305
+
306
+ hidden_state = self.recurrent_block(hidden_state, x)
307
+ delta_flow = self.flow_head(hidden_state)
308
+ return hidden_state, delta_flow
309
+
310
+
311
+ class MaskPredictor(nn.Module):
312
+ """Mask predictor to be used when upsampling the predicted flow.
313
+
314
+ It takes the hidden state of the recurrent unit as input and outputs the mask.
315
+ This is not used in the raft-small model.
316
+ """
317
+
318
+ def __init__(self, *, in_channels, hidden_size, multiplier=0.25):
319
+ super().__init__()
320
+ self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
321
+ # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder,
322
+ # and we interpolate with all 9 surrounding neighbors. See paper and appendix B.
323
+ self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0)
324
+
325
+ # In the original code, they use a factor of 0.25 to "downweight the gradients" of that branch.
326
+ # See e.g. https://github.com/princeton-vl/RAFT/issues/119#issuecomment-953950419
327
+ # or https://github.com/princeton-vl/RAFT/issues/24.
328
+ # It doesn't seem to affect epe significantly and can likely be set to 1.
329
+ self.multiplier = multiplier
330
+
331
+ def forward(self, x):
332
+ x = self.convrelu(x)
333
+ x = self.conv(x)
334
+ return self.multiplier * x
335
+
336
+
337
+ class CorrBlock(nn.Module):
338
+ """The correlation block.
339
+
340
+ Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder,
341
+ and then indexes from this pyramid to create correlation features.
342
+ The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that
343
+ are within a ``radius``, according to the infinity norm (see paper section 3.2).
344
+ Note: typo in the paper, it should be infinity norm, not 1-norm.
345
+ """
346
+
347
+ def __init__(self, *, num_levels: int = 4, radius: int = 4):
348
+ super().__init__()
349
+ self.num_levels = num_levels
350
+ self.radius = radius
351
+
352
+ self.corr_pyramid: List[Tensor] = [torch.tensor(0)] # useless, but torchscript is otherwise confused :')
353
+
354
+ # The neighborhood of a centroid pixel x' is {x' + delta, ||delta||_inf <= radius}
355
+ # so it's a square surrounding x', and its sides have a length of 2 * radius + 1
356
+ # The paper claims that it's ||.||_1 instead of ||.||_inf but it's a typo:
357
+ # https://github.com/princeton-vl/RAFT/issues/122
358
+ self.out_channels = num_levels * (2 * radius + 1) ** 2
359
+
360
+ def build_pyramid(self, fmap1, fmap2):
361
+ """Build the correlation pyramid from two feature maps.
362
+
363
+ The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2)
364
+ The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
365
+ to build the correlation pyramid.
366
+ """
367
+
368
+ if fmap1.shape != fmap2.shape:
369
+ raise ValueError(
370
+ f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)"
371
+ )
372
+
373
+ # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2.
374
+ # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would
375
+ # produce nans in its output.
376
+ min_fmap_size = 2 * (2 ** (self.num_levels - 1))
377
+ if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]):
378
+ raise ValueError(
379
+ "Feature maps are too small to be down-sampled by the correlation pyramid. "
380
+ f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. "
381
+ "Remember that input images to the model are downsampled by 8, so that means their "
382
+ f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}."
383
+ )
384
+
385
+ corr_volume = self._compute_corr_volume(fmap1, fmap2)
386
+
387
+ batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w
388
+ corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w)
389
+ self.corr_pyramid = [corr_volume]
390
+ for _ in range(self.num_levels - 1):
391
+ corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2)
392
+ self.corr_pyramid.append(corr_volume)
393
+
394
+ def index_pyramid(self, centroids_coords):
395
+ """Return correlation features by indexing from the pyramid."""
396
+ neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels
397
+ di = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
398
+ dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
399
+ delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device)
400
+ delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2)
401
+
402
+ batch_size, _, h, w = centroids_coords.shape # _ = 2
403
+ centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2)
404
+
405
+ indexed_pyramid = []
406
+ for corr_volume in self.corr_pyramid:
407
+ sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2)
408
+ indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view(
409
+ batch_size, h, w, -1
410
+ )
411
+ indexed_pyramid.append(indexed_corr_volume)
412
+ centroids_coords = centroids_coords / 2
413
+
414
+ corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
415
+
416
+ expected_output_shape = (batch_size, self.out_channels, h, w)
417
+ if corr_features.shape != expected_output_shape:
418
+ raise ValueError(
419
+ f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}"
420
+ )
421
+
422
+ return corr_features
423
+
424
+ def _compute_corr_volume(self, fmap1, fmap2):
425
+ batch_size, num_channels, h, w = fmap1.shape
426
+ fmap1 = fmap1.view(batch_size, num_channels, h * w)
427
+ fmap2 = fmap2.view(batch_size, num_channels, h * w)
428
+
429
+ corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
430
+ corr = corr.view(batch_size, h, w, 1, h, w)
431
+ return corr / torch.sqrt(torch.tensor(num_channels))
432
+
433
+
434
+ class RAFT(nn.Module):
435
+ def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block, mask_predictor=None):
436
+ """RAFT model from
437
+ `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
438
+
439
+ args:
440
+ feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8.
441
+ Its input is the concatenation of ``image1`` and ``image2``.
442
+ context_encoder (nn.Module): The context encoder. It must downsample the input by 8.
443
+ Its input is ``image1``. As in the original implementation, its output will be split into 2 parts:
444
+
445
+ - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
446
+ - one part will be used to initialize the hidden state of the recurrent unit of
447
+ the ``update_block``
448
+
449
+ These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output
450
+ of the ``context_encoder`` must be strictly greater than ``hidden_state_size``.
451
+
452
+ corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the
453
+ ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose
454
+ 2 methods:
455
+
456
+ - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the
457
+ output of the ``feature_encoder``).
458
+ - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns
459
+ the correlation features. See paper section 3.2.
460
+
461
+ It must expose an ``out_channels`` attribute.
462
+
463
+ update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the
464
+ flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation
465
+ features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow``
466
+ prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute.
467
+ mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
468
+ The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B.
469
+ If ``None`` (default), the flow is upsampled using interpolation.
470
+ """
471
+ super().__init__()
472
+ _log_api_usage_once(self)
473
+
474
+ self.feature_encoder = feature_encoder
475
+ self.context_encoder = context_encoder
476
+ self.corr_block = corr_block
477
+ self.update_block = update_block
478
+
479
+ self.mask_predictor = mask_predictor
480
+
481
+ if not hasattr(self.update_block, "hidden_state_size"):
482
+ raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.")
483
+
484
+ def forward(self, image1, image2, num_flow_updates: int = 12):
485
+
486
+ batch_size, _, h, w = image1.shape
487
+ if (h, w) != image2.shape[-2:]:
488
+ raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
489
+ if not (h % 8 == 0) and (w % 8 == 0):
490
+ raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
491
+
492
+ fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
493
+ fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
494
+ if fmap1.shape[-2:] != (h // 8, w // 8):
495
+ raise ValueError("The feature encoder should downsample H and W by 8")
496
+
497
+ self.corr_block.build_pyramid(fmap1, fmap2)
498
+
499
+ context_out = self.context_encoder(image1)
500
+ if context_out.shape[-2:] != (h // 8, w // 8):
501
+ raise ValueError("The context encoder should downsample H and W by 8")
502
+
503
+ # As in the original paper, the actual output of the context encoder is split in 2 parts:
504
+ # - one part is used to initialize the hidden state of the recurent units of the update block
505
+ # - the rest is the "actual" context.
506
+ hidden_state_size = self.update_block.hidden_state_size
507
+ out_channels_context = context_out.shape[1] - hidden_state_size
508
+ if out_channels_context <= 0:
509
+ raise ValueError(
510
+ f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels"
511
+ )
512
+ hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1)
513
+ hidden_state = torch.tanh(hidden_state)
514
+ context = F.relu(context)
515
+
516
+ coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
517
+ coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
518
+
519
+ flow_predictions = []
520
+ for _ in range(num_flow_updates):
521
+ coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
522
+ corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)
523
+
524
+ flow = coords1 - coords0
525
+ hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)
526
+
527
+ coords1 = coords1 + delta_flow
528
+
529
+ up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
530
+ upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
531
+ flow_predictions.append(upsampled_flow)
532
+
533
+ return flow_predictions
534
+
535
+
536
+ _COMMON_META = {
537
+ "min_size": (128, 128),
538
+ }
539
+
540
+
541
+ class Raft_Large_Weights(WeightsEnum):
542
+ """The metrics reported here are as follows.
543
+
544
+ ``epe`` is the "end-point-error" and indicates how far (in pixels) the
545
+ predicted flow is from its true value. This is averaged over all pixels
546
+ of all images. ``per_image_epe`` is similar, but the average is different:
547
+ the epe is first computed on each image independently, and then averaged
548
+ over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
549
+ in the original paper, and it's only used on Kitti. ``fl-all`` is also a
550
+ Kitti-specific metric, defined by the author of the dataset and used for the
551
+ Kitti leaderboard. It corresponds to the average of pixels whose epe is
552
+ either <3px, or <5% of flow's 2-norm.
553
+ """
554
+
555
+ C_T_V1 = Weights(
556
+ # Weights ported from https://github.com/princeton-vl/RAFT
557
+ url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
558
+ transforms=OpticalFlow,
559
+ meta={
560
+ **_COMMON_META,
561
+ "num_params": 5257536,
562
+ "recipe": "https://github.com/princeton-vl/RAFT",
563
+ "_metrics": {
564
+ "Sintel-Train-Cleanpass": {"epe": 1.4411},
565
+ "Sintel-Train-Finalpass": {"epe": 2.7894},
566
+ "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506},
567
+ },
568
+ "_ops": 211.007,
569
+ "_file_size": 20.129,
570
+ "_docs": """These weights were ported from the original paper. They
571
+ are trained on :class:`~torchvision.datasets.FlyingChairs` +
572
+ :class:`~torchvision.datasets.FlyingThings3D`.""",
573
+ },
574
+ )
575
+
576
+ C_T_V2 = Weights(
577
+ url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
578
+ transforms=OpticalFlow,
579
+ meta={
580
+ **_COMMON_META,
581
+ "num_params": 5257536,
582
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
583
+ "_metrics": {
584
+ "Sintel-Train-Cleanpass": {"epe": 1.3822},
585
+ "Sintel-Train-Finalpass": {"epe": 2.7161},
586
+ "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679},
587
+ },
588
+ "_ops": 211.007,
589
+ "_file_size": 20.129,
590
+ "_docs": """These weights were trained from scratch on
591
+ :class:`~torchvision.datasets.FlyingChairs` +
592
+ :class:`~torchvision.datasets.FlyingThings3D`.""",
593
+ },
594
+ )
595
+
596
+ C_T_SKHT_V1 = Weights(
597
+ # Weights ported from https://github.com/princeton-vl/RAFT
598
+ url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
599
+ transforms=OpticalFlow,
600
+ meta={
601
+ **_COMMON_META,
602
+ "num_params": 5257536,
603
+ "recipe": "https://github.com/princeton-vl/RAFT",
604
+ "_metrics": {
605
+ "Sintel-Test-Cleanpass": {"epe": 1.94},
606
+ "Sintel-Test-Finalpass": {"epe": 3.18},
607
+ },
608
+ "_ops": 211.007,
609
+ "_file_size": 20.129,
610
+ "_docs": """
611
+ These weights were ported from the original paper. They are
612
+ trained on :class:`~torchvision.datasets.FlyingChairs` +
613
+ :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
614
+ Sintel. The Sintel fine-tuning step is a combination of
615
+ :class:`~torchvision.datasets.Sintel`,
616
+ :class:`~torchvision.datasets.KittiFlow`,
617
+ :class:`~torchvision.datasets.HD1K`, and
618
+ :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
619
+ """,
620
+ },
621
+ )
622
+
623
+ C_T_SKHT_V2 = Weights(
624
+ url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
625
+ transforms=OpticalFlow,
626
+ meta={
627
+ **_COMMON_META,
628
+ "num_params": 5257536,
629
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
630
+ "_metrics": {
631
+ "Sintel-Test-Cleanpass": {"epe": 1.819},
632
+ "Sintel-Test-Finalpass": {"epe": 3.067},
633
+ },
634
+ "_ops": 211.007,
635
+ "_file_size": 20.129,
636
+ "_docs": """
637
+ These weights were trained from scratch. They are
638
+ pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
639
+ :class:`~torchvision.datasets.FlyingThings3D` and then
640
+ fine-tuned on Sintel. The Sintel fine-tuning step is a
641
+ combination of :class:`~torchvision.datasets.Sintel`,
642
+ :class:`~torchvision.datasets.KittiFlow`,
643
+ :class:`~torchvision.datasets.HD1K`, and
644
+ :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
645
+ """,
646
+ },
647
+ )
648
+
649
+ C_T_SKHT_K_V1 = Weights(
650
+ # Weights ported from https://github.com/princeton-vl/RAFT
651
+ url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
652
+ transforms=OpticalFlow,
653
+ meta={
654
+ **_COMMON_META,
655
+ "num_params": 5257536,
656
+ "recipe": "https://github.com/princeton-vl/RAFT",
657
+ "_metrics": {
658
+ "Kitti-Test": {"fl_all": 5.10},
659
+ },
660
+ "_ops": 211.007,
661
+ "_file_size": 20.129,
662
+ "_docs": """
663
+ These weights were ported from the original paper. They are
664
+ pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
665
+ :class:`~torchvision.datasets.FlyingThings3D`,
666
+ fine-tuned on Sintel, and then fine-tuned on
667
+ :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
668
+ step was described above.
669
+ """,
670
+ },
671
+ )
672
+
673
+ C_T_SKHT_K_V2 = Weights(
674
+ url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
675
+ transforms=OpticalFlow,
676
+ meta={
677
+ **_COMMON_META,
678
+ "num_params": 5257536,
679
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
680
+ "_metrics": {
681
+ "Kitti-Test": {"fl_all": 5.19},
682
+ },
683
+ "_ops": 211.007,
684
+ "_file_size": 20.129,
685
+ "_docs": """
686
+ These weights were trained from scratch. They are
687
+ pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
688
+ :class:`~torchvision.datasets.FlyingThings3D`,
689
+ fine-tuned on Sintel, and then fine-tuned on
690
+ :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
691
+ step was described above.
692
+ """,
693
+ },
694
+ )
695
+
696
+ DEFAULT = C_T_SKHT_V2
697
+
698
+
699
+ class Raft_Small_Weights(WeightsEnum):
700
+ """The metrics reported here are as follows.
701
+
702
+ ``epe`` is the "end-point-error" and indicates how far (in pixels) the
703
+ predicted flow is from its true value. This is averaged over all pixels
704
+ of all images. ``per_image_epe`` is similar, but the average is different:
705
+ the epe is first computed on each image independently, and then averaged
706
+ over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
707
+ in the original paper, and it's only used on Kitti. ``fl-all`` is also a
708
+ Kitti-specific metric, defined by the author of the dataset and used for the
709
+ Kitti leaderboard. It corresponds to the average of pixels whose epe is
710
+ either <3px, or <5% of flow's 2-norm.
711
+ """
712
+
713
+ C_T_V1 = Weights(
714
+ # Weights ported from https://github.com/princeton-vl/RAFT
715
+ url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
716
+ transforms=OpticalFlow,
717
+ meta={
718
+ **_COMMON_META,
719
+ "num_params": 990162,
720
+ "recipe": "https://github.com/princeton-vl/RAFT",
721
+ "_metrics": {
722
+ "Sintel-Train-Cleanpass": {"epe": 2.1231},
723
+ "Sintel-Train-Finalpass": {"epe": 3.2790},
724
+ "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801},
725
+ },
726
+ "_ops": 47.655,
727
+ "_file_size": 3.821,
728
+ "_docs": """These weights were ported from the original paper. They
729
+ are trained on :class:`~torchvision.datasets.FlyingChairs` +
730
+ :class:`~torchvision.datasets.FlyingThings3D`.""",
731
+ },
732
+ )
733
+ C_T_V2 = Weights(
734
+ url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
735
+ transforms=OpticalFlow,
736
+ meta={
737
+ **_COMMON_META,
738
+ "num_params": 990162,
739
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
740
+ "_metrics": {
741
+ "Sintel-Train-Cleanpass": {"epe": 1.9901},
742
+ "Sintel-Train-Finalpass": {"epe": 3.2831},
743
+ "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369},
744
+ },
745
+ "_ops": 47.655,
746
+ "_file_size": 3.821,
747
+ "_docs": """These weights were trained from scratch on
748
+ :class:`~torchvision.datasets.FlyingChairs` +
749
+ :class:`~torchvision.datasets.FlyingThings3D`.""",
750
+ },
751
+ )
752
+
753
+ DEFAULT = C_T_V2
754
+
755
+
756
+ def _raft(
757
+ *,
758
+ weights=None,
759
+ progress=False,
760
+ # Feature encoder
761
+ feature_encoder_layers,
762
+ feature_encoder_block,
763
+ feature_encoder_norm_layer,
764
+ # Context encoder
765
+ context_encoder_layers,
766
+ context_encoder_block,
767
+ context_encoder_norm_layer,
768
+ # Correlation block
769
+ corr_block_num_levels,
770
+ corr_block_radius,
771
+ # Motion encoder
772
+ motion_encoder_corr_layers,
773
+ motion_encoder_flow_layers,
774
+ motion_encoder_out_channels,
775
+ # Recurrent block
776
+ recurrent_block_hidden_state_size,
777
+ recurrent_block_kernel_size,
778
+ recurrent_block_padding,
779
+ # Flow Head
780
+ flow_head_hidden_size,
781
+ # Mask predictor
782
+ use_mask_predictor,
783
+ **kwargs,
784
+ ):
785
+ feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
786
+ block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer
787
+ )
788
+ context_encoder = kwargs.pop("context_encoder", None) or FeatureEncoder(
789
+ block=context_encoder_block, layers=context_encoder_layers, norm_layer=context_encoder_norm_layer
790
+ )
791
+
792
+ corr_block = kwargs.pop("corr_block", None) or CorrBlock(num_levels=corr_block_num_levels, radius=corr_block_radius)
793
+
794
+ update_block = kwargs.pop("update_block", None)
795
+ if update_block is None:
796
+ motion_encoder = MotionEncoder(
797
+ in_channels_corr=corr_block.out_channels,
798
+ corr_layers=motion_encoder_corr_layers,
799
+ flow_layers=motion_encoder_flow_layers,
800
+ out_channels=motion_encoder_out_channels,
801
+ )
802
+
803
+ # See comments in forward pass of RAFT class about why we split the output of the context encoder
804
+ out_channels_context = context_encoder_layers[-1] - recurrent_block_hidden_state_size
805
+ recurrent_block = RecurrentBlock(
806
+ input_size=motion_encoder.out_channels + out_channels_context,
807
+ hidden_size=recurrent_block_hidden_state_size,
808
+ kernel_size=recurrent_block_kernel_size,
809
+ padding=recurrent_block_padding,
810
+ )
811
+
812
+ flow_head = FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size)
813
+
814
+ update_block = UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
815
+
816
+ mask_predictor = kwargs.pop("mask_predictor", None)
817
+ if mask_predictor is None and use_mask_predictor:
818
+ mask_predictor = MaskPredictor(
819
+ in_channels=recurrent_block_hidden_state_size,
820
+ hidden_size=256,
821
+ multiplier=0.25, # See comment in MaskPredictor about this
822
+ )
823
+
824
+ model = RAFT(
825
+ feature_encoder=feature_encoder,
826
+ context_encoder=context_encoder,
827
+ corr_block=corr_block,
828
+ update_block=update_block,
829
+ mask_predictor=mask_predictor,
830
+ **kwargs, # not really needed, all params should be consumed by now
831
+ )
832
+
833
+ if weights is not None:
834
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
835
+
836
+ return model
837
+
838
+
839
+ @register_model()
840
+ @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
841
+ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT:
842
+ """RAFT model from
843
+ `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
844
+
845
+ Please see the example below for a tutorial on how to use this model.
846
+
847
+ Args:
848
+ weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
849
+ pretrained weights to use. See
850
+ :class:`~torchvision.models.optical_flow.Raft_Large_Weights`
851
+ below for more details, and possible values. By default, no
852
+ pre-trained weights are used.
853
+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
854
+ **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
855
+ base class. Please refer to the `source code
856
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
857
+ for more details about this class.
858
+
859
+ .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
860
+ :members:
861
+ """
862
+
863
+ weights = Raft_Large_Weights.verify(weights)
864
+
865
+ return _raft(
866
+ weights=weights,
867
+ progress=progress,
868
+ # Feature encoder
869
+ feature_encoder_layers=(64, 64, 96, 128, 256),
870
+ feature_encoder_block=ResidualBlock,
871
+ feature_encoder_norm_layer=InstanceNorm2d,
872
+ # Context encoder
873
+ context_encoder_layers=(64, 64, 96, 128, 256),
874
+ context_encoder_block=ResidualBlock,
875
+ context_encoder_norm_layer=BatchNorm2d,
876
+ # Correlation block
877
+ corr_block_num_levels=4,
878
+ corr_block_radius=4,
879
+ # Motion encoder
880
+ motion_encoder_corr_layers=(256, 192),
881
+ motion_encoder_flow_layers=(128, 64),
882
+ motion_encoder_out_channels=128,
883
+ # Recurrent block
884
+ recurrent_block_hidden_state_size=128,
885
+ recurrent_block_kernel_size=((1, 5), (5, 1)),
886
+ recurrent_block_padding=((0, 2), (2, 0)),
887
+ # Flow head
888
+ flow_head_hidden_size=256,
889
+ # Mask predictor
890
+ use_mask_predictor=True,
891
+ **kwargs,
892
+ )
893
+
894
+
895
+ @register_model()
896
+ @handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
897
+ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
898
+ """RAFT "small" model from
899
+ `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.
900
+
901
+ Please see the example below for a tutorial on how to use this model.
902
+
903
+ Args:
904
+ weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
905
+ pretrained weights to use. See
906
+ :class:`~torchvision.models.optical_flow.Raft_Small_Weights`
907
+ below for more details, and possible values. By default, no
908
+ pre-trained weights are used.
909
+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
910
+ **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
911
+ base class. Please refer to the `source code
912
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
913
+ for more details about this class.
914
+
915
+ .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
916
+ :members:
917
+ """
918
+ weights = Raft_Small_Weights.verify(weights)
919
+
920
+ return _raft(
921
+ weights=weights,
922
+ progress=progress,
923
+ # Feature encoder
924
+ feature_encoder_layers=(32, 32, 64, 96, 128),
925
+ feature_encoder_block=BottleneckBlock,
926
+ feature_encoder_norm_layer=InstanceNorm2d,
927
+ # Context encoder
928
+ context_encoder_layers=(32, 32, 64, 96, 160),
929
+ context_encoder_block=BottleneckBlock,
930
+ context_encoder_norm_layer=None,
931
+ # Correlation block
932
+ corr_block_num_levels=4,
933
+ corr_block_radius=3,
934
+ # Motion encoder
935
+ motion_encoder_corr_layers=(96,),
936
+ motion_encoder_flow_layers=(64, 32),
937
+ motion_encoder_out_channels=82,
938
+ # Recurrent block
939
+ recurrent_block_hidden_state_size=96,
940
+ recurrent_block_kernel_size=(3,),
941
+ recurrent_block_padding=(1,),
942
+ # Flow head
943
+ flow_head_hidden_size=128,
944
+ # Mask predictor
945
+ use_mask_predictor=False,
946
+ **kwargs,
947
+ )
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .googlenet import *
2
+ from .inception import *
3
+ from .mobilenet import *
4
+ from .resnet import *
5
+ from .shufflenetv2 import *
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (362 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/googlenet.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/inception.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenet.cpython-311.pyc ADDED
Binary file (386 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv2.cpython-311.pyc ADDED
Binary file (9.19 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv3.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (22.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/shufflenetv2.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.39 kB). View file