diff --git a/.gitattributes b/.gitattributes index 7abf9a54e71c696141e88d26a2689df12bfca4e4..ad8b6d0f7298a35b092e6e2056874de1293f1033 100644 --- a/.gitattributes +++ b/.gitattributes @@ -344,3 +344,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/ .venv/lib/python3.11/site-packages/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/torchvision/image.so b/.venv/lib/python3.11/site-packages/torchvision/image.so new file mode 100644 index 0000000000000000000000000000000000000000..722b0f77b14ba65293b3180dc85d4f3ddd4f4148 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/image.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c82377c2c2be60cedf80c171874d8d50d8b09102fe42c20b3a426b7715a1fc4d +size 667281 diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea0a1f7178b6ca03776d58c17411a8ff483f8b2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/__init__.py @@ -0,0 +1,23 @@ +from .alexnet import * +from .convnext import * +from .densenet import * +from .efficientnet import * +from .googlenet import * +from .inception import * +from .mnasnet import * +from .mobilenet import * +from .regnet import * +from .resnet import * +from .shufflenetv2 import * +from .squeezenet import * +from .vgg import * +from .vision_transformer import * +from .swin_transformer import * +from .maxvit import * +from . import detection, optical_flow, quantization, segmentation, video + +# The Weights and WeightsEnum are developer-facing utils that we make public for +# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094 +# TODO: we could / should document them publicly, but it's not clear where, as +# they're not intended for end users. +from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db2f1df8813f3bfbc6fe2d62a356a48552e1a8fc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_api.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acceea034e5b835a6043824cb4e619a7f7016d37 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_api.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_meta.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_meta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d678ae2d42cc1e814f02284dd4bb6327f867493 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_meta.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dffeae85ea53bb9f86cc84908656ea1c0148f1f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/alexnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/alexnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..457246aae3ae10ec880224333179d960eb325b3f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/alexnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/convnext.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/convnext.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e4c3db06c461e432fc31880c742a6526d7f778 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/convnext.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/densenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/densenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156276f318461b2129a967f95c745a9b3fe40158 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/densenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/efficientnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/efficientnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1821ff1a7e0ae892623e9ebb1dd10d5195c5a2d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/efficientnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f2befd80b3aa5bc4dd69e9c1b013f3c3d1bbb54 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/feature_extraction.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/googlenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/googlenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..550eb7b836948868e11d5814d3e0c7bbc4be6320 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/googlenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/inception.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/inception.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5db591433426abee9603087e3fab65bc96f7a640 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/inception.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/maxvit.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/maxvit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3eb6f8ebfa2691a390ee1ca4424dc091be22da8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/maxvit.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mnasnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mnasnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f2fd0ab2847b9c3a5f2b2b692f087838822ce57 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mnasnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba855e0f57d04a08609fc6508fef5afc6d3dd3ec Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8738c2cf43c6b034c2b56066c07a88391b740817 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..193297b47987796778d1f56f81ca56b548f017b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/mobilenetv3.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/regnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/regnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64bfafbf4c7e0e7b0d32edfee4813fa278cf31bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/regnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/resnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8b4e3097efd9abfdad6b9c1ee109d437d7ef50b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/resnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..928b5b1b849af816475f5cf968b34bfb4085b685 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/shufflenetv2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/squeezenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/squeezenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bd9dc9d87ac321cc1b285a472f5c9aef46e3738 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/squeezenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe35e095b9fceb110143a390ba0a5df37c683294 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/swin_transformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vgg.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vgg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..426f932d4387e9723929a12cbb02b4bf20d8136e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vgg.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf2b5b36340ebe3d29711bf05350b514ad919c1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/__pycache__/vision_transformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/_api.py b/.venv/lib/python3.11/site-packages/torchvision/models/_api.py new file mode 100644 index 0000000000000000000000000000000000000000..0999bf7ba6beba91acc0b3374c2307cc8137847c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/_api.py @@ -0,0 +1,277 @@ +import fnmatch +import importlib +import inspect +import sys +from dataclasses import dataclass +from enum import Enum +from functools import partial +from inspect import signature +from types import ModuleType +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union + +from torch import nn + +from .._internally_replaced_utils import load_state_dict_from_url + + +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"] + + +@dataclass +class Weights: + """ + This class is used to group important attributes associated with the pre-trained weights. + + Args: + url (str): The location where we find the weights. + transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms) + needed to use the model. The reason we attach a constructor method rather than an already constructed + object is because the specific object might have memory and thus we want to delay initialization until + needed. + meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be + informative attributes (for example the number of parameters/flops, recipe link/methods used in training + etc), configuration parameters (for example the `num_classes`) needed to construct the model or important + meta-data (for example the `classes` of a classification model) needed to use the model. + """ + + url: str + transforms: Callable + meta: Dict[str, Any] + + def __eq__(self, other: Any) -> bool: + # We need this custom implementation for correct deep-copy and deserialization behavior. + # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it, + # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often + # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling + # for it, the check against the defined members would fail and effectively prevent the weights from being + # deep-copied or deserialized. + # See https://github.com/pytorch/vision/pull/7107 for details. + if not isinstance(other, Weights): + return NotImplemented + + if self.url != other.url: + return False + + if self.meta != other.meta: + return False + + if isinstance(self.transforms, partial) and isinstance(other.transforms, partial): + return ( + self.transforms.func == other.transforms.func + and self.transforms.args == other.transforms.args + and self.transforms.keywords == other.transforms.keywords + ) + else: + return self.transforms == other.transforms + + +class WeightsEnum(Enum): + """ + This class is the parent class of all model weights. Each model building method receives an optional `weights` + parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type + `Weights`. + + Args: + value (Weights): The data class entry with the weight information. + """ + + @classmethod + def verify(cls, obj: Any) -> Any: + if obj is not None: + if type(obj) is str: + obj = cls[obj.replace(cls.__name__ + ".", "")] + elif not isinstance(obj, cls): + raise TypeError( + f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." + ) + return obj + + def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]: + return load_state_dict_from_url(self.url, *args, **kwargs) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self._name_}" + + @property + def url(self): + return self.value.url + + @property + def transforms(self): + return self.value.transforms + + @property + def meta(self): + return self.value.meta + + +def get_weight(name: str) -> WeightsEnum: + """ + Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" + + Args: + name (str): The name of the weight enum entry. + + Returns: + WeightsEnum: The requested weight enum. + """ + try: + enum_name, value_name = name.split(".") + except ValueError: + raise ValueError(f"Invalid weight name provided: '{name}'.") + + base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1]) + base_module = importlib.import_module(base_module_name) + model_modules = [base_module] + [ + x[1] + for x in inspect.getmembers(base_module, inspect.ismodule) + if x[1].__file__.endswith("__init__.py") # type: ignore[union-attr] + ] + + weights_enum = None + for m in model_modules: + potential_class = m.__dict__.get(enum_name, None) + if potential_class is not None and issubclass(potential_class, WeightsEnum): + weights_enum = potential_class + break + + if weights_enum is None: + raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") + + return weights_enum[value_name] + + +def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]: + """ + Returns the weights enum class associated to the given model. + + Args: + name (callable or str): The model builder function or the name under which it is registered. + + Returns: + weights_enum (WeightsEnum): The weights enum class associated with the model. + """ + model = get_model_builder(name) if isinstance(name, str) else name + return _get_enum_from_fn(model) + + +def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]: + """ + Internal method that gets the weight enum of a specific model builder method. + + Args: + fn (Callable): The builder method used to create the model. + Returns: + WeightsEnum: The requested weight enum. + """ + sig = signature(fn) + if "weights" not in sig.parameters: + raise ValueError("The method is missing the 'weights' argument.") + + ann = signature(fn).parameters["weights"].annotation + weights_enum = None + if isinstance(ann, type) and issubclass(ann, WeightsEnum): + weights_enum = ann + else: + # handle cases like Union[Optional, T] + # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 + for t in ann.__args__: # type: ignore[union-attr] + if isinstance(t, type) and issubclass(t, WeightsEnum): + weights_enum = t + break + + if weights_enum is None: + raise ValueError( + "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." + ) + + return weights_enum + + +M = TypeVar("M", bound=nn.Module) + +BUILTIN_MODELS = {} + + +def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: + def wrapper(fn: Callable[..., M]) -> Callable[..., M]: + key = name if name is not None else fn.__name__ + if key in BUILTIN_MODELS: + raise ValueError(f"An entry is already registered under the name '{key}'.") + BUILTIN_MODELS[key] = fn + return fn + + return wrapper + + +def list_models( + module: Optional[ModuleType] = None, + include: Union[Iterable[str], str, None] = None, + exclude: Union[Iterable[str], str, None] = None, +) -> List[str]: + """ + Returns a list with the names of registered models. + + Args: + module (ModuleType, optional): The module from which we want to extract the available models. + include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models. + Filters are passed to `fnmatch `__ to match Unix shell-style + wildcards. In case of many filters, the results is the union of individual filters. + exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models. + Filter are passed to `fnmatch `__ to match Unix shell-style + wildcards. In case of many filters, the results is removal of all the models that match any individual filter. + + Returns: + models (list): A list with the names of available models. + """ + all_models = { + k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ + } + if include: + models: Set[str] = set() + if isinstance(include, str): + include = [include] + for include_filter in include: + models = models | set(fnmatch.filter(all_models, include_filter)) + else: + models = all_models + + if exclude: + if isinstance(exclude, str): + exclude = [exclude] + for exclude_filter in exclude: + models = models - set(fnmatch.filter(all_models, exclude_filter)) + return sorted(models) + + +def get_model_builder(name: str) -> Callable[..., nn.Module]: + """ + Gets the model name and returns the model builder method. + + Args: + name (str): The name under which the model is registered. + + Returns: + fn (Callable): The model builder method. + """ + name = name.lower() + try: + fn = BUILTIN_MODELS[name] + except KeyError: + raise ValueError(f"Unknown model {name}") + return fn + + +def get_model(name: str, **config: Any) -> nn.Module: + """ + Gets the model name and configuration and returns an instantiated model. + + Args: + name (str): The name under which the model is registered. + **config (Any): parameters passed to the model builder method. + + Returns: + model (nn.Module): The initialized model. + """ + fn = get_model_builder(name) + return fn(**config) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/_meta.py b/.venv/lib/python3.11/site-packages/torchvision/models/_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..e66f411c287e0f456448315ba4fd0bfcce281d2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/_meta.py @@ -0,0 +1,1554 @@ +""" +This file is part of the private API. Please do not refer to any variables defined here directly as they will be +removed on future versions without warning. +""" + +# This will eventually be replaced with a call at torchvision.datasets.info("imagenet").categories +_IMAGENET_CATEGORIES = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead", + "electric ray", + "stingray", + "cock", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel", + "kite", + "bald eagle", + "vulture", + "great grey owl", + "European fire salamander", + "common newt", + "eft", + "spotted salamander", + "axolotl", + "bullfrog", + "tree frog", + "tailed frog", + "loggerhead", + "leatherback turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "common iguana", + "American chameleon", + "whiptail", + "agama", + "frilled lizard", + "alligator lizard", + "Gila monster", + "green lizard", + "African chameleon", + "Komodo dragon", + "African crocodile", + "American alligator", + "triceratops", + "thunder snake", + "ringneck snake", + "hognose snake", + "green snake", + "king snake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "rock python", + "Indian cobra", + "green mamba", + "sea snake", + "horned viper", + "diamondback", + "sidewinder", + "trilobite", + "harvestman", + "scorpion", + "black and gold garden spider", + "barn spider", + "garden spider", + "black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie chicken", + "peacock", + "quail", + "partridge", + "African grey", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "American egret", + "bittern", + "crane bird", + "limpkin", + "European gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "red-backed sandpiper", + "redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog", + "Pekinese", + "Shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound", + "basset", + "beagle", + "bloodhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound", + "English foxhound", + "redbone", + "borzoi", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound", + "Norwegian elkhound", + "otterhound", + "Saluki", + "Scottish deerhound", + "Weimaraner", + "Staffordshire bullterrier", + "American Staffordshire terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "Sealyham terrier", + "Airedale", + "cairn", + "Australian terrier", + "Dandie Dinmont", + "Boston bull", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "Scotch terrier", + "Tibetan terrier", + "silky terrier", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla", + "English setter", + "Irish setter", + "Gordon setter", + "Brittany spaniel", + "clumber", + "English springer", + "Welsh springer spaniel", + "cocker spaniel", + "Sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog", + "Shetland sheepdog", + "collie", + "Border collie", + "Bouvier des Flandres", + "Rottweiler", + "German shepherd", + "Doberman", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "Saint Bernard", + "Eskimo dog", + "malamute", + "Siberian husky", + "dalmatian", + "affenpinscher", + "basenji", + "pug", + "Leonberg", + "Newfoundland", + "Great Pyrenees", + "Samoyed", + "Pomeranian", + "chow", + "keeshond", + "Brabancon griffon", + "Pembroke", + "Cardigan", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf", + "white wolf", + "red wolf", + "coyote", + "dingo", + "dhole", + "African hunting dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian cat", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "ice bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "long-horned beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket", + "walking stick", + "cockroach", + "mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "admiral", + "ringlet", + "monarch", + "cabbage butterfly", + "sulphur butterfly", + "lycaenid", + "starfish", + "sea urchin", + "sea cucumber", + "wood rabbit", + "hare", + "Angora", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "sorrel", + "zebra", + "hog", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram", + "bighorn", + "ibex", + "hartebeest", + "impala", + "gazelle", + "Arabian camel", + "llama", + "weasel", + "mink", + "polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas", + "baboon", + "macaque", + "langur", + "colobus", + "proboscis monkey", + "marmoset", + "capuchin", + "howler monkey", + "titi", + "spider monkey", + "squirrel monkey", + "Madagascar cat", + "indri", + "Indian elephant", + "African elephant", + "lesser panda", + "giant panda", + "barracouta", + "eel", + "coho", + "rock beauty", + "anemone fish", + "sturgeon", + "gar", + "lionfish", + "puffer", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibian", + "analog clock", + "apiary", + "apron", + "ashcan", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint", + "Band Aid", + "banjo", + "bannister", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "barrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap", + "bath towel", + "bathtub", + "beach wagon", + "beacon", + "beaker", + "bearskin", + "beer bottle", + "beer glass", + "bell cote", + "bib", + "bicycle-built-for-two", + "bikini", + "binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsled", + "bolo tie", + "bonnet", + "bookcase", + "bookshop", + "bottlecap", + "bow", + "bow tie", + "brass", + "brassiere", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "bullet train", + "butcher shop", + "cab", + "caldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "carpenter's kit", + "carton", + "car wheel", + "cash machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "cellular telephone", + "chain", + "chainlink fence", + "chain mail", + "chain saw", + "chest", + "chiffonier", + "chime", + "china cabinet", + "Christmas stocking", + "church", + "cinema", + "cleaver", + "cliff dwelling", + "cloak", + "clog", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil", + "combination lock", + "computer keyboard", + "confectionery", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishrag", + "dishwasher", + "disk brake", + "dock", + "dogsled", + "dome", + "doormat", + "drilling platform", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa", + "file", + "fireboat", + "fire engine", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gasmask", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golfcart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "grille", + "grocery store", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower", + "hand-held computer", + "handkerchief", + "hard disc", + "harmonica", + "harp", + "harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoopskirt", + "horizontal bar", + "horse cart", + "hourglass", + "iPod", + "iron", + "jack-o'-lantern", + "jean", + "jeep", + "jersey", + "jigsaw puzzle", + "jinrikisha", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "liner", + "lipstick", + "Loafer", + "lotion", + "loudspeaker", + "loupe", + "lumbermill", + "magnetic compass", + "mailbag", + "mailbox", + "maillot", + "maillot tank suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine chest", + "megalith", + "microphone", + "microwave", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter", + "mountain bike", + "mountain tent", + "mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "organ", + "oscilloscope", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle", + "paddlewheel", + "padlock", + "paintbrush", + "pajama", + "palace", + "panpipe", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "passenger car", + "patio", + "pay-phone", + "pedestal", + "pencil box", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "pick", + "pickelhaube", + "picket fence", + "pickup", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate", + "pitcher", + "plane", + "planetarium", + "plastic bag", + "plate rack", + "plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "pop bottle", + "pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "projectile", + "projector", + "puck", + "punching bag", + "purse", + "quill", + "quilt", + "racer", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "rubber eraser", + "rugby ball", + "rule", + "running shoe", + "safe", + "safety pin", + "saltshaker", + "sandal", + "sarong", + "sax", + "scabbard", + "scale", + "school bus", + "schooner", + "scoreboard", + "screen", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe shop", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar dish", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch", + "stove", + "strainer", + "streetcar", + "stretcher", + "studio couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglass", + "sunglasses", + "sunscreen", + "suspension bridge", + "swab", + "sweatshirt", + "swimming trunks", + "swing", + "switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy", + "television", + "tennis ball", + "thatch", + "theater curtain", + "thimble", + "thresher", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toyshop", + "tractor", + "trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright", + "vacuum", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "warplane", + "washbasin", + "washer", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool", + "worm fence", + "wreck", + "yawl", + "yurt", + "web site", + "comic book", + "crossword puzzle", + "street sign", + "traffic light", + "book jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "ice lolly", + "French loaf", + "bagel", + "pretzel", + "cheeseburger", + "hotdog", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce", + "dough", + "meat loaf", + "pizza", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeside", + "promontory", + "sandbar", + "seashore", + "valley", + "volcano", + "ballplayer", + "groom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "hip", + "buckeye", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn", + "earthstar", + "hen-of-the-woods", + "bolete", + "ear", + "toilet tissue", +] + +# To be replaced with torchvision.datasets.info("coco").categories +_COCO_CATEGORIES = [ + "__background__", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "N/A", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "N/A", + "backpack", + "umbrella", + "N/A", + "N/A", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "N/A", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "N/A", + "dining table", + "N/A", + "N/A", + "toilet", + "N/A", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "N/A", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] + +# To be replaced with torchvision.datasets.info("coco_kp") +_COCO_PERSON_CATEGORIES = ["no person", "person"] +_COCO_PERSON_KEYPOINT_NAMES = [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", +] + +# To be replaced with torchvision.datasets.info("voc").categories +_VOC_CATEGORIES = [ + "__background__", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] + +# To be replaced with torchvision.datasets.info("kinetics400").categories +_KINETICS400_CATEGORIES = [ + "abseiling", + "air drumming", + "answering questions", + "applauding", + "applying cream", + "archery", + "arm wrestling", + "arranging flowers", + "assembling computer", + "auctioning", + "baby waking up", + "baking cookies", + "balloon blowing", + "bandaging", + "barbequing", + "bartending", + "beatboxing", + "bee keeping", + "belly dancing", + "bench pressing", + "bending back", + "bending metal", + "biking through snow", + "blasting sand", + "blowing glass", + "blowing leaves", + "blowing nose", + "blowing out candles", + "bobsledding", + "bookbinding", + "bouncing on trampoline", + "bowling", + "braiding hair", + "breading or breadcrumbing", + "breakdancing", + "brush painting", + "brushing hair", + "brushing teeth", + "building cabinet", + "building shed", + "bungee jumping", + "busking", + "canoeing or kayaking", + "capoeira", + "carrying baby", + "cartwheeling", + "carving pumpkin", + "catching fish", + "catching or throwing baseball", + "catching or throwing frisbee", + "catching or throwing softball", + "celebrating", + "changing oil", + "changing wheel", + "checking tires", + "cheerleading", + "chopping wood", + "clapping", + "clay pottery making", + "clean and jerk", + "cleaning floor", + "cleaning gutters", + "cleaning pool", + "cleaning shoes", + "cleaning toilet", + "cleaning windows", + "climbing a rope", + "climbing ladder", + "climbing tree", + "contact juggling", + "cooking chicken", + "cooking egg", + "cooking on campfire", + "cooking sausages", + "counting money", + "country line dancing", + "cracking neck", + "crawling baby", + "crossing river", + "crying", + "curling hair", + "cutting nails", + "cutting pineapple", + "cutting watermelon", + "dancing ballet", + "dancing charleston", + "dancing gangnam style", + "dancing macarena", + "deadlifting", + "decorating the christmas tree", + "digging", + "dining", + "disc golfing", + "diving cliff", + "dodgeball", + "doing aerobics", + "doing laundry", + "doing nails", + "drawing", + "dribbling basketball", + "drinking", + "drinking beer", + "drinking shots", + "driving car", + "driving tractor", + "drop kicking", + "drumming fingers", + "dunking basketball", + "dying hair", + "eating burger", + "eating cake", + "eating carrots", + "eating chips", + "eating doughnuts", + "eating hotdog", + "eating ice cream", + "eating spaghetti", + "eating watermelon", + "egg hunting", + "exercising arm", + "exercising with an exercise ball", + "extinguishing fire", + "faceplanting", + "feeding birds", + "feeding fish", + "feeding goats", + "filling eyebrows", + "finger snapping", + "fixing hair", + "flipping pancake", + "flying kite", + "folding clothes", + "folding napkins", + "folding paper", + "front raises", + "frying vegetables", + "garbage collecting", + "gargling", + "getting a haircut", + "getting a tattoo", + "giving or receiving award", + "golf chipping", + "golf driving", + "golf putting", + "grinding meat", + "grooming dog", + "grooming horse", + "gymnastics tumbling", + "hammer throw", + "headbanging", + "headbutting", + "high jump", + "high kick", + "hitting baseball", + "hockey stop", + "holding snake", + "hopscotch", + "hoverboarding", + "hugging", + "hula hooping", + "hurdling", + "hurling (sport)", + "ice climbing", + "ice fishing", + "ice skating", + "ironing", + "javelin throw", + "jetskiing", + "jogging", + "juggling balls", + "juggling fire", + "juggling soccer ball", + "jumping into pool", + "jumpstyle dancing", + "kicking field goal", + "kicking soccer ball", + "kissing", + "kitesurfing", + "knitting", + "krumping", + "laughing", + "laying bricks", + "long jump", + "lunge", + "making a cake", + "making a sandwich", + "making bed", + "making jewelry", + "making pizza", + "making snowman", + "making sushi", + "making tea", + "marching", + "massaging back", + "massaging feet", + "massaging legs", + "massaging person's head", + "milking cow", + "mopping floor", + "motorcycling", + "moving furniture", + "mowing lawn", + "news anchoring", + "opening bottle", + "opening present", + "paragliding", + "parasailing", + "parkour", + "passing American football (in game)", + "passing American football (not in game)", + "peeling apples", + "peeling potatoes", + "petting animal (not cat)", + "petting cat", + "picking fruit", + "planting trees", + "plastering", + "playing accordion", + "playing badminton", + "playing bagpipes", + "playing basketball", + "playing bass guitar", + "playing cards", + "playing cello", + "playing chess", + "playing clarinet", + "playing controller", + "playing cricket", + "playing cymbals", + "playing didgeridoo", + "playing drums", + "playing flute", + "playing guitar", + "playing harmonica", + "playing harp", + "playing ice hockey", + "playing keyboard", + "playing kickball", + "playing monopoly", + "playing organ", + "playing paintball", + "playing piano", + "playing poker", + "playing recorder", + "playing saxophone", + "playing squash or racquetball", + "playing tennis", + "playing trombone", + "playing trumpet", + "playing ukulele", + "playing violin", + "playing volleyball", + "playing xylophone", + "pole vault", + "presenting weather forecast", + "pull ups", + "pumping fist", + "pumping gas", + "punching bag", + "punching person (boxing)", + "push up", + "pushing car", + "pushing cart", + "pushing wheelchair", + "reading book", + "reading newspaper", + "recording music", + "riding a bike", + "riding camel", + "riding elephant", + "riding mechanical bull", + "riding mountain bike", + "riding mule", + "riding or walking with horse", + "riding scooter", + "riding unicycle", + "ripping paper", + "robot dancing", + "rock climbing", + "rock scissors paper", + "roller skating", + "running on treadmill", + "sailing", + "salsa dancing", + "sanding floor", + "scrambling eggs", + "scuba diving", + "setting table", + "shaking hands", + "shaking head", + "sharpening knives", + "sharpening pencil", + "shaving head", + "shaving legs", + "shearing sheep", + "shining shoes", + "shooting basketball", + "shooting goal (soccer)", + "shot put", + "shoveling snow", + "shredding paper", + "shuffling cards", + "side kick", + "sign language interpreting", + "singing", + "situp", + "skateboarding", + "ski jumping", + "skiing (not slalom or crosscountry)", + "skiing crosscountry", + "skiing slalom", + "skipping rope", + "skydiving", + "slacklining", + "slapping", + "sled dog racing", + "smoking", + "smoking hookah", + "snatch weight lifting", + "sneezing", + "sniffing", + "snorkeling", + "snowboarding", + "snowkiting", + "snowmobiling", + "somersaulting", + "spinning poi", + "spray painting", + "spraying", + "springboard diving", + "squat", + "sticking tongue out", + "stomping grapes", + "stretching arm", + "stretching leg", + "strumming guitar", + "surfing crowd", + "surfing water", + "sweeping floor", + "swimming backstroke", + "swimming breast stroke", + "swimming butterfly stroke", + "swing dancing", + "swinging legs", + "swinging on something", + "sword fighting", + "tai chi", + "taking a shower", + "tango dancing", + "tap dancing", + "tapping guitar", + "tapping pen", + "tasting beer", + "tasting food", + "testifying", + "texting", + "throwing axe", + "throwing ball", + "throwing discus", + "tickling", + "tobogganing", + "tossing coin", + "tossing salad", + "training dog", + "trapezing", + "trimming or shaving beard", + "trimming trees", + "triple jump", + "tying bow tie", + "tying knot (not on a tie)", + "tying tie", + "unboxing", + "unloading truck", + "using computer", + "using remote controller (not gaming)", + "using segway", + "vault", + "waiting in line", + "walking the dog", + "washing dishes", + "washing feet", + "washing hair", + "washing hands", + "water skiing", + "water sliding", + "watering plants", + "waxing back", + "waxing chest", + "waxing eyebrows", + "waxing legs", + "weaving basket", + "welding", + "whistling", + "windsurfing", + "wrapping present", + "wrestling", + "writing", + "yawning", + "yoga", + "zumba", +] diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/_utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d59a6220b91f0f2fe85b6bf01948ee5a506cd82a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/_utils.py @@ -0,0 +1,256 @@ +import functools +import inspect +import warnings +from collections import OrderedDict +from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union + +from torch import nn + +from .._utils import sequence_to_str +from ._api import WeightsEnum + + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + + Args: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + + Examples:: + + >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + + _version = 2 + __annotations__ = { + "return_layers": Dict[str, str], + } + + def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + orig_return_layers = return_layers + return_layers = {str(k): str(v) for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super().__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +D = TypeVar("D") + + +def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: + """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. + + For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: + + .. code:: + + def old_fn(foo, bar, baz=None): + ... + + def new_fn(foo, *, bar, baz=None): + ... + + Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC + and at the same time warn the user of the deprecation, this decorator can be used: + + .. code:: + + @kwonly_to_pos_or_kw + def new_fn(foo, *, bar, baz=None): + ... + + new_fn("foo", "bar, "baz") + """ + params = inspect.signature(fn).parameters + + try: + keyword_only_start_idx = next( + idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY + ) + except StopIteration: + raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None + + keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> D: + args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] + if keyword_only_args: + keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) + warnings.warn( + f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " + f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) " + f"instead." + ) + kwargs.update(keyword_only_kwargs) + + return fn(*args, **kwargs) + + return wrapper + + +W = TypeVar("W", bound=WeightsEnum) +M = TypeVar("M", bound=nn.Module) +V = TypeVar("V") + + +def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): + """Decorates a model builder with the new interface to make it compatible with the old. + + In particular this handles two things: + + 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See + :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. + 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to + ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. + + Args: + **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter + name and default value for the legacy ``pretrained=True``. The default value can be a callable in which + case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in + the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters + should be accessed with :meth:`~dict.get`. + """ + + def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: + @kwonly_to_pos_or_kw + @functools.wraps(builder) + def inner_wrapper(*args: Any, **kwargs: Any) -> M: + for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] + # If neither the weights nor the pretrained parameter as passed, or the weights argument already use + # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the + # weight argument, since it is a valid value. + sentinel = object() + weights_arg = kwargs.get(weights_param, sentinel) + if ( + (weights_param not in kwargs and pretrained_param not in kwargs) + or isinstance(weights_arg, WeightsEnum) + or (isinstance(weights_arg, str) and weights_arg != "legacy") + or weights_arg is None + ): + continue + + # If the pretrained parameter was passed as positional argument, it is now mapped to + # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current + # signature to infer the names of positionally passed arguments and thus has no knowledge that there + # used to be a pretrained parameter. + pretrained_positional = weights_arg is not sentinel + if pretrained_positional: + # We put the pretrained argument under its legacy name in the keyword argument dictionary to have + # unified access to the value if the default value is a callable. + kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) + else: + pretrained_arg = kwargs[pretrained_param] + + if pretrained_arg: + default_weights_arg = default(kwargs) if callable(default) else default + if not isinstance(default_weights_arg, WeightsEnum): + raise ValueError(f"No weights available for model {builder.__name__}") + else: + default_weights_arg = None + + if not pretrained_positional: + warnings.warn( + f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, " + f"please use '{weights_param}' instead." + ) + + msg = ( + f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and " + f"may be removed in the future. " + f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." + ) + if pretrained_arg: + msg = ( + f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " + f"to get the most up-to-date weights." + ) + warnings.warn(msg) + + del kwargs[pretrained_param] + kwargs[weights_param] = default_weights_arg + + return builder(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + +def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: + if param in kwargs: + if kwargs[param] != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") + else: + kwargs[param] = new_value + + +def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V: + if actual is not None: + if actual != expected: + raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.") + return expected + + +class _ModelURLs(dict): + def __getitem__(self, item): + warnings.warn( + "Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may " + "be removed in the future. Please access them via the appropriate Weights Enum instead." + ) + return super().__getitem__(item) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py b/.venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..f42bc124c7b92cddeed2c161e3ead3a3c8963295 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/feature_extraction.py @@ -0,0 +1,572 @@ +import inspect +import math +import re +import warnings +from collections import OrderedDict +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torchvision +from torch import fx, nn +from torch.fx.graph_module import _copy_attr + + +__all__ = ["create_feature_extractor", "get_graph_node_names"] + + +class LeafModuleAwareTracer(fx.Tracer): + """ + An fx.Tracer that allows the user to specify a set of leaf modules, i.e. + modules that are not to be traced through. The resulting graph ends up + having single nodes referencing calls to the leaf modules' forward methods. + """ + + def __init__(self, *args, **kwargs): + self.leaf_modules = {} + if "leaf_modules" in kwargs: + leaf_modules = kwargs.pop("leaf_modules") + self.leaf_modules = leaf_modules + super().__init__(*args, **kwargs) + + def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: + if isinstance(m, tuple(self.leaf_modules)): + return True + return super().is_leaf_module(m, module_qualname) + + +class NodePathTracer(LeafModuleAwareTracer): + """ + NodePathTracer is an FX tracer that, for each operation, also records the + name of the Node from which the operation originated. A node name here is + a `.` separated path walking the hierarchy from top level module down to + leaf operation or leaf module. The name of the top level module is not + included as part of the node name. For example, if we trace a module whose + forward method applies a ReLU module, the name for that node will simply + be 'relu'. + + Some notes on the specifics: + - Nodes are recorded to `self.node_to_qualname` which is a dictionary + mapping a given Node object to its node name. + - Nodes are recorded in the order which they are executed during + tracing. + - When a duplicate node name is encountered, a suffix of the form + _{int} is added. The counter starts from 1. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Track the qualified name of the Node being traced + self.current_module_qualname = "" + # A map from FX Node to the qualified name\# + # NOTE: This is loosely like the "qualified name" mentioned in the + # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted + # for the purposes of the torchvision feature extractor + self.node_to_qualname = OrderedDict() + + def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): + """ + Override of `fx.Tracer.call_module` + This override: + 1) Stores away the qualified name of the caller for restoration later + 2) Adds the qualified name of the caller to + `current_module_qualname` for retrieval by `create_proxy` + 3) Once a leaf module is reached, calls `create_proxy` + 4) Restores the caller's qualified name into current_module_qualname + """ + old_qualname = self.current_module_qualname + try: + module_qualname = self.path_of_module(m) + self.current_module_qualname = module_qualname + if not self.is_leaf_module(m, module_qualname): + out = forward(*args, **kwargs) + return out + return self.create_proxy("call_module", module_qualname, args, kwargs) + finally: + self.current_module_qualname = old_qualname + + def create_proxy( + self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_ + ) -> fx.proxy.Proxy: + """ + Override of `Tracer.create_proxy`. This override intercepts the recording + of every operation and stores away the current traced module's qualified + name in `node_to_qualname` + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) + self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node) + return proxy + + def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str: + node_qualname = module_qualname + + if node.op != "call_module": + # In this case module_qualname from torch.fx doesn't go all the + # way to the leaf function/op, so we need to append it + if len(node_qualname) > 0: + # Only append '.' if we are deeper than the top level module + node_qualname += "." + node_qualname += str(node) + + # Now we need to add an _{index} postfix on any repeated node names + # For modules we do this from scratch + # But for anything else, torch.fx already has a globally scoped + # _{index} postfix. But we want it locally (relative to direct parent) + # scoped. So first we need to undo the torch.fx postfix + if re.match(r".+_[0-9]+$", node_qualname) is not None: + node_qualname = node_qualname.rsplit("_", 1)[0] + + # ... and now we add on our own postfix + for existing_qualname in reversed(self.node_to_qualname.values()): + # Check to see if existing_qualname is of the form + # {node_qualname} or {node_qualname}_{int} + if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None: + postfix = existing_qualname.replace(node_qualname, "") + if len(postfix): + # existing_qualname is of the form {node_qualname}_{int} + next_index = int(postfix[1:]) + 1 + else: + # existing_qualname is of the form {node_qualname} + next_index = 1 + node_qualname += f"_{next_index}" + break + + return node_qualname + + +def _is_subseq(x, y): + """Check if y is a subsequence of x + https://stackoverflow.com/a/24017747/4391249 + """ + iter_x = iter(x) + return all(any(x_item == y_item for x_item in iter_x) for y_item in y) + + +def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer): + """ + Utility function for warning the user if there are differences between + the train graph nodes and the eval graph nodes. + """ + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + + if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)): + return + + suggestion_msg = ( + "When choosing nodes for feature extraction, you may need to specify " + "output nodes for train and eval mode separately." + ) + + if _is_subseq(train_nodes, eval_nodes): + msg = ( + "NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. " + ) + elif _is_subseq(eval_nodes, train_nodes): + msg = ( + "NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. " + ) + else: + msg = "The nodes obtained by tracing the model in train mode are different to those obtained in eval mode. " + warnings.warn(msg + suggestion_msg) + + +def _get_leaf_modules_for_ops() -> List[type]: + members = inspect.getmembers(torchvision.ops) + result = [] + for _, obj in members: + if inspect.isclass(obj) and issubclass(obj, torch.nn.Module): + result.append(obj) + return result + + +def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: + default_autowrap_modules = (math, torchvision.ops) + default_leaf_modules = _get_leaf_modules_for_ops() + result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs + result_tracer_kwargs["autowrap_modules"] = ( + tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules)) + if "autowrap_modules" in result_tracer_kwargs + else default_autowrap_modules + ) + result_tracer_kwargs["leaf_modules"] = ( + list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules)) + if "leaf_modules" in result_tracer_kwargs + else default_leaf_modules + ) + return result_tracer_kwargs + + +def get_graph_node_names( + model: nn.Module, + tracer_kwargs: Optional[Dict[str, Any]] = None, + suppress_diff_warning: bool = False, + concrete_args: Optional[Dict[str, Any]] = None, +) -> Tuple[List[str], List[str]]: + """ + Dev utility to return node names in order of execution. See note on node + names under :func:`create_feature_extractor`. Useful for seeing which node + names are available for feature extraction. There are two reasons that + node names can't easily be read directly from the code for a model: + + 1. Not all submodules are traced through. Modules from ``torch.nn`` all + fall within this category. + 2. Nodes representing the repeated application of the same operation + or leaf module get a ``_{counter}`` postfix. + + The model is traced twice: once in train mode, and once in eval mode. Both + sets of node names are returned. + + For more details on the node naming conventions used here, please see the + :ref:`relevant subheading ` in the + `documentation `_. + + Args: + model (nn.Module): model for which we'd like to print node names + tracer_kwargs (dict, optional): a dictionary of keyword arguments for + ``NodePathTracer`` (they are eventually passed onto + `torch.fx.Tracer `_). + By default, it will be set to wrap and make leaf nodes all torchvision ops: + {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} + WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user + provided dictionary. + suppress_diff_warning (bool, optional): whether to suppress a warning + when there are discrepancies between the train and eval version of + the graph. Defaults to False. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. According to the `Pytorch docs + `_, + this parameter's API may not be guaranteed. + + Returns: + tuple(list, list): a list of node names from tracing the model in + train mode, and another from tracing the model in eval mode. + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> train_nodes, eval_nodes = get_graph_node_names(model) + """ + tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) + is_training = model.training + train_tracer = NodePathTracer(**tracer_kwargs) + train_tracer.trace(model.train(), concrete_args=concrete_args) + eval_tracer = NodePathTracer(**tracer_kwargs) + eval_tracer.trace(model.eval(), concrete_args=concrete_args) + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + if not suppress_diff_warning: + _warn_graph_differences(train_tracer, eval_tracer) + # Restore training state + model.train(is_training) + return train_nodes, eval_nodes + + +class DualGraphModule(fx.GraphModule): + """ + A derivative of `fx.GraphModule`. Differs in the following ways: + - Requires a train and eval version of the underlying graph + - Copies submodules according to the nodes of both train and eval graphs. + - Calling train(mode) switches between train graph and eval graph. + """ + + def __init__( + self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule" + ): + """ + Args: + root (nn.Module): module from which the copied module hierarchy is + built + train_graph (fx.Graph): the graph that should be used in train mode + eval_graph (fx.Graph): the graph that should be used in eval mode + """ + super(fx.GraphModule, self).__init__() + + self.__class__.__name__ = class_name + + self.train_graph = train_graph + self.eval_graph = eval_graph + + # Copy all get_attr and call_module ops (indicated by BOTH train and + # eval graphs) + for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): + if node.op in ["get_attr", "call_module"]: + if not isinstance(node.target, str): + raise TypeError(f"node.target should be of type str instead of {type(node.target)}") + _copy_attr(root, self, node.target) + + # train mode by default + self.train() + self.graph = train_graph + + # (borrowed from fx.GraphModule): + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + if self.eval_graph._tracer_cls != self.train_graph._tracer_cls: + raise TypeError( + f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train" + ) + self._tracer_cls = None + if self.graph._tracer_cls and "" not in self.graph._tracer_cls.__qualname__: + self._tracer_cls = self.graph._tracer_cls + + def train(self, mode=True): + """ + Swap out the graph depending on the selected training mode. + NOTE this should be safe when calling model.eval() because that just + calls this with mode == False. + """ + # NOTE: Only set self.graph if the current graph is not the desired + # one. This saves us from recompiling the graph where not necessary. + if mode and not self.training: + self.graph = self.train_graph + elif not mode and self.training: + self.graph = self.eval_graph + return super().train(mode=mode) + + +def create_feature_extractor( + model: nn.Module, + return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + tracer_kwargs: Optional[Dict[str, Any]] = None, + suppress_diff_warning: bool = False, + concrete_args: Optional[Dict[str, Any]] = None, +) -> fx.GraphModule: + """ + Creates a new graph module that returns intermediate nodes from a given + model as dictionary with user specified keys as strings, and the requested + outputs as values. This is achieved by re-writing the computation graph of + the model via FX to return the desired nodes as outputs. All unused nodes + are removed, together with their corresponding parameters. + + Desired output nodes must be specified as a ``.`` separated + path walking the module hierarchy from top level module down to leaf + operation or leaf module. For more details on the node naming conventions + used here, please see the :ref:`relevant subheading ` + in the `documentation `_. + + Not all models will be FX traceable, although with some massaging they can + be made to cooperate. Here's a (not exhaustive) list of tips: + + - If you don't need to trace through a particular, problematic + sub-module, turn it into a "leaf module" by passing a list of + ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below). + It will not be traced through, but rather, the resulting graph will + hold a reference to that module's forward method. + - Likewise, you may turn functions into leaf functions by passing a + list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see + example below). + - Some inbuilt Python functions can be problematic. For instance, + ``int`` will raise an error during tracing. You may wrap them in your + own function and then pass that in ``autowrap_functions`` as one of + the ``tracer_kwargs``. + + For further information on FX see the + `torch.fx documentation `_. + + Args: + model (nn.Module): model on which we will extract the features + return_nodes (list or dict, optional): either a ``List`` or a ``Dict`` + containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a ``Dict``, the keys are the node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping + node specification strings directly to output names. In the case + that ``train_return_nodes`` and ``eval_return_nodes`` are specified, + this should not be specified. + train_return_nodes (list or dict, optional): similar to + ``return_nodes``. This can be used if the return nodes + for train mode are different than those from eval mode. + If this is specified, ``eval_return_nodes`` must also be specified, + and ``return_nodes`` should not be specified. + eval_return_nodes (list or dict, optional): similar to + ``return_nodes``. This can be used if the return nodes + for train mode are different than those from eval mode. + If this is specified, ``train_return_nodes`` must also be specified, + and `return_nodes` should not be specified. + tracer_kwargs (dict, optional): a dictionary of keyword arguments for + ``NodePathTracer`` (which passes them onto it's parent class + `torch.fx.Tracer `_). + By default, it will be set to wrap and make leaf nodes all torchvision ops: + {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} + WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user + provided dictionary. + suppress_diff_warning (bool, optional): whether to suppress a warning + when there are discrepancies between the train and eval version of + the graph. Defaults to False. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. According to the `Pytorch docs + `_, + this parameter's API may not be guaranteed. + + Examples:: + + >>> # Feature extraction with resnet + >>> model = torchvision.models.resnet18() + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> model = create_feature_extractor( + >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = model(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + + >>> # Specifying leaf modules and leaf functions + >>> def leaf_function(x): + >>> # This would raise a TypeError if traced through + >>> return int(x) + >>> + >>> class LeafModule(torch.nn.Module): + >>> def forward(self, x): + >>> # This would raise a TypeError if traced through + >>> int(x.shape[0]) + >>> return torch.nn.functional.relu(x + 4) + >>> + >>> class MyModule(torch.nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.conv = torch.nn.Conv2d(3, 1, 3) + >>> self.leaf_module = LeafModule() + >>> + >>> def forward(self, x): + >>> leaf_function(x.shape[0]) + >>> x = self.conv(x) + >>> return self.leaf_module(x) + >>> + >>> model = create_feature_extractor( + >>> MyModule(), return_nodes=['leaf_module'], + >>> tracer_kwargs={'leaf_modules': [LeafModule], + >>> 'autowrap_functions': [leaf_function]}) + + """ + tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) + is_training = model.training + + if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]): + + raise ValueError( + "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified" + ) + + if (train_return_nodes is None) ^ (eval_return_nodes is None): + raise ValueError( + "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" + ) + + if not ((return_nodes is None) ^ (train_return_nodes is None)): + raise ValueError("If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified") + + # Put *_return_nodes into Dict[str, str] format + def to_strdict(n) -> Dict[str, str]: + if isinstance(n, list): + return {str(i): str(i) for i in n} + return {str(k): str(v) for k, v in n.items()} + + if train_return_nodes is None: + return_nodes = to_strdict(return_nodes) + train_return_nodes = deepcopy(return_nodes) + eval_return_nodes = deepcopy(return_nodes) + else: + train_return_nodes = to_strdict(train_return_nodes) + eval_return_nodes = to_strdict(eval_return_nodes) + + # Repeat the tracing and graph rewriting for train and eval mode + tracers = {} + graphs = {} + mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes} + for mode in ["train", "eval"]: + if mode == "train": + model.train() + elif mode == "eval": + model.eval() + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model, concrete_args=concrete_args) + + name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = list(tracer.node_to_qualname.values()) + # FIXME We don't know if we should expect this to happen + if len(set(available_nodes)) != len(available_nodes): + raise ValueError( + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + ) + # Check that all outputs in return_nodes are present in the model + for query in mode_return_nodes[mode].keys(): + # To check if a query is available we need to check that at least + # one of the available names starts with it up to a . + if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]): + raise ValueError( + f"node: '{query}' is not present in model. Hint: use " + "`get_graph_node_names` to make sure the " + "`return_nodes` you specified are present. It may even " + "be that you need to specify `train_return_nodes` and " + "`eval_return_nodes` separately." + ) + + # Remove existing output nodes (train mode) + orig_output_nodes = [] + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_nodes.append(n) + if not orig_output_nodes: + raise ValueError("No output nodes found in graph_module.graph.nodes") + + for n in orig_output_nodes: + graph_module.graph.erase_node(n) + + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + module_qualname = tracer.node_to_qualname.get(n) + if module_qualname is None: + # NOTE - Know cases where this happens: + # - Node representing creation of a tensor constant - probably + # not interesting as a return node + # - When packing outputs into a named tuple like in InceptionV3 + continue + for query in mode_return_nodes[mode]: + depth = query.count(".") + if ".".join(module_qualname.split(".")[: depth + 1]) == query: + output_nodes[mode_return_nodes[mode][query]] = n + mode_return_nodes[mode].pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + # Keep track of the tracer and graph, so we can choose the main one + tracers[mode] = tracer + graphs[mode] = graph + + # Warn user if there are any discrepancies between the graphs of the + # train and eval modes + if not suppress_diff_warning: + _warn_graph_differences(tracers["train"], tracers["eval"]) + + # Build the final graph module + graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name) + + # Restore original training mode + model.train(is_training) + graph_module.train(is_training) + + return graph_module diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/inception.py b/.venv/lib/python3.11/site-packages/torchvision/models/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..447a7682d62c0e31570aaf1e32c2da55a0d697d1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/inception.py @@ -0,0 +1,478 @@ +import warnings +from collections import namedtuple +from functools import partial +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ..transforms._presets import ImageClassification +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] + + +InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"]) +InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]} + +# Script annotations failed with _GoogleNetOutputs = namedtuple ... +# _InceptionOutputs set here for backwards compat +_InceptionOutputs = InceptionOutputs + + +class Inception3(nn.Module): + def __init__( + self, + num_classes: int = 1000, + aux_logits: bool = True, + transform_input: bool = False, + inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, + init_weights: Optional[bool] = None, + dropout: float = 0.5, + ) -> None: + super().__init__() + _log_api_usage_once(self) + if inception_blocks is None: + inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] + if init_weights is None: + warnings.warn( + "The default weight initialization of inception_v3 will be changed in future releases of " + "torchvision. If you wish to keep the old behavior (which leads to long initialization times" + " due to scipy/scipy#11299), please set init_weights=True.", + FutureWarning, + ) + init_weights = True + if len(inception_blocks) != 7: + raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}") + conv_block = inception_blocks[0] + inception_a = inception_blocks[1] + inception_b = inception_blocks[2] + inception_c = inception_blocks[3] + inception_d = inception_blocks[4] + inception_e = inception_blocks[5] + inception_aux = inception_blocks[6] + + self.aux_logits = aux_logits + self.transform_input = transform_input + self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = inception_a(192, pool_features=32) + self.Mixed_5c = inception_a(256, pool_features=64) + self.Mixed_5d = inception_a(288, pool_features=64) + self.Mixed_6a = inception_b(288) + self.Mixed_6b = inception_c(768, channels_7x7=128) + self.Mixed_6c = inception_c(768, channels_7x7=160) + self.Mixed_6d = inception_c(768, channels_7x7=160) + self.Mixed_6e = inception_c(768, channels_7x7=192) + self.AuxLogits: Optional[nn.Module] = None + if aux_logits: + self.AuxLogits = inception_aux(768, num_classes) + self.Mixed_7a = inception_d(768) + self.Mixed_7b = inception_e(1280) + self.Mixed_7c = inception_e(2048) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(p=dropout) + self.fc = nn.Linear(2048, num_classes) + if init_weights: + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore + torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _transform_input(self, x: Tensor) -> Tensor: + if self.transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + return x + + def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]: + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.maxpool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.maxpool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + aux: Optional[Tensor] = None + if self.AuxLogits is not None: + if self.training: + aux = self.AuxLogits(x) + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + # Adaptive average pooling + x = self.avgpool(x) + # N x 2048 x 1 x 1 + x = self.dropout(x) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + x = self.fc(x) + # N x 1000 (num_classes) + return x, aux + + @torch.jit.unused + def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs: + if self.training and self.aux_logits: + return InceptionOutputs(x, aux) + else: + return x # type: ignore[return-value] + + def forward(self, x: Tensor) -> InceptionOutputs: + x = self._transform_input(x) + x, aux = self._forward(x) + aux_defined = self.training and self.aux_logits + if torch.jit.is_scripting(): + if not aux_defined: + warnings.warn("Scripted Inception3 always returns Inception3 Tuple") + return InceptionOutputs(x, aux) + else: + return self.eager_outputs(x, aux) + + +class InceptionA(nn.Module): + def __init__( + self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None + ) -> None: + super().__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x: Tensor) -> List[Tensor]: + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: + super().__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x: Tensor) -> List[Tensor]: + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + def __init__( + self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None + ) -> None: + super().__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x: Tensor) -> List[Tensor]: + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: + super().__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x: Tensor) -> List[Tensor]: + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: + super().__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x: Tensor) -> List[Tensor]: + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + def __init__( + self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None + ) -> None: + super().__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 # type: ignore[assignment] + self.fc = nn.Linear(768, num_classes) + self.fc.stddev = 0.001 # type: ignore[assignment] + + def forward(self, x: Tensor) -> Tensor: + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class Inception_V3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + transforms=partial(ImageClassification, crop_size=299, resize_size=342), + meta={ + "num_params": 27161264, + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.294, + "acc@5": 93.450, + } + }, + "_ops": 5.713, + "_file_size": 103.903, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) +def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: + """ + Inception v3 model architecture from + `Rethinking the Inception Architecture for Computer Vision `_. + + .. note:: + **Important**: In contrast to the other models the inception_v3 expects tensors with a size of + N x 3 x 299 x 299, so ensure your images are sized accordingly. + + Args: + weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.Inception_V3_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.Inception3`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Inception_V3_Weights + :members: + """ + weights = Inception_V3_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", True) + if weights is not None: + if "transform_input" not in kwargs: + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = Inception3(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/mnasnet.py b/.venv/lib/python3.11/site-packages/torchvision/models/mnasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5846111ab1c05b4ebca7ccf9240ac744267b859a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/mnasnet.py @@ -0,0 +1,434 @@ +import warnings +from functools import partial +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from ..transforms._presets import ImageClassification +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "MNASNet", + "MNASNet0_5_Weights", + "MNASNet0_75_Weights", + "MNASNet1_0_Weights", + "MNASNet1_3_Weights", + "mnasnet0_5", + "mnasnet0_75", + "mnasnet1_0", + "mnasnet1_3", +] + + +# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is +# 1.0 - tensorflow. +_BN_MOMENTUM = 1 - 0.9997 + + +class _InvertedResidual(nn.Module): + def __init__( + self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1 + ) -> None: + super().__init__() + if stride not in [1, 2]: + raise ValueError(f"stride should be 1 or 2 instead of {stride}") + if kernel_size not in [3, 5]: + raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}") + mid_ch = in_ch * expansion_factor + self.apply_residual = in_ch == out_ch and stride == 1 + self.layers = nn.Sequential( + # Pointwise + nn.Conv2d(in_ch, mid_ch, 1, bias=False), + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), + nn.ReLU(inplace=True), + # Depthwise + nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False), + nn.BatchNorm2d(mid_ch, momentum=bn_momentum), + nn.ReLU(inplace=True), + # Linear pointwise. Note that there's no activation. + nn.Conv2d(mid_ch, out_ch, 1, bias=False), + nn.BatchNorm2d(out_ch, momentum=bn_momentum), + ) + + def forward(self, input: Tensor) -> Tensor: + if self.apply_residual: + return self.layers(input) + input + else: + return self.layers(input) + + +def _stack( + in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float +) -> nn.Sequential: + """Creates a stack of inverted residuals.""" + if repeats < 1: + raise ValueError(f"repeats should be >= 1, instead got {repeats}") + # First one has no skip, because feature map size changes. + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum) + remaining = [] + for _ in range(1, repeats): + remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum)) + return nn.Sequential(first, *remaining) + + +def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: + """Asymmetric rounding to make `val` divisible by `divisor`. With default + bias, will round up, unless the number is no more than 10% greater than the + smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88.""" + if not 0.0 < round_up_bias < 1.0: + raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}") + new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) + return new_val if new_val >= round_up_bias * val else new_val + divisor + + +def _get_depths(alpha: float) -> List[int]: + """Scales tensor depths as in reference MobileNet code, prefers rounding up + rather than down.""" + depths = [32, 16, 24, 40, 80, 96, 192, 320] + return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] + + +class MNASNet(torch.nn.Module): + """MNASNet, as described in https://arxiv.org/abs/1807.11626. This + implements the B1 variant of the model. + >>> model = MNASNet(1.0, num_classes=1000) + >>> x = torch.rand(1, 3, 224, 224) + >>> y = model(x) + >>> y.dim() + 2 + >>> y.nelement() + 1000 + """ + + # Version 2 adds depth scaling in the initial stages of the network. + _version = 2 + + def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None: + super().__init__() + _log_api_usage_once(self) + if alpha <= 0.0: + raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}") + self.alpha = alpha + self.num_classes = num_classes + depths = _get_depths(alpha) + layers = [ + # First layer: regular conv. + nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + # Depthwise separable, no skip. + nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False), + nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM), + # MNASNet blocks: stacks of inverted residuals. + _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM), + _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM), + _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM), + _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM), + # Final mapping to classifier input. + nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + ] + self.layers = nn.Sequential(*layers) + self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid") + nn.init.zeros_(m.bias) + + def forward(self, x: Tensor) -> Tensor: + x = self.layers(x) + # Equivalent to global avgpool and removing H and W dimensions. + x = x.mean([2, 3]) + return self.classifier(x) + + def _load_from_state_dict( + self, + state_dict: Dict, + prefix: str, + local_metadata: Dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + version = local_metadata.get("version", None) + if version not in [1, 2]: + raise ValueError(f"version shluld be set to 1 or 2 instead of {version}") + + if version == 1 and not self.alpha == 1.0: + # In the initial version of the model (v1), stem was fixed-size. + # All other layer configurations were the same. This will patch + # the model so that it's identical to v1. Model with alpha 1.0 is + # unaffected. + depths = _get_depths(self.alpha) + v1_stem = [ + nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), + nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), + _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM), + ] + for idx, layer in enumerate(v1_stem): + self.layers[idx] = layer + + # The model is now identical to v1, and must be saved as such. + self._version = 1 + warnings.warn( + "A new version of MNASNet model has been implemented. " + "Your checkpoint was saved using the previous version. " + "This checkpoint will load and work as before, but " + "you may want to upgrade by training a newer model or " + "transfer learning from an updated ImageNet checkpoint.", + UserWarning, + ) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/1e100/mnasnet_trainer", +} + + +class MNASNet0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2218512, + "_metrics": { + "ImageNet-1K": { + "acc@1": 67.734, + "acc@5": 87.490, + } + }, + "_ops": 0.104, + "_file_size": 8.591, + "_docs": """These weights reproduce closely the results of the paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet0_75_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/6019", + "num_params": 3170208, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.180, + "acc@5": 90.496, + } + }, + "_ops": 0.215, + "_file_size": 12.303, + "_docs": """ + These weights were trained from scratch by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4383312, + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.456, + "acc@5": 91.510, + } + }, + "_ops": 0.314, + "_file_size": 16.915, + "_docs": """These weights reproduce closely the results of the paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet1_3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/6019", + "num_params": 6282256, + "_metrics": { + "ImageNet-1K": { + "acc@1": 76.506, + "acc@5": 93.522, + } + }, + "_ops": 0.526, + "_file_size": 24.246, + "_docs": """ + These weights were trained from scratch by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = MNASNet(alpha, **kwargs) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) +def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: + """MNASNet with depth multiplier of 0.5 from + `MnasNet: Platform-Aware Neural Architecture Search for Mobile + `_ paper. + + Args: + weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet0_5_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MNASNet0_5_Weights + :members: + """ + weights = MNASNet0_5_Weights.verify(weights) + + return _mnasnet(0.5, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1)) +def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: + """MNASNet with depth multiplier of 0.75 from + `MnasNet: Platform-Aware Neural Architecture Search for Mobile + `_ paper. + + Args: + weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet0_75_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MNASNet0_75_Weights + :members: + """ + weights = MNASNet0_75_Weights.verify(weights) + + return _mnasnet(0.75, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) +def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: + """MNASNet with depth multiplier of 1.0 from + `MnasNet: Platform-Aware Neural Architecture Search for Mobile + `_ paper. + + Args: + weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet1_0_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MNASNet1_0_Weights + :members: + """ + weights = MNASNet1_0_Weights.verify(weights) + + return _mnasnet(1.0, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1)) +def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: + """MNASNet with depth multiplier of 1.3 from + `MnasNet: Platform-Aware Neural Architecture Search for Mobile + `_ paper. + + Args: + weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet1_3_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MNASNet1_3_Weights + :members: + """ + weights = MNASNet1_3_Weights.verify(weights) + + return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/mobilenet.py b/.venv/lib/python3.11/site-packages/torchvision/models/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..0a270d14d3a4ad9eda62b68c2c01e9fdb710ef38 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/mobilenet.py @@ -0,0 +1,6 @@ +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all + +__all__ = mv2_all + mv3_all diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/mobilenetv2.py b/.venv/lib/python3.11/site-packages/torchvision/models/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb6a4981d6dec73c563dfc845329a58eaa3e083 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/mobilenetv2.py @@ -0,0 +1,260 @@ +from functools import partial +from typing import Any, Callable, List, Optional + +import torch +from torch import nn, Tensor + +from ..ops.misc import Conv2dNormActivation +from ..transforms._presets import ImageClassification +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface + + +__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] + + +# necessary for backwards compatibility +class InvertedResidual(nn.Module): + def __init__( + self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super().__init__() + self.stride = stride + if stride not in [1, 2]: + raise ValueError(f"stride should be 1 or 2 instead of {stride}") + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers: List[nn.Module] = [] + if expand_ratio != 1: + # pw + layers.append( + Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) + ) + layers.extend( + [ + # dw + Conv2dNormActivation( + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ] + ) + self.conv = nn.Sequential(*layers) + self.out_channels = oup + self._is_cn = stride > 1 + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__( + self, + num_classes: int = 1000, + width_mult: float = 1.0, + inverted_residual_setting: Optional[List[List[int]]] = None, + round_nearest: int = 8, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + dropout: float = 0.2, + ) -> None: + """ + MobileNet V2 main class + + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + norm_layer: Module specifying the normalization layer to use + dropout (float): The droupout probability + + """ + super().__init__() + _log_api_usage_once(self) + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError( + f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" + ) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features: List[nn.Module] = [ + Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) + ] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) + input_channel = output_channel + # building last several layers + features.append( + Conv2dNormActivation( + input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 + ) + ) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(p=dropout), + nn.Linear(self.last_channel, num_classes), + ) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + x = self.features(x) + # Cannot use "squeeze" as batch-size can be 1 + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +_COMMON_META = { + "num_params": 3504872, + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, +} + + +class MobileNet_V2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.878, + "acc@5": 90.286, + } + }, + "_ops": 0.301, + "_file_size": 13.555, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.154, + "acc@5": 90.822, + } + }, + "_ops": 0.301, + "_file_size": 13.598, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) +def mobilenet_v2( + *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV2: + """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear + Bottlenecks `_ paper. + + Args: + weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MobileNet_V2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.MobileNet_V2_Weights + :members: + """ + weights = MobileNet_V2_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = MobileNetV2(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89d2302f825ff0dbe25d02f6dc7c84d3c0790ad0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__init__.py @@ -0,0 +1 @@ +from .raft import * diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..310a7178d6306ab02a626a6438079db4d074a854 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da64fcd9c0b3f670dc7842cc0e4a6d234594d3e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/raft.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/raft.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc3ec96b0955411eeacd3bf5fce7331514cad55 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/__pycache__/raft.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/_utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2454a27315d6e560dccb6ea2ce6083da03e256 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/_utils.py @@ -0,0 +1,48 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None): + """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates.""" + h, w = img.shape[-2:] + + xgrid, ygrid = absolute_grid.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (w - 1) - 1 + # Adding condition if h > 1 to enable this function be reused in raft-stereo + if h > 1: + ygrid = 2 * ygrid / (h - 1) - 1 + normalized_grid = torch.cat([xgrid, ygrid], dim=-1) + + return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners) + + +def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"): + device = torch.device(device) + coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij") + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch_size, 1, 1, 1) + + +def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8): + """Upsample flow by the input factor (default 8). + + If up_mask is None we just interpolate. + If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B. + Note that in appendix B the picture assumes a downsample factor of 4 instead of 8. + """ + batch_size, num_channels, h, w = flow.shape + new_h, new_w = h * factor, w * factor + + if up_mask is None: + return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True) + + up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w) + up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1 + + upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w) + upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2) + + return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/raft.py b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..c294777ee6ffc0a9151f76f13bf2bde018580f9e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/optical_flow/raft.py @@ -0,0 +1,947 @@ +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.ops import Conv2dNormActivation + +from ...transforms._presets import OpticalFlow +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._utils import handle_legacy_interface +from ._utils import grid_sample, make_coords_grid, upsample_flow + + +__all__ = ( + "RAFT", + "raft_large", + "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", +) + + +class ResidualBlock(nn.Module): + """Slightly modified Residual block with extra relu and biases.""" + + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False): + super().__init__() + + # Note regarding bias=True: + # Usually we can pass bias=False in conv layers followed by a norm layer. + # But in the RAFT training reference, the BatchNorm2d layers are only activated for the first dataset, + # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful + # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm + # because these aren't frozen, but we don't bother (also, we wouldn't be able to load the original weights). + self.convnormrelu1 = Conv2dNormActivation( + in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True + ) + self.convnormrelu2 = Conv2dNormActivation( + out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True + ) + + # make mypy happy + self.downsample: nn.Module + + if stride == 1 and not always_project: + self.downsample = nn.Identity() + else: + self.downsample = Conv2dNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + y = x + y = self.convnormrelu1(y) + y = self.convnormrelu2(y) + + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + """Slightly modified BottleNeck block (extra relu and biases)""" + + def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): + super().__init__() + + # See note in ResidualBlock for the reason behind bias=True + self.convnormrelu1 = Conv2dNormActivation( + in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True + ) + self.convnormrelu2 = Conv2dNormActivation( + out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True + ) + self.convnormrelu3 = Conv2dNormActivation( + out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True + ) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = Conv2dNormActivation( + in_channels, + out_channels, + norm_layer=norm_layer, + kernel_size=1, + stride=stride, + bias=True, + activation_layer=None, + ) + + def forward(self, x): + y = x + y = self.convnormrelu1(y) + y = self.convnormrelu2(y) + y = self.convnormrelu3(y) + + x = self.downsample(x) + + return self.relu(x + y) + + +class FeatureEncoder(nn.Module): + """The feature encoder, used both as the actual feature encoder, and as the context encoder. + + It must downsample its input by 8. + """ + + def __init__( + self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d + ): + super().__init__() + + if len(layers) != 5: + raise ValueError(f"The expected number of layers is 5, instead got {len(layers)}") + + # See note in ResidualBlock for the reason behind bias=True + self.convnormrelu = Conv2dNormActivation( + 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=strides[0], bias=True + ) + + self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=strides[1]) + self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=strides[2]) + self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=strides[3]) + + self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + num_downsamples = len(list(filter(lambda s: s == 2, strides))) + self.output_dim = layers[-1] + self.downsample_factor = 2**num_downsamples + + def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride): + block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride) + block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1) + return nn.Sequential(block1, block2) + + def forward(self, x): + x = self.convnormrelu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv(x) + + return x + + +class MotionEncoder(nn.Module): + """The motion encoder, part of the update block. + + Takes the current predicted flow and the correlation features as input and returns an encoded version of these. + """ + + def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128): + super().__init__() + + if len(flow_layers) != 2: + raise ValueError(f"The expected number of flow_layers is 2, instead got {len(flow_layers)}") + if len(corr_layers) not in (1, 2): + raise ValueError(f"The number of corr_layers should be 1 or 2, instead got {len(corr_layers)}") + + self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) + if len(corr_layers) == 2: + self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) + else: + self.convcorr2 = nn.Identity() + + self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) + self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) + + # out_channels - 2 because we cat the flow (2 channels) at the end + self.conv = Conv2dNormActivation( + corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 + ) + + self.out_channels = out_channels + + def forward(self, flow, corr_features): + corr = self.convcorr1(corr_features) + corr = self.convcorr2(corr) + + flow_orig = flow + flow = self.convflow1(flow) + flow = self.convflow2(flow) + + corr_flow = torch.cat([corr, flow], dim=1) + corr_flow = self.conv(corr_flow) + return torch.cat([corr_flow, flow_orig], dim=1) + + +class ConvGRU(nn.Module): + """Convolutional Gru unit.""" + + def __init__(self, *, input_size, hidden_size, kernel_size, padding): + super().__init__() + self.convz = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + self.convr = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + self.convq = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + return h + + +def _pass_through_h(h, _): + # Declared here for torchscript + return h + + +class RecurrentBlock(nn.Module): + """Recurrent block, part of the update block. + + Takes the current hidden state and the concatenation of (motion encoder output, context) as input. + Returns an updated hidden state. + """ + + def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))): + super().__init__() + + if len(kernel_size) != len(padding): + raise ValueError( + f"kernel_size should have the same length as padding, instead got len(kernel_size) = {len(kernel_size)} and len(padding) = {len(padding)}" + ) + if len(kernel_size) not in (1, 2): + raise ValueError(f"kernel_size should either 1 or 2, instead got {len(kernel_size)}") + + self.convgru1 = ConvGRU( + input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0] + ) + if len(kernel_size) == 2: + self.convgru2 = ConvGRU( + input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[1], padding=padding[1] + ) + else: + self.convgru2 = _pass_through_h + + self.hidden_size = hidden_size + + def forward(self, h, x): + h = self.convgru1(h, x) + h = self.convgru2(h, x) + return h + + +class FlowHead(nn.Module): + """Flow head, part of the update block. + + Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow". + """ + + def __init__(self, *, in_channels, hidden_size): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, hidden_size, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_size, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class UpdateBlock(nn.Module): + """The update block which contains the motion encoder, the recurrent block, and the flow head. + + It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block. + """ + + def __init__(self, *, motion_encoder, recurrent_block, flow_head): + super().__init__() + self.motion_encoder = motion_encoder + self.recurrent_block = recurrent_block + self.flow_head = flow_head + + self.hidden_state_size = recurrent_block.hidden_size + + def forward(self, hidden_state, context, corr_features, flow): + motion_features = self.motion_encoder(flow, corr_features) + x = torch.cat([context, motion_features], dim=1) + + hidden_state = self.recurrent_block(hidden_state, x) + delta_flow = self.flow_head(hidden_state) + return hidden_state, delta_flow + + +class MaskPredictor(nn.Module): + """Mask predictor to be used when upsampling the predicted flow. + + It takes the hidden state of the recurrent unit as input and outputs the mask. + This is not used in the raft-small model. + """ + + def __init__(self, *, in_channels, hidden_size, multiplier=0.25): + super().__init__() + self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) + # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder, + # and we interpolate with all 9 surrounding neighbors. See paper and appendix B. + self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) + + # In the original code, they use a factor of 0.25 to "downweight the gradients" of that branch. + # See e.g. https://github.com/princeton-vl/RAFT/issues/119#issuecomment-953950419 + # or https://github.com/princeton-vl/RAFT/issues/24. + # It doesn't seem to affect epe significantly and can likely be set to 1. + self.multiplier = multiplier + + def forward(self, x): + x = self.convrelu(x) + x = self.conv(x) + return self.multiplier * x + + +class CorrBlock(nn.Module): + """The correlation block. + + Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder, + and then indexes from this pyramid to create correlation features. + The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that + are within a ``radius``, according to the infinity norm (see paper section 3.2). + Note: typo in the paper, it should be infinity norm, not 1-norm. + """ + + def __init__(self, *, num_levels: int = 4, radius: int = 4): + super().__init__() + self.num_levels = num_levels + self.radius = radius + + self.corr_pyramid: List[Tensor] = [torch.tensor(0)] # useless, but torchscript is otherwise confused :') + + # The neighborhood of a centroid pixel x' is {x' + delta, ||delta||_inf <= radius} + # so it's a square surrounding x', and its sides have a length of 2 * radius + 1 + # The paper claims that it's ||.||_1 instead of ||.||_inf but it's a typo: + # https://github.com/princeton-vl/RAFT/issues/122 + self.out_channels = num_levels * (2 * radius + 1) ** 2 + + def build_pyramid(self, fmap1, fmap2): + """Build the correlation pyramid from two feature maps. + + The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2) + The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions + to build the correlation pyramid. + """ + + if fmap1.shape != fmap2.shape: + raise ValueError( + f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)" + ) + + # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2. + # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would + # produce nans in its output. + min_fmap_size = 2 * (2 ** (self.num_levels - 1)) + if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]): + raise ValueError( + "Feature maps are too small to be down-sampled by the correlation pyramid. " + f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. " + "Remember that input images to the model are downsampled by 8, so that means their " + f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}." + ) + + corr_volume = self._compute_corr_volume(fmap1, fmap2) + + batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w + corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w) + self.corr_pyramid = [corr_volume] + for _ in range(self.num_levels - 1): + corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2) + self.corr_pyramid.append(corr_volume) + + def index_pyramid(self, centroids_coords): + """Return correlation features by indexing from the pyramid.""" + neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels + di = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len) + delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device) + delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2) + + batch_size, _, h, w = centroids_coords.shape # _ = 2 + centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2) + + indexed_pyramid = [] + for corr_volume in self.corr_pyramid: + sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2) + indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view( + batch_size, h, w, -1 + ) + indexed_pyramid.append(indexed_corr_volume) + centroids_coords = centroids_coords / 2 + + corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() + + expected_output_shape = (batch_size, self.out_channels, h, w) + if corr_features.shape != expected_output_shape: + raise ValueError( + f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}" + ) + + return corr_features + + def _compute_corr_volume(self, fmap1, fmap2): + batch_size, num_channels, h, w = fmap1.shape + fmap1 = fmap1.view(batch_size, num_channels, h * w) + fmap2 = fmap2.view(batch_size, num_channels, h * w) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch_size, h, w, 1, h, w) + return corr / torch.sqrt(torch.tensor(num_channels)) + + +class RAFT(nn.Module): + def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block, mask_predictor=None): + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + args: + feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8. + Its input is the concatenation of ``image1`` and ``image2``. + context_encoder (nn.Module): The context encoder. It must downsample the input by 8. + Its input is ``image1``. As in the original implementation, its output will be split into 2 parts: + + - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block`` + - one part will be used to initialize the hidden state of the recurrent unit of + the ``update_block`` + + These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output + of the ``context_encoder`` must be strictly greater than ``hidden_state_size``. + + corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the + ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose + 2 methods: + + - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the + output of the ``feature_encoder``). + - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns + the correlation features. See paper section 3.2. + + It must expose an ``out_channels`` attribute. + + update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the + flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation + features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow`` + prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute. + mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow. + The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B. + If ``None`` (default), the flow is upsampled using interpolation. + """ + super().__init__() + _log_api_usage_once(self) + + self.feature_encoder = feature_encoder + self.context_encoder = context_encoder + self.corr_block = corr_block + self.update_block = update_block + + self.mask_predictor = mask_predictor + + if not hasattr(self.update_block, "hidden_state_size"): + raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.") + + def forward(self, image1, image2, num_flow_updates: int = 12): + + batch_size, _, h, w = image1.shape + if (h, w) != image2.shape[-2:]: + raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") + if not (h % 8 == 0) and (w % 8 == 0): + raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") + + fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) + fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) + if fmap1.shape[-2:] != (h // 8, w // 8): + raise ValueError("The feature encoder should downsample H and W by 8") + + self.corr_block.build_pyramid(fmap1, fmap2) + + context_out = self.context_encoder(image1) + if context_out.shape[-2:] != (h // 8, w // 8): + raise ValueError("The context encoder should downsample H and W by 8") + + # As in the original paper, the actual output of the context encoder is split in 2 parts: + # - one part is used to initialize the hidden state of the recurent units of the update block + # - the rest is the "actual" context. + hidden_state_size = self.update_block.hidden_state_size + out_channels_context = context_out.shape[1] - hidden_state_size + if out_channels_context <= 0: + raise ValueError( + f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels" + ) + hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1) + hidden_state = torch.tanh(hidden_state) + context = F.relu(context) + + coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device) + + flow_predictions = [] + for _ in range(num_flow_updates): + coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper + corr_features = self.corr_block.index_pyramid(centroids_coords=coords1) + + flow = coords1 - coords0 + hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow) + + coords1 = coords1 + delta_flow + + up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) + upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask) + flow_predictions.append(upsampled_flow) + + return flow_predictions + + +_COMMON_META = { + "min_size": (128, 128), +} + + +class Raft_Large_Weights(WeightsEnum): + """The metrics reported here are as follows. + + ``epe`` is the "end-point-error" and indicates how far (in pixels) the + predicted flow is from its true value. This is averaged over all pixels + of all images. ``per_image_epe`` is similar, but the average is different: + the epe is first computed on each image independently, and then averaged + over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") + in the original paper, and it's only used on Kitti. ``fl-all`` is also a + Kitti-specific metric, defined by the author of the dataset and used for the + Kitti leaderboard. It corresponds to the average of pixels whose epe is + either <3px, or <5% of flow's 2-norm. + """ + + C_T_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 1.4411}, + "Sintel-Train-Finalpass": {"epe": 2.7894}, + "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """These weights were ported from the original paper. They + are trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + + C_T_V2 = Weights( + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 1.3822}, + "Sintel-Train-Finalpass": {"epe": 2.7161}, + "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """These weights were trained from scratch on + :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + + C_T_SKHT_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Sintel-Test-Cleanpass": {"epe": 1.94}, + "Sintel-Test-Finalpass": {"epe": 3.18}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were ported from the original paper. They are + trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on + Sintel. The Sintel fine-tuning step is a combination of + :class:`~torchvision.datasets.Sintel`, + :class:`~torchvision.datasets.KittiFlow`, + :class:`~torchvision.datasets.HD1K`, and + :class:`~torchvision.datasets.FlyingThings3D` (clean pass). + """, + }, + ) + + C_T_SKHT_V2 = Weights( + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Sintel-Test-Cleanpass": {"epe": 1.819}, + "Sintel-Test-Finalpass": {"epe": 3.067}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were trained from scratch. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D` and then + fine-tuned on Sintel. The Sintel fine-tuning step is a + combination of :class:`~torchvision.datasets.Sintel`, + :class:`~torchvision.datasets.KittiFlow`, + :class:`~torchvision.datasets.HD1K`, and + :class:`~torchvision.datasets.FlyingThings3D` (clean pass). + """, + }, + ) + + C_T_SKHT_K_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Kitti-Test": {"fl_all": 5.10}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were ported from the original paper. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`, + fine-tuned on Sintel, and then fine-tuned on + :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning + step was described above. + """, + }, + ) + + C_T_SKHT_K_V2 = Weights( + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Kitti-Test": {"fl_all": 5.19}, + }, + "_ops": 211.007, + "_file_size": 20.129, + "_docs": """ + These weights were trained from scratch. They are + pre-trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`, + fine-tuned on Sintel, and then fine-tuned on + :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning + step was described above. + """, + }, + ) + + DEFAULT = C_T_SKHT_V2 + + +class Raft_Small_Weights(WeightsEnum): + """The metrics reported here are as follows. + + ``epe`` is the "end-point-error" and indicates how far (in pixels) the + predicted flow is from its true value. This is averaged over all pixels + of all images. ``per_image_epe`` is similar, but the average is different: + the epe is first computed on each image independently, and then averaged + over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe") + in the original paper, and it's only used on Kitti. ``fl-all`` is also a + Kitti-specific metric, defined by the author of the dataset and used for the + Kitti leaderboard. It corresponds to the average of pixels whose epe is + either <3px, or <5% of flow's 2-norm. + """ + + C_T_V1 = Weights( + # Weights ported from https://github.com/princeton-vl/RAFT + url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/princeton-vl/RAFT", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 2.1231}, + "Sintel-Train-Finalpass": {"epe": 3.2790}, + "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801}, + }, + "_ops": 47.655, + "_file_size": 3.821, + "_docs": """These weights were ported from the original paper. They + are trained on :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + C_T_V2 = Weights( + url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", + transforms=OpticalFlow, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "_metrics": { + "Sintel-Train-Cleanpass": {"epe": 1.9901}, + "Sintel-Train-Finalpass": {"epe": 3.2831}, + "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369}, + }, + "_ops": 47.655, + "_file_size": 3.821, + "_docs": """These weights were trained from scratch on + :class:`~torchvision.datasets.FlyingChairs` + + :class:`~torchvision.datasets.FlyingThings3D`.""", + }, + ) + + DEFAULT = C_T_V2 + + +def _raft( + *, + weights=None, + progress=False, + # Feature encoder + feature_encoder_layers, + feature_encoder_block, + feature_encoder_norm_layer, + # Context encoder + context_encoder_layers, + context_encoder_block, + context_encoder_norm_layer, + # Correlation block + corr_block_num_levels, + corr_block_radius, + # Motion encoder + motion_encoder_corr_layers, + motion_encoder_flow_layers, + motion_encoder_out_channels, + # Recurrent block + recurrent_block_hidden_state_size, + recurrent_block_kernel_size, + recurrent_block_padding, + # Flow Head + flow_head_hidden_size, + # Mask predictor + use_mask_predictor, + **kwargs, +): + feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder( + block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer + ) + context_encoder = kwargs.pop("context_encoder", None) or FeatureEncoder( + block=context_encoder_block, layers=context_encoder_layers, norm_layer=context_encoder_norm_layer + ) + + corr_block = kwargs.pop("corr_block", None) or CorrBlock(num_levels=corr_block_num_levels, radius=corr_block_radius) + + update_block = kwargs.pop("update_block", None) + if update_block is None: + motion_encoder = MotionEncoder( + in_channels_corr=corr_block.out_channels, + corr_layers=motion_encoder_corr_layers, + flow_layers=motion_encoder_flow_layers, + out_channels=motion_encoder_out_channels, + ) + + # See comments in forward pass of RAFT class about why we split the output of the context encoder + out_channels_context = context_encoder_layers[-1] - recurrent_block_hidden_state_size + recurrent_block = RecurrentBlock( + input_size=motion_encoder.out_channels + out_channels_context, + hidden_size=recurrent_block_hidden_state_size, + kernel_size=recurrent_block_kernel_size, + padding=recurrent_block_padding, + ) + + flow_head = FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size) + + update_block = UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head) + + mask_predictor = kwargs.pop("mask_predictor", None) + if mask_predictor is None and use_mask_predictor: + mask_predictor = MaskPredictor( + in_channels=recurrent_block_hidden_state_size, + hidden_size=256, + multiplier=0.25, # See comment in MaskPredictor about this + ) + + model = RAFT( + feature_encoder=feature_encoder, + context_encoder=context_encoder, + corr_block=corr_block, + update_block=update_block, + mask_predictor=mask_predictor, + **kwargs, # not really needed, all params should be consumed by now + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT: + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Please see the example below for a tutorial on how to use this model. + + Args: + weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.optical_flow.Raft_Large_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights + :members: + """ + + weights = Raft_Large_Weights.verify(weights) + + return _raft( + weights=weights, + progress=progress, + # Feature encoder + feature_encoder_layers=(64, 64, 96, 128, 256), + feature_encoder_block=ResidualBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(64, 64, 96, 128, 256), + context_encoder_block=ResidualBlock, + context_encoder_norm_layer=BatchNorm2d, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=4, + # Motion encoder + motion_encoder_corr_layers=(256, 192), + motion_encoder_flow_layers=(128, 64), + motion_encoder_out_channels=128, + # Recurrent block + recurrent_block_hidden_state_size=128, + recurrent_block_kernel_size=((1, 5), (5, 1)), + recurrent_block_padding=((0, 2), (2, 0)), + # Flow head + flow_head_hidden_size=256, + # Mask predictor + use_mask_predictor=True, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: + """RAFT "small" model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `__. + + Please see the example below for a tutorial on how to use this model. + + Args: + weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.optical_flow.Raft_Small_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights + :members: + """ + weights = Raft_Small_Weights.verify(weights) + + return _raft( + weights=weights, + progress=progress, + # Feature encoder + feature_encoder_layers=(32, 32, 64, 96, 128), + feature_encoder_block=BottleneckBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(32, 32, 64, 96, 160), + context_encoder_block=BottleneckBlock, + context_encoder_norm_layer=None, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=3, + # Motion encoder + motion_encoder_corr_layers=(96,), + motion_encoder_flow_layers=(64, 32), + motion_encoder_out_channels=82, + # Recurrent block + recurrent_block_hidden_state_size=96, + recurrent_block_kernel_size=(3,), + recurrent_block_padding=(1,), + # Flow head + flow_head_hidden_size=128, + # Mask predictor + use_mask_predictor=False, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da8bbba3567b0b9110429354d89b65ec679a2fd5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__init__.py @@ -0,0 +1,5 @@ +from .googlenet import * +from .inception import * +from .mobilenet import * +from .resnet import * +from .shufflenetv2 import * diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d62eef18ccfbbf20167b4ee8465bfc3914c59c9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/googlenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/googlenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d41deca1d41dd1c51e05045c2b5cedfbe819755 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/googlenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/inception.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/inception.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b69ab0373a3ad36e0e930d8ab05db236330c0ccd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/inception.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..436825085a01686928a8652d4d4583659bc90ee9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec2f6daef013780e48704dba70348b11a0c828f3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a6f05b5e46bcd0b9adab6de0effcd8ed6dc34f2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/mobilenetv3.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/resnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be165fa79588b48f786ca2f69e87bc45557b2e50 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/resnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/shufflenetv2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/shufflenetv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..373d2b9211563566cac92f29ce7d58df72707f8a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/shufflenetv2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a57278a6bdfa33d50fbbfb7f23c6e84687d14c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/googlenet.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/googlenet.py new file mode 100644 index 0000000000000000000000000000000000000000..30ef3356ba13108b9bdc4c90a9ab4cb7f92e445a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/googlenet.py @@ -0,0 +1,210 @@ +import warnings +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableGoogLeNet", + "GoogLeNet_QuantizedWeights", + "googlenet", +] + + +class QuantizableBasicConv2d(BasicConv2d): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True) + + +class QuantizableInception(Inception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.cat = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.cat.cat(outputs, 1) + + +class QuantizableInceptionAux(InceptionAux): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 + x = F.adaptive_avg_pool2d(x, (4, 4)) + # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 + x = self.conv(x) + # N x 128 x 4 x 4 + x = torch.flatten(x, 1) + # N x 2048 + x = self.relu(self.fc1(x)) + # N x 1024 + x = self.dropout(x) + # N x 1024 + x = self.fc2(x) + # N x 1000 (num_classes) + + return x + + +class QuantizableGoogLeNet(GoogLeNet): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__( # type: ignore[misc] + *args, blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], **kwargs + ) + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> GoogLeNetOutputs: + x = self._transform_input(x) + x = self.quant(x) + x, aux1, aux2 = self._forward(x) + x = self.dequant(x) + aux_defined = self.training and self.aux_logits + if torch.jit.is_scripting(): + if not aux_defined: + warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple") + return GoogLeNetOutputs(x, aux2, aux1) + else: + return self.eager_outputs(x, aux2, aux1) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in googlenet model + + Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. + Model is modified in place. Note that this operation does not change numerics + and the model after modification is in floating point + """ + + for m in self.modules(): + if type(m) is QuantizableBasicConv2d: + m.fuse_model(is_qat) + + +class GoogLeNet_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c81f6644.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "num_params": 6624904, + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.826, + "acc@5": 89.404, + } + }, + "_ops": 1.498, + "_file_size": 12.618, + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_googlenet") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else GoogLeNet_Weights.IMAGENET1K_V1, + ) +) +def googlenet( + *, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableGoogLeNet: + """GoogLeNet (Inception v1) model architecture from `Going Deeper with Convolutions `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` or :class:`~torchvision.models.GoogLeNet_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableGoogLeNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.GoogLeNet_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.GoogLeNet_Weights + :members: + :noindex: + """ + weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: + if "transform_input" not in kwargs: + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableGoogLeNet(**kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if not original_aux_logits: + model.aux_logits = False + model.aux1 = None # type: ignore[assignment] + model.aux2 = None # type: ignore[assignment] + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/inception.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..75c126697e99befd6ae7d3c1ee88fb8542e06d31 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/inception.py @@ -0,0 +1,273 @@ +import warnings +from functools import partial +from typing import Any, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torchvision.models import inception as inception_module +from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableInception3", + "Inception_V3_QuantizedWeights", + "inception_v3", +] + + +class QuantizableBasicConv2d(inception_module.BasicConv2d): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True) + + +class QuantizableInceptionA(inception_module.InceptionA): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionB(inception_module.InceptionB): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionC(inception_module.InceptionC): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionD(inception_module.InceptionD): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop.cat(outputs, 1) + + +class QuantizableInceptionE(inception_module.InceptionE): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + self.myop1 = nn.quantized.FloatFunctional() + self.myop2 = nn.quantized.FloatFunctional() + self.myop3 = nn.quantized.FloatFunctional() + + def _forward(self, x: Tensor) -> List[Tensor]: + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] + branch3x3 = self.myop1.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = self.myop2.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x: Tensor) -> Tensor: + outputs = self._forward(x) + return self.myop3.cat(outputs, 1) + + +class QuantizableInceptionAux(inception_module.InceptionAux): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs) # type: ignore[misc] + + +class QuantizableInception3(inception_module.Inception3): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__( # type: ignore[misc] + *args, + inception_blocks=[ + QuantizableBasicConv2d, + QuantizableInceptionA, + QuantizableInceptionB, + QuantizableInceptionC, + QuantizableInceptionD, + QuantizableInceptionE, + QuantizableInceptionAux, + ], + **kwargs, + ) + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> InceptionOutputs: + x = self._transform_input(x) + x = self.quant(x) + x, aux = self._forward(x) + x = self.dequant(x) + aux_defined = self.training and self.aux_logits + if torch.jit.is_scripting(): + if not aux_defined: + warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple") + return InceptionOutputs(x, aux) + else: + return self.eager_outputs(x, aux) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in inception model + + Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. + Model is modified in place. Note that this operation does not change numerics + and the model after modification is in floating point + """ + + for m in self.modules(): + if type(m) is QuantizableBasicConv2d: + m.fuse_model(is_qat) + + +class Inception_V3_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-a2837893.pth", + transforms=partial(ImageClassification, crop_size=299, resize_size=342), + meta={ + "num_params": 27161264, + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": Inception_V3_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.176, + "acc@5": 93.354, + } + }, + "_ops": 5.713, + "_file_size": 23.146, + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_inception_v3") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else Inception_V3_Weights.IMAGENET1K_V1, + ) +) +def inception_v3( + *, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableInception3: + r"""Inception v3 model architecture from + `Rethinking the Inception Architecture for Computer Vision `__. + + .. note:: + **Important**: In contrast to the other models the inception_v3 expects tensors with a size of + N x 3 x 299 x 299, so ensure your images are sized accordingly. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.Inception_V3_QuantizedWeights` or :class:`~torchvision.models.Inception_V3_Weights`, optional): The pretrained + weights for the model. See + :class:`~torchvision.models.quantization.Inception_V3_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableInception3`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.Inception_V3_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.Inception_V3_Weights + :members: + :noindex: + """ + weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: + if "transform_input" not in kwargs: + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableInception3(**kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + if quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + if not quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenet.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..0a270d14d3a4ad9eda62b68c2c01e9fdb710ef38 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenet.py @@ -0,0 +1,6 @@ +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all + +__all__ = mv2_all + mv3_all diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenetv2.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..4700bb4af931072f1aee3403c1e8c461ec33c76d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenetv2.py @@ -0,0 +1,154 @@ +from functools import partial +from typing import Any, Optional, Union + +from torch import nn, Tensor +from torch.ao.quantization import DeQuantStub, QuantStub +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNet_V2_Weights, MobileNetV2 + +from ...ops.misc import Conv2dNormActivation +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableMobileNetV2", + "MobileNet_V2_QuantizedWeights", + "mobilenet_v2", +] + + +class QuantizableInvertedResidual(InvertedResidual): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return self.skip_add.add(x, self.conv(x)) + else: + return self.conv(x) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + for idx in range(len(self.conv)): + if type(self.conv[idx]) is nn.Conv2d: + _fuse_modules(self.conv, [str(idx), str(idx + 1)], is_qat, inplace=True) + + +class QuantizableMobileNetV2(MobileNetV2): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + MobileNet V2 main class + + Args: + Inherits args from floating point MobileNetV2 + """ + super().__init__(*args, **kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + for m in self.modules(): + if type(m) is Conv2dNormActivation: + _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True) + if type(m) is QuantizableInvertedResidual: + m.fuse_model(is_qat) + + +class MobileNet_V2_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "num_params": 3504872, + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "qnnpack", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", + "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.658, + "acc@5": 90.150, + } + }, + "_ops": 0.301, + "_file_size": 3.423, + "_docs": """ + These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@register_model(name="quantized_mobilenet_v2") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V2_Weights.IMAGENET1K_V1, + ) +) +def mobilenet_v2( + *, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV2: + """ + Constructs a MobileNetV2 architecture from + `MobileNetV2: Inverted Residuals and Linear Bottlenecks + `_. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.MobileNet_V2_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V2_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.MobileNet_V2_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + quantize (bool, optional): If True, returns a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableMobileNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.quantization.MobileNet_V2_QuantizedWeights + :members: + .. autoclass:: torchvision.models.MobileNet_V2_Weights + :members: + :noindex: + """ + weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenetv3.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fdcfec9570d35683efb10344e667d3f4487fce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/mobilenetv3.py @@ -0,0 +1,237 @@ +from functools import partial +from typing import Any, List, Optional, Union + +import torch +from torch import nn, Tensor +from torch.ao.quantization import DeQuantStub, QuantStub + +from ...ops.misc import Conv2dNormActivation, SqueezeExcitation +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..mobilenetv3 import ( + _mobilenet_v3_conf, + InvertedResidual, + InvertedResidualConfig, + MobileNet_V3_Large_Weights, + MobileNetV3, +) +from .utils import _fuse_modules, _replace_relu + + +__all__ = [ + "QuantizableMobileNetV3", + "MobileNet_V3_Large_QuantizedWeights", + "mobilenet_v3_large", +] + + +class QuantizableSqueezeExcitation(SqueezeExcitation): + _version = 2 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["scale_activation"] = nn.Hardsigmoid + super().__init__(*args, **kwargs) + self.skip_mul = nn.quantized.FloatFunctional() + + def forward(self, input: Tensor) -> Tensor: + return self.skip_mul.mul(self._scale(input), input) + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if hasattr(self, "qconfig") and (version is None or version < 2): + default_state_dict = { + "scale_activation.activation_post_process.scale": torch.tensor([1.0]), + "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]), + "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), + "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor( + [0], dtype=torch.int32 + ), + "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), + "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), + } + for k, v in default_state_dict.items(): + full_key = prefix + k + if full_key not in state_dict: + state_dict[full_key] = v + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class QuantizableInvertedResidual(InvertedResidual): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs) # type: ignore[misc] + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return self.skip_add.add(x, self.block(x)) + else: + return self.block(x) + + +class QuantizableMobileNetV3(MobileNetV3): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + MobileNet V3 main class + + Args: + Inherits args from floating point MobileNetV3 + """ + super().__init__(*args, **kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + for m in self.modules(): + if type(m) is Conv2dNormActivation: + modules_to_fuse = ["0", "1"] + if len(m) == 3 and type(m[2]) is nn.ReLU: + modules_to_fuse.append("2") + _fuse_modules(m, modules_to_fuse, is_qat, inplace=True) + elif type(m) is QuantizableSqueezeExcitation: + m.fuse_model(is_qat) + + +def _mobilenet_v3_model( + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + weights: Optional[WeightsEnum], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableMobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + + model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + + if quantize: + # Instead of quantizing the model and then loading the quantized weights we take a different approach. + # We prepare the QAT model, load the QAT weights from training and then convert it. + # This is done to avoid extremely low accuracies observed on the specific model. This is rather a workaround + # for an unresolved bug on the eager quantization API detailed at: https://github.com/pytorch/vision/issues/5890 + model.fuse_model(is_qat=True) + model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) + torch.ao.quantization.prepare_qat(model, inplace=True) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + if quantize: + torch.ao.quantization.convert(model, inplace=True) + model.eval() + + return model + + +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + "num_params": 5483032, + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "qnnpack", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", + "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.004, + "acc@5": 90.858, + } + }, + "_ops": 0.217, + "_file_size": 21.554, + "_docs": """ + These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized + weights listed below. + """, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@register_model(name="quantized_mobilenet_v3_large") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V3_Large_Weights.IMAGENET1K_V1, + ) +) +def mobilenet_v3_large( + *, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV3: + """ + MobileNetV3 (Large) model from + `Searching for MobileNetV3 `_. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights + :members: + .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights + :members: + :noindex: + """ + weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/resnet.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..39958a010fbd335709bc77a1aaf26c996584a398 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/resnet.py @@ -0,0 +1,484 @@ +from functools import partial +from typing import Any, List, Optional, Type, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.models.resnet import ( + BasicBlock, + Bottleneck, + ResNet, + ResNet18_Weights, + ResNet50_Weights, + ResNeXt101_32X8D_Weights, + ResNeXt101_64X4D_Weights, +) + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableResNet", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", + "ResNeXt101_64X4D_QuantizedWeights", + "resnet18", + "resnet50", + "resnext101_32x8d", + "resnext101_64x4d", +] + + +class QuantizableBasicBlock(BasicBlock): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.add_relu = torch.nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = self.add_relu.add_relu(out, identity) + + return out + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True) + if self.downsample: + _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True) + + +class QuantizableBottleneck(Bottleneck): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.skip_add_relu = nn.quantized.FloatFunctional() + self.relu1 = nn.ReLU(inplace=False) + self.relu2 = nn.ReLU(inplace=False) + + def forward(self, x: Tensor) -> Tensor: + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + out = self.skip_add_relu.add_relu(out, identity) + + return out + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + _fuse_modules( + self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True + ) + if self.downsample: + _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True) + + +class QuantizableResNet(ResNet): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + # Ensure scriptability + # super(QuantizableResNet,self).forward(x) + # is not scriptable + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in resnet models + + Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization. + Model is modified in place. Note that this operation does not change numerics + and the model after modification is in floating point + """ + _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True) + for m in self.modules(): + if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock: + m.fuse_model(is_qat) + + +def _resnet( + block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], + layers: List[int], + weights: Optional[WeightsEnum], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableResNet(block, layers, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, +} + + +class ResNet18_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11689512, + "unquantized": ResNet18_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.494, + "acc@5": 88.882, + } + }, + "_ops": 1.814, + "_file_size": 11.238, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ResNet50_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.920, + "acc@5": 92.814, + } + }, + "_ops": 4.089, + "_file_size": 24.759, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V2, + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.282, + "acc@5": 94.976, + } + }, + "_ops": 4.089, + "_file_size": 24.953, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.986, + "acc@5": 94.480, + } + }, + "_ops": 16.414, + "_file_size": 86.034, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.574, + "acc@5": 96.132, + } + }, + "_ops": 16.414, + "_file_size": 86.645, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83455272, + "recipe": "https://github.com/pytorch/vision/pull/5935", + "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.898, + "acc@5": 96.326, + } + }, + "_ops": 15.46, + "_file_size": 81.556, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_resnet18") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet18_Weights.IMAGENET1K_V1, + ) +) +def resnet18( + *, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNet-18 model from + `Deep Residual Learning for Image Recognition `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + :noindex: + """ + weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) + + return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) + + +@register_model(name="quantized_resnet50") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet50_Weights.IMAGENET1K_V1, + ) +) +def resnet50( + *, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNet-50 model from + `Deep Residual Learning for Image Recognition `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNet50_Weights + :members: + :noindex: + """ + weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) + + return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) + + +@register_model(name="quantized_resnext101_32x8d") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + ) +) +def resnext101_32x8d( + *, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNeXt-101 32x8d model from + `Aggregated Residual Transformation for Deep Neural Networks `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights + :members: + :noindex: + """ + weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) + + +@register_model(name="quantized_resnext101_64x4d") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_64X4D_Weights.IMAGENET1K_V1, + ) +) +def resnext101_64x4d( + *, + weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + """ResNeXt-101 64x4d model from + `Aggregated Residual Transformation for Deep Neural Networks `_ + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + quantize (bool, optional): If True, return a quantized version of the model. Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights + :members: + :noindex: + """ + weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 64) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/shufflenetv2.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/shufflenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1b01356a74b8e4f16d66811060be698cfed199 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/shufflenetv2.py @@ -0,0 +1,427 @@ +from functools import partial +from typing import Any, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torchvision.models import shufflenetv2 + +from ...transforms._presets import ImageClassification +from .._api import register_model, Weights, WeightsEnum +from .._meta import _IMAGENET_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..shufflenetv2 import ( + ShuffleNet_V2_X0_5_Weights, + ShuffleNet_V2_X1_0_Weights, + ShuffleNet_V2_X1_5_Weights, + ShuffleNet_V2_X2_0_Weights, +) +from .utils import _fuse_modules, _replace_relu, quantize_model + + +__all__ = [ + "QuantizableShuffleNetV2", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", + "ShuffleNet_V2_X1_5_QuantizedWeights", + "ShuffleNet_V2_X2_0_QuantizedWeights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] + + +class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.cat = nn.quantized.FloatFunctional() + + def forward(self, x: Tensor) -> Tensor: + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = self.cat.cat([x1, self.branch2(x2)], dim=1) + else: + out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1) + + out = shufflenetv2.channel_shuffle(out, 2) + + return out + + +class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc] + self.quant = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x: Tensor) -> Tensor: + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self, is_qat: Optional[bool] = None) -> None: + r"""Fuse conv/bn/relu modules in shufflenetv2 model + + Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization. + Model is modified in place. + + .. note:: + Note that this operation does not change numerics + and the model after modification is in floating point + """ + for name, m in self._modules.items(): + if name in ["conv1", "conv5"] and m is not None: + _fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True) + for m in self.modules(): + if type(m) is QuantizableInvertedResidual: + if len(m.branch1._modules.items()) > 0: + _fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True) + _fuse_modules( + m.branch2, + [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], + is_qat, + inplace=True, + ) + + +def _shufflenetv2( + stages_repeats: List[int], + stages_out_channels: List[int], + *, + weights: Optional[WeightsEnum], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "_docs": """ + These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized + weights listed below. + """, +} + + +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 57.972, + "acc@5": 79.780, + } + }, + "_ops": 0.04, + "_file_size": 1.501, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-1e62bb32.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 68.360, + "acc@5": 87.582, + } + }, + "_ops": 0.145, + "_file_size": 2.334, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_5_fbgemm-d7401f05.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 3503624, + "unquantized": ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.052, + "acc@5": 90.700, + } + }, + "_ops": 0.296, + "_file_size": 3.672, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x2_0_fbgemm-5cac526c.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 7393996, + "unquantized": ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1, + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.354, + "acc@5": 92.488, + } + }, + "_ops": 0.583, + "_file_size": 7.467, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@register_model(name="quantized_shufflenet_v2_x0_5") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x0_5( + *, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 0.5x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X0_5_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X0_5_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +@register_model(name="quantized_shufflenet_v2_x1_0") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x1_0( + *, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.0x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X1_0_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X1_0_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +@register_model(name="quantized_shufflenet_v2_x1_5") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x1_5( + *, + weights: Optional[Union[ShuffleNet_V2_X1_5_QuantizedWeights, ShuffleNet_V2_X1_5_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 1.5x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X1_5_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X1_5_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X1_5_QuantizedWeights if quantize else ShuffleNet_V2_X1_5_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 176, 352, 704, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) + + +@register_model(name="quantized_shufflenet_v2_x2_0") +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1, + ) +) +def shufflenet_v2_x2_0( + *, + weights: Optional[Union[ShuffleNet_V2_X2_0_QuantizedWeights, ShuffleNet_V2_X2_0_Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableShuffleNetV2: + """ + Constructs a ShuffleNetV2 with 2.0x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + .. note:: + Note that ``quantize = True`` returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported. + + Args: + weights (:class:`~torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights` or :class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights`, optional): The + pretrained weights for the model. See + :class:`~torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. + Default is True. + quantize (bool, optional): If True, return a quantized version of the model. + Default is False. + **kwargs: parameters passed to the ``torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.quantization.ShuffleNet_V2_X2_0_QuantizedWeights + :members: + + .. autoclass:: torchvision.models.ShuffleNet_V2_X2_0_Weights + :members: + :noindex: + """ + weights = (ShuffleNet_V2_X2_0_QuantizedWeights if quantize else ShuffleNet_V2_X2_0_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/quantization/utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a21e2af8e016568e79b25e35ec774f39f0595c3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/quantization/utils.py @@ -0,0 +1,51 @@ +from typing import Any, List, Optional, Union + +import torch +from torch import nn + + +def _replace_relu(module: nn.Module) -> None: + reassign = {} + for name, mod in module.named_children(): + _replace_relu(mod) + # Checking for explicit type instead of instance + # as we only want to replace modules of the exact type + # not inherited classes + if type(mod) is nn.ReLU or type(mod) is nn.ReLU6: + reassign[name] = nn.ReLU(inplace=False) + + for key, value in reassign.items(): + module._modules[key] = value + + +def quantize_model(model: nn.Module, backend: str) -> None: + _dummy_input_data = torch.rand(1, 3, 299, 299) + if backend not in torch.backends.quantized.supported_engines: + raise RuntimeError("Quantized backend not supported ") + torch.backends.quantized.engine = backend + model.eval() + # Make sure that weight qconfig matches that of the serialized models + if backend == "fbgemm": + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer + ) + + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + model.fuse_model() # type: ignore[operator] + torch.ao.quantization.prepare(model, inplace=True) + model(_dummy_input_data) + torch.ao.quantization.convert(model, inplace=True) + + +def _fuse_modules( + model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any +): + if is_qat is None: + is_qat = model.training + method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules + return method(model, modules_to_fuse, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/regnet.py b/.venv/lib/python3.11/site-packages/torchvision/models/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f37b2994e48fd202cccd4db6dc63b97f7c0c1ae7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/regnet.py @@ -0,0 +1,1571 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn, Tensor + +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms._presets import ImageClassification, InterpolationMode +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "RegNet", + "RegNet_Y_400MF_Weights", + "RegNet_Y_800MF_Weights", + "RegNet_Y_1_6GF_Weights", + "RegNet_Y_3_2GF_Weights", + "RegNet_Y_8GF_Weights", + "RegNet_Y_16GF_Weights", + "RegNet_Y_32GF_Weights", + "RegNet_Y_128GF_Weights", + "RegNet_X_400MF_Weights", + "RegNet_X_800MF_Weights", + "RegNet_X_1_6GF_Weights", + "RegNet_X_3_2GF_Weights", + "RegNet_X_8GF_Weights", + "RegNet_X_16GF_Weights", + "RegNet_X_32GF_Weights", + "regnet_y_400mf", + "regnet_y_800mf", + "regnet_y_1_6gf", + "regnet_y_3_2gf", + "regnet_y_8gf", + "regnet_y_16gf", + "regnet_y_32gf", + "regnet_y_128gf", + "regnet_x_400mf", + "regnet_x_800mf", + "regnet_x_1_6gf", + "regnet_x_3_2gf", + "regnet_x_8gf", + "regnet_x_16gf", + "regnet_x_32gf", +] + + +class SimpleStemIN(Conv2dNormActivation): + """Simple stem for ImageNet: 3x3, BN, ReLU.""" + + def __init__( + self, + width_in: int, + width_out: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + ) -> None: + super().__init__( + width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer + ) + + +class BottleneckTransform(nn.Sequential): + """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" + + def __init__( + self, + width_in: int, + width_out: int, + stride: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float], + ) -> None: + layers: OrderedDict[str, nn.Module] = OrderedDict() + w_b = int(round(width_out * bottleneck_multiplier)) + g = w_b // group_width + + layers["a"] = Conv2dNormActivation( + width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer + ) + layers["b"] = Conv2dNormActivation( + w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer + ) + + if se_ratio: + # The SE reduction ratio is defined with respect to the + # beginning of the block + width_se_out = int(round(se_ratio * width_in)) + layers["se"] = SqueezeExcitation( + input_channels=w_b, + squeeze_channels=width_se_out, + activation=activation_layer, + ) + + layers["c"] = Conv2dNormActivation( + w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None + ) + super().__init__(layers) + + +class ResBottleneckBlock(nn.Module): + """Residual bottleneck block: x + F(x), F = bottleneck transform.""" + + def __init__( + self, + width_in: int, + width_out: int, + stride: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + group_width: int = 1, + bottleneck_multiplier: float = 1.0, + se_ratio: Optional[float] = None, + ) -> None: + super().__init__() + + # Use skip connection with projection if shape changes + self.proj = None + should_proj = (width_in != width_out) or (stride != 1) + if should_proj: + self.proj = Conv2dNormActivation( + width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None + ) + self.f = BottleneckTransform( + width_in, + width_out, + stride, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + ) + self.activation = activation_layer(inplace=True) + + def forward(self, x: Tensor) -> Tensor: + if self.proj is not None: + x = self.proj(x) + self.f(x) + else: + x = x + self.f(x) + return self.activation(x) + + +class AnyStage(nn.Sequential): + """AnyNet stage (sequence of blocks w/ the same output shape).""" + + def __init__( + self, + width_in: int, + width_out: int, + stride: int, + depth: int, + block_constructor: Callable[..., nn.Module], + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float] = None, + stage_index: int = 0, + ) -> None: + super().__init__() + + for i in range(depth): + block = block_constructor( + width_in if i == 0 else width_out, + width_out, + stride if i == 0 else 1, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + ) + + self.add_module(f"block{stage_index}-{i}", block) + + +class BlockParams: + def __init__( + self, + depths: List[int], + widths: List[int], + group_widths: List[int], + bottleneck_multipliers: List[float], + strides: List[int], + se_ratio: Optional[float] = None, + ) -> None: + self.depths = depths + self.widths = widths + self.group_widths = group_widths + self.bottleneck_multipliers = bottleneck_multipliers + self.strides = strides + self.se_ratio = se_ratio + + @classmethod + def from_init_params( + cls, + depth: int, + w_0: int, + w_a: float, + w_m: float, + group_width: int, + bottleneck_multiplier: float = 1.0, + se_ratio: Optional[float] = None, + **kwargs: Any, + ) -> "BlockParams": + """ + Programmatically compute all the per-block settings, + given the RegNet parameters. + + The first step is to compute the quantized linear block parameters, + in log space. Key parameters are: + - `w_a` is the width progression slope + - `w_0` is the initial width + - `w_m` is the width stepping in the log space + + In other terms + `log(block_width) = log(w_0) + w_m * block_capacity`, + with `bock_capacity` ramping up following the w_0 and w_a params. + This block width is finally quantized to multiples of 8. + + The second step is to compute the parameters per stage, + taking into account the skip connection and the final 1x1 convolutions. + We use the fact that the output width is constant within a stage. + """ + + QUANT = 8 + STRIDE = 2 + + if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0: + raise ValueError("Invalid RegNet settings") + # Compute the block widths. Each stage has one unique block width + widths_cont = torch.arange(depth) * w_a + w_0 + block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m)) + block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist() + num_stages = len(set(block_widths)) + + # Convert to per stage parameters + split_helper = zip( + block_widths + [0], + [0] + block_widths, + block_widths + [0], + [0] + block_widths, + ) + splits = [w != wp or r != rp for w, wp, r, rp in split_helper] + + stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t] + stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist() + + strides = [STRIDE] * num_stages + bottleneck_multipliers = [bottleneck_multiplier] * num_stages + group_widths = [group_width] * num_stages + + # Adjust the compatibility of stage widths and group widths + stage_widths, group_widths = cls._adjust_widths_groups_compatibilty( + stage_widths, bottleneck_multipliers, group_widths + ) + + return cls( + depths=stage_depths, + widths=stage_widths, + group_widths=group_widths, + bottleneck_multipliers=bottleneck_multipliers, + strides=strides, + se_ratio=se_ratio, + ) + + def _get_expanded_params(self): + return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers) + + @staticmethod + def _adjust_widths_groups_compatibilty( + stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int] + ) -> Tuple[List[int], List[int]]: + """ + Adjusts the compatibility of widths and groups, + depending on the bottleneck ratio. + """ + # Compute all widths for the current settings + widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)] + group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)] + + # Compute the adjusted widths so that stage and group widths fit + ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)] + stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)] + return stage_widths, group_widths_min + + +class RegNet(nn.Module): + def __init__( + self, + block_params: BlockParams, + num_classes: int = 1000, + stem_width: int = 32, + stem_type: Optional[Callable[..., nn.Module]] = None, + block_type: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if stem_type is None: + stem_type = SimpleStemIN + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if block_type is None: + block_type = ResBottleneckBlock + if activation is None: + activation = nn.ReLU + + # Ad hoc stem + self.stem = stem_type( + 3, # width_in + stem_width, + norm_layer, + activation, + ) + + current_width = stem_width + + blocks = [] + for i, ( + width_out, + stride, + depth, + group_width, + bottleneck_multiplier, + ) in enumerate(block_params._get_expanded_params()): + blocks.append( + ( + f"block{i+1}", + AnyStage( + current_width, + width_out, + stride, + depth, + block_type, + norm_layer, + activation, + group_width, + bottleneck_multiplier, + block_params.se_ratio, + stage_index=i + 1, + ), + ) + ) + + current_width = width_out + + self.trunk_output = nn.Sequential(OrderedDict(blocks)) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(in_features=current_width, out_features=num_classes) + + # Performs ResNet-style weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # Note that there is no bias due to BN + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out)) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + x = self.trunk_output(x) + + x = self.avgpool(x) + x = x.flatten(start_dim=1) + x = self.fc(x) + + return x + + +def _regnet( + block_params: BlockParams, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> RegNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) + model = RegNet(block_params, norm_layer=norm_layer, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META: Dict[str, Any] = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, +} + +_COMMON_SWAG_META = { + **_COMMON_META, + "recipe": "https://github.com/facebookresearch/SWAG", + "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", +} + + +class RegNet_Y_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 74.046, + "acc@5": 91.716, + } + }, + "_ops": 0.402, + "_file_size": 16.806, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.804, + "acc@5": 92.742, + } + }, + "_ops": 0.402, + "_file_size": 16.806, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 76.420, + "acc@5": 93.136, + } + }, + "_ops": 0.834, + "_file_size": 24.774, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.828, + "acc@5": 94.502, + } + }, + "_ops": 0.834, + "_file_size": 24.774, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.950, + "acc@5": 93.966, + } + }, + "_ops": 1.612, + "_file_size": 43.152, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.876, + "acc@5": 95.444, + } + }, + "_ops": 1.612, + "_file_size": 43.152, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.948, + "acc@5": 94.576, + } + }, + "_ops": 3.176, + "_file_size": 74.567, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.982, + "acc@5": 95.972, + } + }, + "_ops": 3.176, + "_file_size": 74.567, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.032, + "acc@5": 95.048, + } + }, + "_ops": 8.473, + "_file_size": 150.701, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.828, + "acc@5": 96.330, + } + }, + "_ops": 8.473, + "_file_size": 150.701, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.424, + "acc@5": 95.240, + } + }, + "_ops": 15.912, + "_file_size": 319.49, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.886, + "acc@5": 96.328, + } + }, + "_ops": 15.912, + "_file_size": 319.49, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + IMAGENET1K_SWAG_E2E_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth", + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 83590140, + "_metrics": { + "ImageNet-1K": { + "acc@1": 86.012, + "acc@5": 98.054, + } + }, + "_ops": 46.735, + "_file_size": 319.49, + "_docs": """ + These weights are learnt via transfer learning by end-to-end fine-tuning the original + `SWAG `_ weights on ImageNet-1K data. + """, + }, + ) + IMAGENET1K_SWAG_LINEAR_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "recipe": "https://github.com/pytorch/vision/pull/5793", + "num_params": 83590140, + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.976, + "acc@5": 97.244, + } + }, + "_ops": 15.912, + "_file_size": 319.49, + "_docs": """ + These weights are composed of the original frozen `SWAG `_ trunk + weights and a linear classifier learnt on top of them trained on ImageNet-1K data. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.878, + "acc@5": 95.340, + } + }, + "_ops": 32.28, + "_file_size": 554.076, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.368, + "acc@5": 96.498, + } + }, + "_ops": 32.28, + "_file_size": 554.076, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + IMAGENET1K_SWAG_E2E_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth", + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 145046770, + "_metrics": { + "ImageNet-1K": { + "acc@1": 86.838, + "acc@5": 98.362, + } + }, + "_ops": 94.826, + "_file_size": 554.076, + "_docs": """ + These weights are learnt via transfer learning by end-to-end fine-tuning the original + `SWAG `_ weights on ImageNet-1K data. + """, + }, + ) + IMAGENET1K_SWAG_LINEAR_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "recipe": "https://github.com/pytorch/vision/pull/5793", + "num_params": 145046770, + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.622, + "acc@5": 97.480, + } + }, + "_ops": 32.28, + "_file_size": 554.076, + "_docs": """ + These weights are composed of the original frozen `SWAG `_ trunk + weights and a linear classifier learnt on top of them trained on ImageNet-1K data. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_128GF_Weights(WeightsEnum): + IMAGENET1K_SWAG_E2E_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth", + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 644812894, + "_metrics": { + "ImageNet-1K": { + "acc@1": 88.228, + "acc@5": 98.682, + } + }, + "_ops": 374.57, + "_file_size": 2461.564, + "_docs": """ + These weights are learnt via transfer learning by end-to-end fine-tuning the original + `SWAG `_ weights on ImageNet-1K data. + """, + }, + ) + IMAGENET1K_SWAG_LINEAR_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "recipe": "https://github.com/pytorch/vision/pull/5793", + "num_params": 644812894, + "_metrics": { + "ImageNet-1K": { + "acc@1": 86.068, + "acc@5": 97.844, + } + }, + "_ops": 127.518, + "_file_size": 2461.564, + "_docs": """ + These weights are composed of the original frozen `SWAG `_ trunk + weights and a linear classifier learnt on top of them trained on ImageNet-1K data. + """, + }, + ) + DEFAULT = IMAGENET1K_SWAG_E2E_V1 + + +class RegNet_X_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.834, + "acc@5": 90.950, + } + }, + "_ops": 0.414, + "_file_size": 21.258, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 74.864, + "acc@5": 92.322, + } + }, + "_ops": 0.414, + "_file_size": 21.257, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.212, + "acc@5": 92.348, + } + }, + "_ops": 0.8, + "_file_size": 27.945, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.522, + "acc@5": 93.826, + } + }, + "_ops": 0.8, + "_file_size": 27.945, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.040, + "acc@5": 93.440, + } + }, + "_ops": 1.603, + "_file_size": 35.339, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 79.668, + "acc@5": 94.922, + } + }, + "_ops": 1.603, + "_file_size": 35.339, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.364, + "acc@5": 93.992, + } + }, + "_ops": 3.177, + "_file_size": 58.756, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.196, + "acc@5": 95.430, + } + }, + "_ops": 3.177, + "_file_size": 58.756, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 79.344, + "acc@5": 94.686, + } + }, + "_ops": 7.995, + "_file_size": 151.456, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.682, + "acc@5": 95.678, + } + }, + "_ops": 7.995, + "_file_size": 151.456, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.058, + "acc@5": 94.944, + } + }, + "_ops": 15.941, + "_file_size": 207.627, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.716, + "acc@5": 96.196, + } + }, + "_ops": 15.941, + "_file_size": 207.627, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.622, + "acc@5": 95.248, + } + }, + "_ops": 31.736, + "_file_size": 412.039, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.014, + "acc@5": 96.288, + } + }, + "_ops": 31.736, + "_file_size": 412.039, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) +def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_400MF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_400MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_400MF_Weights + :members: + """ + weights = RegNet_Y_400MF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) +def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_800MF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_800MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_800MF_Weights + :members: + """ + weights = RegNet_Y_800MF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_1.6GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_1_6GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_1_6GF_Weights + :members: + """ + weights = RegNet_Y_1_6GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_3.2GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_3_2GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_3_2GF_Weights + :members: + """ + weights = RegNet_Y_3_2GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) +def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_8GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_8GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_8GF_Weights + :members: + """ + weights = RegNet_Y_8GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) +def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_16GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_16GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_16GF_Weights + :members: + """ + weights = RegNet_Y_16GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) +def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_32GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_32GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_32GF_Weights + :members: + """ + weights = RegNet_Y_32GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", None)) +def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetY_128GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_Y_128GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_Y_128GF_Weights + :members: + """ + weights = RegNet_Y_128GF_Weights.verify(weights) + + params = BlockParams.from_init_params( + depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs + ) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) +def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_400MF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_400MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_400MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_400MF_Weights + :members: + """ + weights = RegNet_X_400MF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) +def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_800MF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_800MF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_800MF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_800MF_Weights + :members: + """ + weights = RegNet_X_800MF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_1.6GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_1_6GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_1_6GF_Weights + :members: + """ + weights = RegNet_X_1_6GF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_3.2GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_3_2GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_3_2GF_Weights + :members: + """ + weights = RegNet_X_3_2GF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) +def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_8GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_8GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_8GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_8GF_Weights + :members: + """ + weights = RegNet_X_8GF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) +def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_16GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_16GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_16GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_16GF_Weights + :members: + """ + weights = RegNet_X_16GF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) + return _regnet(params, weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) +def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: + """ + Constructs a RegNetX_32GF architecture from + `Designing Network Design Spaces `_. + + Args: + weights (:class:`~torchvision.models.RegNet_X_32GF_Weights`, optional): The pretrained weights to use. + See :class:`~torchvision.models.RegNet_X_32GF_Weights` below for more details and possible values. + By default, no pretrained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or + ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code + `_ + for more detail about the classes. + + .. autoclass:: torchvision.models.RegNet_X_32GF_Weights + :members: + """ + weights = RegNet_X_32GF_Weights.verify(weights) + + params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) + return _regnet(params, weights, progress, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py b/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..83c0340cef74d9cb4c1dc92c38f4b3024be1f731 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py @@ -0,0 +1,985 @@ +from functools import partial +from typing import Any, Callable, List, Optional, Type, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from ..transforms._presets import ImageClassification +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "ResNet", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "ResNeXt101_64X4D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "resnext101_64x4d", + "wide_resnet50_2", + "wide_resnet101_2", +] + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + f"or a 3-element tuple, got {replace_stride_with_dilation}" + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> ResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = ResNet(block, layers, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, +} + + +class ResNet18_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11689512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.758, + "acc@5": 89.078, + } + }, + "_ops": 1.814, + "_file_size": 44.661, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet34_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 21797672, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.314, + "acc@5": 91.420, + } + }, + "_ops": 3.664, + "_file_size": 83.275, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet50_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet50-0676ba61.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 76.130, + "acc@5": 92.862, + } + }, + "_ops": 4.089, + "_file_size": 97.781, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.858, + "acc@5": 95.434, + } + }, + "_ops": 4.089, + "_file_size": 97.79, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet101_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.374, + "acc@5": 93.546, + } + }, + "_ops": 7.801, + "_file_size": 170.511, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.886, + "acc@5": 95.780, + } + }, + "_ops": 7.801, + "_file_size": 170.53, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet152_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.312, + "acc@5": 94.046, + } + }, + "_ops": 11.514, + "_file_size": 230.434, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet152-f82ba261.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.284, + "acc@5": 96.002, + } + }, + "_ops": 11.514, + "_file_size": 230.474, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt50_32X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.618, + "acc@5": 93.698, + } + }, + "_ops": 4.23, + "_file_size": 95.789, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.198, + "acc@5": 95.340, + } + }, + "_ops": 4.23, + "_file_size": 95.833, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_32X8D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "_metrics": { + "ImageNet-1K": { + "acc@1": 79.312, + "acc@5": 94.526, + } + }, + "_ops": 16.414, + "_file_size": 339.586, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.834, + "acc@5": 96.228, + } + }, + "_ops": 16.414, + "_file_size": 339.673, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_64X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83455272, + "recipe": "https://github.com/pytorch/vision/pull/5935", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.246, + "acc@5": 96.454, + } + }, + "_ops": 15.46, + "_file_size": 319.318, + "_docs": """ + These weights were trained from scratch by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Wide_ResNet50_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.468, + "acc@5": 94.086, + } + }, + "_ops": 11.398, + "_file_size": 131.82, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.602, + "acc@5": 95.758, + } + }, + "_ops": 11.398, + "_file_size": 263.124, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet101_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.848, + "acc@5": 94.284, + } + }, + "_ops": 22.753, + "_file_size": 242.896, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.510, + "acc@5": 96.020, + } + }, + "_ops": 22.753, + "_file_size": 484.747, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) +def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-18 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet18_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + """ + weights = ResNet18_Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) +def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-34 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet34_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet34_Weights + :members: + """ + weights = ResNet34_Weights.verify(weights) + + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) +def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-50 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet50_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet50_Weights + :members: + """ + weights = ResNet50_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) +def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-101 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet101_Weights + :members: + """ + weights = ResNet101_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) +def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-152 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet152_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet152_Weights + :members: + """ + weights = ResNet152_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) +def resnext50_32x4d( + *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """ResNeXt-50 32x4d model from + `Aggregated Residual Transformation for Deep Neural Networks `_. + + Args: + weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNext50_32X4D_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights + :members: + """ + weights = ResNeXt50_32X4D_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) +def resnext101_32x8d( + *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """ResNeXt-101 32x8d model from + `Aggregated Residual Transformation for Deep Neural Networks `_. + + Args: + weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights + :members: + """ + weights = ResNeXt101_32X8D_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ResNeXt101_64X4D_Weights.IMAGENET1K_V1)) +def resnext101_64x4d( + *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """ResNeXt-101 64x4d model from + `Aggregated Residual Transformation for Deep Neural Networks `_. + + Args: + weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights + :members: + """ + weights = ResNeXt101_64X4D_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "groups", 64) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) +def wide_resnet50_2( + *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """Wide ResNet-50-2 model from + `Wide Residual Networks `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights + :members: + """ + weights = Wide_ResNet50_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) +def wide_resnet101_2( + *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """Wide ResNet-101-2 model from + `Wide Residual Networks `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048 + channels, and in Wide ResNet-101-2 has 2048-1024-2048. + + Args: + weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights + :members: + """ + weights = Wide_ResNet101_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6f37f958a131b76ce80306718b77d78bc3f045 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__init__.py @@ -0,0 +1,3 @@ +from .deeplabv3 import * +from .fcn import * +from .lraspp import * diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/deeplabv3.py b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/deeplabv3.py new file mode 100644 index 0000000000000000000000000000000000000000..a92ddfe3b7af41a9ffd371d34cd459ba57965c53 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/deeplabv3.py @@ -0,0 +1,390 @@ +from functools import partial +from typing import Any, Optional, Sequence + +import torch +from torch import nn +from torch.nn import functional as F + +from ...transforms._presets import SemanticSegmentation +from .._api import register_model, Weights, WeightsEnum +from .._meta import _VOC_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights +from ._utils import _SimpleSegmentationModel +from .fcn import FCNHead + + +__all__ = [ + "DeepLabV3", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", + "deeplabv3_mobilenet_v3_large", + "deeplabv3_resnet50", + "deeplabv3_resnet101", +] + + +class DeepLabV3(_SimpleSegmentationModel): + """ + Implements DeepLabV3 model from + `"Rethinking Atrous Convolution for Semantic Image Segmentation" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + + pass + + +class DeepLabHead(nn.Sequential): + def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None: + super().__init__( + ASPP(in_channels, atrous_rates), + nn.Conv2d(256, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, num_classes, 1), + ) + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None: + modules = [ + nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ] + super().__init__(*modules) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels: int, out_channels: int) -> None: + super().__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + size = x.shape[-2:] + for mod in self: + x = mod(x) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None: + super().__init__() + modules = [] + modules.append( + nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) + ) + + rates = tuple(atrous_rates) + for rate in rates: + modules.append(ASPPConv(in_channels, out_channels, rate)) + + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Dropout(0.5), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _res = [] + for conv in self.convs: + _res.append(conv(x)) + res = torch.cat(_res, dim=1) + return self.project(res) + + +def _deeplabv3_resnet( + backbone: ResNet, + num_classes: int, + aux: Optional[bool], +) -> DeepLabV3: + return_layers = {"layer4": "out"} + if aux: + return_layers["layer3"] = "aux" + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = FCNHead(1024, num_classes) if aux else None + classifier = DeepLabHead(2048, num_classes) + return DeepLabV3(backbone, classifier, aux_classifier) + + +_COMMON_META = { + "categories": _VOC_CATEGORIES, + "min_size": (1, 1), + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC + dataset. + """, +} + + +class DeepLabV3_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 42004074, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 66.4, + "pixel_acc": 92.4, + } + }, + "_ops": 178.722, + "_file_size": 160.515, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 60996202, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 67.4, + "pixel_acc": 92.4, + } + }, + "_ops": 258.743, + "_file_size": 233.217, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 11029328, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 60.3, + "pixel_acc": 91.2, + } + }, + "_ops": 10.452, + "_file_size": 42.301, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +def _deeplabv3_mobilenetv3( + backbone: MobileNetV3, + num_classes: int, + aux: Optional[bool], +) -> DeepLabV3: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 + out_inplanes = backbone[out_pos].out_channels + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + aux_inplanes = backbone[aux_pos].out_channels + return_layers = {str(out_pos): "out"} + if aux: + return_layers[str(aux_pos)] = "aux" + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None + classifier = DeepLabHead(out_inplanes, num_classes) + return DeepLabV3(backbone, classifier, aux_classifier) + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def deeplabv3_resnet50( + *, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a ResNet-50 backbone. + + .. betastatus:: segmentation module + + Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation `__. + + Args: + weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the + backbone + **kwargs: unused + + .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights + :members: + """ + weights = DeepLabV3_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) +def deeplabv3_resnet101( + *, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a ResNet-101 backbone. + + .. betastatus:: segmentation module + + Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation `__. + + Args: + weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the + backbone + **kwargs: unused + + .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights + :members: + """ + weights = DeepLabV3_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _deeplabv3_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def deeplabv3_mobilenet_v3_large( + *, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> DeepLabV3: + """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. + + Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation `__. + + Args: + weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background) + aux_loss (bool, optional): If True, it uses an auxiliary loss + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights + for the backbone + **kwargs: unused + + .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights + :members: + """ + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) + model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/fcn.py b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/fcn.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2e242adac0e7430bab6155ae0347770e29fee9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/fcn.py @@ -0,0 +1,232 @@ +from functools import partial +from typing import Any, Optional + +from torch import nn + +from ...transforms._presets import SemanticSegmentation +from .._api import register_model, Weights, WeightsEnum +from .._meta import _VOC_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..resnet import ResNet, resnet101, ResNet101_Weights, resnet50, ResNet50_Weights +from ._utils import _SimpleSegmentationModel + + +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] + + +class FCN(_SimpleSegmentationModel): + """ + Implements FCN model from + `"Fully Convolutional Networks for Semantic Segmentation" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "out" for the last feature map used, and "aux" if an auxiliary classifier + is used. + classifier (nn.Module): module that takes the "out" element returned from + the backbone and returns a dense prediction. + aux_classifier (nn.Module, optional): auxiliary classifier used during training + """ + + pass + + +class FCNHead(nn.Sequential): + def __init__(self, in_channels: int, channels: int) -> None: + inter_channels = in_channels // 4 + layers = [ + nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(), + nn.Dropout(0.1), + nn.Conv2d(inter_channels, channels, 1), + ] + + super().__init__(*layers) + + +_COMMON_META = { + "categories": _VOC_CATEGORIES, + "min_size": (1, 1), + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC + dataset. + """, +} + + +class FCN_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 35322218, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 60.5, + "pixel_acc": 91.4, + } + }, + "_ops": 152.717, + "_file_size": 135.009, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class FCN_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 54314346, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 63.7, + "pixel_acc": 91.9, + } + }, + "_ops": 232.738, + "_file_size": 207.711, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +def _fcn_resnet( + backbone: ResNet, + num_classes: int, + aux: Optional[bool], +) -> FCN: + return_layers = {"layer4": "out"} + if aux: + return_layers["layer3"] = "aux" + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = FCNHead(1024, num_classes) if aux else None + classifier = FCNHead(2048, num_classes) + return FCN(backbone, classifier, aux_classifier) + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) +def fcn_resnet50( + *, + weights: Optional[FCN_ResNet50_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> FCN: + """Fully-Convolutional Network model with a ResNet-50 backbone from the `Fully Convolutional + Networks for Semantic Segmentation `_ paper. + + .. betastatus:: segmentation module + + Args: + weights (:class:`~torchvision.models.segmentation.FCN_ResNet50_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.FCN_ResNet50_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background). + aux_loss (bool, optional): If True, it uses an auxiliary loss. + weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained + weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.segmentation.FCN_ResNet50_Weights + :members: + """ + + weights = FCN_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _fcn_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) +def fcn_resnet101( + *, + weights: Optional[FCN_ResNet101_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + aux_loss: Optional[bool] = None, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> FCN: + """Fully-Convolutional Network model with a ResNet-101 backbone from the `Fully Convolutional + Networks for Semantic Segmentation `_ paper. + + .. betastatus:: segmentation module + + Args: + weights (:class:`~torchvision.models.segmentation.FCN_ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.FCN_ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background). + aux_loss (bool, optional): If True, it uses an auxiliary loss. + weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained + weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.segmentation.fcn.FCN`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.segmentation.FCN_ResNet101_Weights + :members: + """ + + weights = FCN_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) + model = _fcn_resnet(backbone, num_classes, aux_loss) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/lraspp.py b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/lraspp.py new file mode 100644 index 0000000000000000000000000000000000000000..70bced70fd37c3c681915492cea0c68c87cf0a7e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/lraspp.py @@ -0,0 +1,178 @@ +from collections import OrderedDict +from functools import partial +from typing import Any, Dict, Optional + +from torch import nn, Tensor +from torch.nn import functional as F + +from ...transforms._presets import SemanticSegmentation +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _VOC_CATEGORIES +from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter +from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3 + + +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] + + +class LRASPP(nn.Module): + """ + Implements a Lite R-ASPP Network for semantic segmentation from + `"Searching for MobileNetV3" + `_. + + Args: + backbone (nn.Module): the network used to compute the features for the model. + The backbone should return an OrderedDict[Tensor], with the key being + "high" for the high level feature map and "low" for the low level feature map. + low_channels (int): the number of channels of the low level features. + high_channels (int): the number of channels of the high level features. + num_classes (int, optional): number of output classes of the model (including the background). + inter_channels (int, optional): the number of channels for intermediate computations. + """ + + def __init__( + self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128 + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.backbone = backbone + self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) + + def forward(self, input: Tensor) -> Dict[str, Tensor]: + features = self.backbone(input) + out = self.classifier(features) + out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False) + + result = OrderedDict() + result["out"] = out + + return result + + +class LRASPPHead(nn.Module): + def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None: + super().__init__() + self.cbr = nn.Sequential( + nn.Conv2d(high_channels, inter_channels, 1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + ) + self.scale = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(high_channels, inter_channels, 1, bias=False), + nn.Sigmoid(), + ) + self.low_classifier = nn.Conv2d(low_channels, num_classes, 1) + self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1) + + def forward(self, input: Dict[str, Tensor]) -> Tensor: + low = input["low"] + high = input["high"] + + x = self.cbr(high) + s = self.scale(high) + x = x * s + x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False) + + return self.low_classifier(low) + self.high_classifier(x) + + +def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] + low_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + high_pos = stage_indices[-1] # use C5 which has output_stride = 16 + low_channels = backbone[low_pos].out_channels + high_channels = backbone[high_pos].out_channels + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) + + return LRASPP(backbone, low_channels, high_channels, num_classes) + + +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", + transforms=partial(SemanticSegmentation, resize_size=520), + meta={ + "num_params": 3221538, + "categories": _VOC_CATEGORIES, + "min_size": (1, 1), + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", + "_metrics": { + "COCO-val2017-VOC-labels": { + "miou": 57.9, + "pixel_acc": 91.2, + } + }, + "_ops": 2.086, + "_file_size": 12.49, + "_docs": """ + These weights were trained on a subset of COCO, using only the 20 categories that are present in the + Pascal VOC dataset. + """, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +@register_model() +@handle_legacy_interface( + weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) +def lraspp_mobilenet_v3_large( + *, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, + **kwargs: Any, +) -> LRASPP: + """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone from + `Searching for MobileNetV3 `_ paper. + + .. betastatus:: segmentation module + + Args: + weights (:class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + num_classes (int, optional): number of output classes of the model (including the background). + aux_loss (bool, optional): If True, it uses an auxiliary loss. + weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained + weights for the backbone. + **kwargs: parameters passed to the ``torchvision.models.segmentation.LRASPP`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights + :members: + """ + if kwargs.pop("aux_loss", False): + raise NotImplementedError("This model does not use auxiliary loss") + + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) + model = _lraspp_mobilenetv3(backbone, num_classes) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/shufflenetv2.py b/.venv/lib/python3.11/site-packages/torchvision/models/shufflenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3322b7a88f183c15838308f39c38d36dea13c0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/shufflenetv2.py @@ -0,0 +1,408 @@ +from functools import partial +from typing import Any, Callable, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from ..transforms._presets import ImageClassification +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "ShuffleNetV2", + "ShuffleNet_V2_X0_5_Weights", + "ShuffleNet_V2_X1_0_Weights", + "ShuffleNet_V2_X1_5_Weights", + "ShuffleNet_V2_X2_0_Weights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] + + +def channel_shuffle(x: Tensor, groups: int) -> Tensor: + batchsize, num_channels, height, width = x.size() + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, num_channels, height, width) + + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp: int, oup: int, stride: int) -> None: + super().__init__() + + if not (1 <= stride <= 3): + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + if (self.stride == 1) and (inp != branch_features << 1): + raise ValueError( + f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1." + ) + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(branch_features), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.ReLU(inplace=True), + ) + + @staticmethod + def depthwise_conv( + i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False + ) -> nn.Conv2d: + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x: Tensor) -> Tensor: + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + +class ShuffleNetV2(nn.Module): + def __init__( + self, + stages_repeats: List[int], + stages_out_channels: List[int], + num_classes: int = 1000, + inverted_residual: Callable[..., nn.Module] = InvertedResidual, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if len(stages_repeats) != 3: + raise ValueError("expected stages_repeats as list of 3 positive ints") + if len(stages_out_channels) != 5: + raise ValueError("expected stages_out_channels as list of 5 positive ints") + self._stage_out_channels = stages_out_channels + + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + input_channels = output_channels + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # Static annotations for mypy + self.stage2: nn.Sequential + self.stage3: nn.Sequential + self.stage4: nn.Sequential + stage_names = [f"stage{i}" for i in [2, 3, 4]] + for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]): + seq = [inverted_residual(input_channels, output_channels, 2)] + for i in range(repeats - 1): + seq.append(inverted_residual(output_channels, output_channels, 1)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + + output_channels = self._stage_out_channels[-1] + self.conv5 = nn.Sequential( + nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace=True), + ) + + self.fc = nn.Linear(output_channels, num_classes) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.maxpool(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.conv5(x) + x = x.mean([2, 3]) # globalpool + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _shufflenetv2( + weights: Optional[WeightsEnum], + progress: bool, + *args: Any, + **kwargs: Any, +) -> ShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = ShuffleNetV2(*args, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/ericsun99/Shufflenet-v2-Pytorch", +} + + +class ShuffleNet_V2_X0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/ericsun99/Shufflenet-v2-Pytorch + url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "_metrics": { + "ImageNet-1K": { + "acc@1": 60.552, + "acc@5": 81.746, + } + }, + "_ops": 0.04, + "_file_size": 5.282, + "_docs": """These weights were trained from scratch to reproduce closely the results of the paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/ericsun99/Shufflenet-v2-Pytorch + url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.362, + "acc@5": 88.316, + } + }, + "_ops": 0.145, + "_file_size": 8.791, + "_docs": """These weights were trained from scratch to reproduce closely the results of the paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 3503624, + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.996, + "acc@5": 91.086, + } + }, + "_ops": 0.296, + "_file_size": 13.557, + "_docs": """ + These weights were trained from scratch by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X2_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/5906", + "num_params": 7393996, + "_metrics": { + "ImageNet-1K": { + "acc@1": 76.230, + "acc@5": 93.006, + } + }, + "_ops": 0.583, + "_file_size": 28.433, + "_docs": """ + These weights were trained from scratch by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x0_5( + *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 architecture with 0.5x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + Args: + weights (:class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ShuffleNet_V2_X0_5_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ShuffleNet_V2_X0_5_Weights + :members: + """ + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x1_0( + *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 architecture with 1.0x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + Args: + weights (:class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ShuffleNet_V2_X1_0_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ShuffleNet_V2_X1_0_Weights + :members: + """ + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x1_5( + *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 architecture with 1.5x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + Args: + weights (:class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ShuffleNet_V2_X1_5_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ShuffleNet_V2_X1_5_Weights + :members: + """ + weights = ShuffleNet_V2_X1_5_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x2_0( + *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: + """ + Constructs a ShuffleNetV2 architecture with 2.0x output channels, as described in + `ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design + `__. + + Args: + weights (:class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ShuffleNet_V2_X2_0_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.shufflenetv2.ShuffleNetV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ShuffleNet_V2_X2_0_Weights + :members: + """ + weights = ShuffleNet_V2_X2_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/swin_transformer.py b/.venv/lib/python3.11/site-packages/torchvision/models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2035f659bfc7f4e4f98d36cabff9be3802d49e59 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/swin_transformer.py @@ -0,0 +1,1033 @@ +import math +from functools import partial +from typing import Any, Callable, List, Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ..ops.misc import MLP, Permute +from ..ops.stochastic_depth import StochasticDepth +from ..transforms._presets import ImageClassification, InterpolationMode +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "SwinTransformer", + "Swin_T_Weights", + "Swin_S_Weights", + "Swin_B_Weights", + "Swin_V2_T_Weights", + "Swin_V2_S_Weights", + "Swin_V2_B_Weights", + "swin_t", + "swin_s", + "swin_b", + "swin_v2_t", + "swin_v2_s", + "swin_v2_b", +] + + +def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + return x + + +torch.fx.wrap("_patch_merging_pad") + + +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> torch.Tensor: + N = window_size[0] * window_size[1] + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x + + +class PatchMergingV2(nn.Module): + """Patch Merging Layer for Swin Transformer V2. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) # difference + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.reduction(x) # ... H/2 W/2 2*C + x = self.norm(x) + return x + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, + logit_scale: Optional[torch.Tensor] = None, + training: bool = True, +) -> Tensor: + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): Window size. + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + shift_size = shift_size.copy() + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C + + # multi-head attention + if logit_scale is not None and qkv_bias is not None: + qkv_bias = qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length : 2 * length].zero_() + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + if logit_scale is not None: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp() + attn = attn * logit_scale + else: + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) + count = 0 + for h in h_slices: + for w in w_slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout, training=training) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout, training=training) + + # reverse windows + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention") + + +class ShiftedWindowAttention(nn.Module): + """ + See :func:`shifted_window_attention`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self): + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def define_relative_position_index(self): + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def get_relative_position_bias(self) -> torch.Tensor: + return _get_relative_position_bias( + self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + training=self.training, + ) + + +class ShiftedWindowAttentionV2(ShiftedWindowAttention): + """ + See :func:`shifted_window_attention_v2`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__( + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, + ) + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length : 2 * length].data.zero_() + + def define_relative_position_bias_table(self): + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 + ) + self.register_buffer("relative_coords_table", relative_coords_table) + + def get_relative_position_bias(self) -> torch.Tensor: + relative_position_bias = _get_relative_position_bias( + self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), + self.relative_position_index, # type: ignore[arg-type] + self.window_size, + ) + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + return relative_position_bias + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + logit_scale=self.logit_scale, + training=self.training, + ) + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + ): + super().__init__() + _log_api_usage_once(self) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + window_size, + shift_size, + num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + ) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x: Tensor): + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x + + +class SwinTransformerBlockV2(SwinTransformerBlock): + """ + Swin Transformer V2 Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, + ): + super().__init__( + dim, + num_heads, + window_size, + shift_size, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=stochastic_depth_prob, + norm_layer=norm_layer, + attn_layer=attn_layer, + ) + + def forward(self, x: Tensor): + # Here is the difference, we apply norm after the attention in V2. + # In V1 we applied norm before the attention. + x = x + self.stochastic_depth(self.norm1(self.attn(x))) + x = x + self.stochastic_depth(self.norm2(self.mlp(x))) + return x + + +class SwinTransformer(nn.Module): + """ + Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using + Shifted Windows" `_ paper. + Args: + patch_size (List[int]): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (List[int]): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. + num_classes (int): Number of classes for classification head. Default: 1000. + block (nn.Module, optional): SwinTransformer Block. Default: None. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. + """ + + def __init__( + self, + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.1, + num_classes: int = 1000, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, + ): + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = SwinTransformerBlock + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append( + nn.Sequential( + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + ) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2**i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(downsample_layer(dim, norm_layer)) + self.features = nn.Sequential(*layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.flatten = nn.Flatten(1) + self.head = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + x = self.features(x) + x = self.norm(x) + x = self.permute(x) + x = self.avgpool(x) + x = self.flatten(x) + x = self.head(x) + return x + + +def _swin_transformer( + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SwinTransformer: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = SwinTransformer( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "categories": _IMAGENET_CATEGORIES, +} + + +class Swin_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_t-704ceda3.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 28288354, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.474, + "acc@5": 95.776, + } + }, + "_ops": 4.491, + "_file_size": 108.19, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_s-5e29d889.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 49606258, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.196, + "acc@5": 96.360, + } + }, + "_ops": 8.741, + "_file_size": 189.786, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_B_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_b-68c6b09e.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 87768224, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.582, + "acc@5": 96.640, + } + }, + "_ops": 15.431, + "_file_size": 335.364, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 28351570, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.072, + "acc@5": 96.132, + } + }, + "_ops": 5.94, + "_file_size": 108.626, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 49737442, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.712, + "acc@5": 96.816, + } + }, + "_ops": 11.546, + "_file_size": 190.675, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_V2_B_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth", + transforms=partial( + ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 87930848, + "min_size": (256, 256), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2", + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.112, + "acc@5": 96.864, + } + }, + "_ops": 20.325, + "_file_size": 336.372, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin_T_Weights.IMAGENET1K_V1)) +def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_tiny architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_T_Weights + :members: + """ + weights = Swin_T_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.2, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin_S_Weights.IMAGENET1K_V1)) +def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_small architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_S_Weights + :members: + """ + weights = Swin_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.3, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin_B_Weights.IMAGENET1K_V1)) +def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_base architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_B_Weights + :members: + """ + weights = Swin_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[7, 7], + stochastic_depth_prob=0.5, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin_V2_T_Weights.IMAGENET1K_V1)) +def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_tiny architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_T_Weights + :members: + """ + weights = Swin_V2_T_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.2, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin_V2_S_Weights.IMAGENET1K_V1)) +def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_small architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_S_Weights + :members: + """ + weights = Swin_V2_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.3, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin_V2_B_Weights.IMAGENET1K_V1)) +def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_v2_base architecture from + `Swin Transformer V2: Scaling Up Capacity and Resolution `_. + + Args: + weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_V2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_V2_B_Weights + :members: + """ + weights = Swin_V2_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[8, 8], + stochastic_depth_prob=0.5, + weights=weights, + progress=progress, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/models/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1eedd3116001af22ec202d2ccec6eefad8090ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/video/__init__.py @@ -0,0 +1,4 @@ +from .mvit import * +from .resnet import * +from .s3d import * +from .swin_transformer import * diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebfe7fed1ca30b7f66ae8db070758a690cf2fabc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/mvit.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/mvit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b554ebd6133260937ffecef9ad0b29a81cbe4bf3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/mvit.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/resnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..babc7b2e9f3b61279e264636f3942b587eac21f1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/resnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/s3d.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/s3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87363c7a7e5fde1dbcc01e5998bd7e5807a0baee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/s3d.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/swin_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/swin_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bfcc92307d23564486b7b0ed46b08b10a07a18e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/video/__pycache__/swin_transformer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/mvit.py b/.venv/lib/python3.11/site-packages/torchvision/models/video/mvit.py new file mode 100644 index 0000000000000000000000000000000000000000..159c12a4f3eac579f4e122741f57cc60f5cd0a23 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/video/mvit.py @@ -0,0 +1,897 @@ +import math +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.fx +import torch.nn as nn + +from ...ops import MLP, StochasticDepth +from ...transforms._presets import VideoClassification +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "MViT", + "MViT_V1_B_Weights", + "mvit_v1_b", + "MViT_V2_S_Weights", + "mvit_v2_s", +] + + +@dataclass +class MSBlockConfig: + num_heads: int + input_channels: int + output_channels: int + kernel_q: List[int] + kernel_kv: List[int] + stride_q: List[int] + stride_kv: List[int] + + +def _prod(s: Sequence[int]) -> int: + product = 1 + for v in s: + product *= v + return product + + +def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]: + tensor_dim = x.dim() + if tensor_dim == target_dim - 1: + x = x.unsqueeze(expand_dim) + elif tensor_dim != target_dim: + raise ValueError(f"Unsupported input dimension {x.shape}") + return x, tensor_dim + + +def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor: + if tensor_dim == target_dim - 1: + x = x.squeeze(expand_dim) + return x + + +torch.fx.wrap("_unsqueeze") +torch.fx.wrap("_squeeze") + + +class Pool(nn.Module): + def __init__( + self, + pool: nn.Module, + norm: Optional[nn.Module], + activation: Optional[nn.Module] = None, + norm_before_pool: bool = False, + ) -> None: + super().__init__() + self.pool = pool + layers = [] + if norm is not None: + layers.append(norm) + if activation is not None: + layers.append(activation) + self.norm_act = nn.Sequential(*layers) if layers else None + self.norm_before_pool = norm_before_pool + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x, tensor_dim = _unsqueeze(x, 4, 1) + + # Separate the class token and reshape the input + class_token, x = torch.tensor_split(x, indices=(1,), dim=2) + x = x.transpose(2, 3) + B, N, C = x.shape[:3] + x = x.reshape((B * N, C) + thw).contiguous() + + # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference + if self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + # apply the pool on the input and add back the token + x = self.pool(x) + T, H, W = x.shape[2:] + x = x.reshape(B, N, C, -1).transpose(2, 3) + x = torch.cat((class_token, x), dim=2) + + if not self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + x = _squeeze(x, 4, 1, tensor_dim) + return x, (T, H, W) + + +def _interpolate(embedding: torch.Tensor, d: int) -> torch.Tensor: + if embedding.shape[0] == d: + return embedding + + return ( + nn.functional.interpolate( + embedding.permute(1, 0).unsqueeze(0), + size=d, + mode="linear", + ) + .squeeze(0) + .permute(1, 0) + ) + + +def _add_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + q_thw: Tuple[int, int, int], + k_thw: Tuple[int, int, int], + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + rel_pos_t: torch.Tensor, +) -> torch.Tensor: + # Modified code from: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932 + q_t, q_h, q_w = q_thw + k_t, k_h, k_w = k_thw + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + dt = int(2 * max(q_t, k_t) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = torch.arange(q_h)[:, None] * q_h_ratio - (torch.arange(k_h)[None, :] + (1.0 - k_h)) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = torch.arange(q_w)[:, None] * q_w_ratio - (torch.arange(k_w)[None, :] + (1.0 - k_w)) * k_w_ratio + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = torch.arange(q_t)[:, None] * q_t_ratio - (torch.arange(k_t)[None, :] + (1.0 - k_t)) * k_t_ratio + + # Interpolate rel pos if needed. + rel_pos_h = _interpolate(rel_pos_h, dh) + rel_pos_w = _interpolate(rel_pos_w, dw) + rel_pos_t = _interpolate(rel_pos_t, dt) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + Rt = rel_pos_t[dist_t.long()] + + B, n_head, _, dim = q.shape + + r_q = q[:, :, 1:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum("bythwc,hkc->bythwk", r_q, Rh) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum("bythwc,wkc->bythwk", r_q, Rw) # [B, H, q_t, qh, qw, k_w] + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, dim) + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel_q_t = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel_q_t = rel_q_t.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + # Combine rel pos. + rel_pos = ( + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :] + + rel_q_t[:, :, :, :, :, :, None, None] + ).reshape(B, n_head, q_t * q_h * q_w, k_t * k_h * k_w) + + # Add it to attention + attn[:, :, 1:, 1:] += rel_pos + + return attn + + +def _add_shortcut(x: torch.Tensor, shortcut: torch.Tensor, residual_with_cls_embed: bool): + if residual_with_cls_embed: + x.add_(shortcut) + else: + x[:, :, 1:, :] += shortcut[:, :, 1:, :] + return x + + +torch.fx.wrap("_add_rel_pos") +torch.fx.wrap("_add_shortcut") + + +class MultiscaleAttention(nn.Module): + def __init__( + self, + input_size: List[int], + embed_dim: int, + output_dim: int, + num_heads: int, + kernel_q: List[int], + kernel_kv: List[int], + stride_q: List[int], + stride_kv: List[int], + residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + dropout: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.output_dim = output_dim + self.num_heads = num_heads + self.head_dim = output_dim // num_heads + self.scaler = 1.0 / math.sqrt(self.head_dim) + self.residual_pool = residual_pool + self.residual_with_cls_embed = residual_with_cls_embed + + self.qkv = nn.Linear(embed_dim, 3 * output_dim) + layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)] + if dropout > 0.0: + layers.append(nn.Dropout(dropout, inplace=True)) + self.project = nn.Sequential(*layers) + + self.pool_q: Optional[nn.Module] = None + if _prod(kernel_q) > 1 or _prod(stride_q) > 1: + padding_q = [int(q // 2) for q in kernel_q] + self.pool_q = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_q, # type: ignore[arg-type] + stride=stride_q, # type: ignore[arg-type] + padding=padding_q, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + self.pool_k: Optional[nn.Module] = None + self.pool_v: Optional[nn.Module] = None + if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1: + padding_kv = [int(kv // 2) for kv in kernel_kv] + self.pool_k = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + self.pool_v = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + self.rel_pos_h: Optional[nn.Parameter] = None + self.rel_pos_w: Optional[nn.Parameter] = None + self.rel_pos_t: Optional[nn.Parameter] = None + if rel_pos_embed: + size = max(input_size[1:]) + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + spatial_dim = 2 * max(q_size, kv_size) - 1 + temporal_dim = 2 * input_size[0] - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim)) + self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim)) + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + nn.init.trunc_normal_(self.rel_pos_t, std=0.02) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + B, N, C = x.shape + q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) + + if self.pool_k is not None: + k, k_thw = self.pool_k(k, thw) + else: + k_thw = thw + if self.pool_v is not None: + v = self.pool_v(v, thw)[0] + if self.pool_q is not None: + q, thw = self.pool_q(q, thw) + + attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) + if self.rel_pos_h is not None and self.rel_pos_w is not None and self.rel_pos_t is not None: + attn = _add_rel_pos( + attn, + q, + thw, + k_thw, + self.rel_pos_h, + self.rel_pos_w, + self.rel_pos_t, + ) + attn = attn.softmax(dim=-1) + + x = torch.matmul(attn, v) + if self.residual_pool: + _add_shortcut(x, q, self.residual_with_cls_embed) + x = x.transpose(1, 2).reshape(B, -1, self.output_dim) + x = self.project(x) + + return x, thw + + +class MultiscaleBlock(nn.Module): + def __init__( + self, + input_size: List[int], + cnf: MSBlockConfig, + residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + proj_after_attn: bool, + dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.proj_after_attn = proj_after_attn + + self.pool_skip: Optional[nn.Module] = None + if _prod(cnf.stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q] + padding_skip = [int(k // 2) for k in kernel_skip] + self.pool_skip = Pool( + nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] + ) + + attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels + + self.norm1 = norm_layer(cnf.input_channels) + self.norm2 = norm_layer(attn_dim) + self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) + + self.attn = MultiscaleAttention( + input_size, + cnf.input_channels, + attn_dim, + cnf.num_heads, + kernel_q=cnf.kernel_q, + kernel_kv=cnf.kernel_kv, + stride_q=cnf.stride_q, + stride_kv=cnf.stride_kv, + rel_pos_embed=rel_pos_embed, + residual_pool=residual_pool, + residual_with_cls_embed=residual_with_cls_embed, + dropout=dropout, + norm_layer=norm_layer, + ) + self.mlp = MLP( + attn_dim, + [4 * attn_dim, cnf.output_channels], + activation_layer=nn.GELU, + dropout=dropout, + inplace=None, + ) + + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + self.project: Optional[nn.Module] = None + if cnf.input_channels != cnf.output_channels: + self.project = nn.Linear(cnf.input_channels, cnf.output_channels) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) + x_attn, thw_new = self.attn(x_norm1, thw) + x = x if self.project is None or not self.proj_after_attn else self.project(x_norm1) + x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] + x = x_skip + self.stochastic_depth(x_attn) + + x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) + x_proj = x if self.project is None or self.proj_after_attn else self.project(x_norm2) + + return x_proj + self.stochastic_depth(self.mlp(x_norm2)), thw_new + + +class PositionalEncoding(nn.Module): + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None: + super().__init__() + self.spatial_size = spatial_size + self.temporal_size = temporal_size + + self.class_token = nn.Parameter(torch.zeros(embed_size)) + self.spatial_pos: Optional[nn.Parameter] = None + self.temporal_pos: Optional[nn.Parameter] = None + self.class_pos: Optional[nn.Parameter] = None + if not rel_pos_embed: + self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) + self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) + self.class_pos = nn.Parameter(torch.zeros(embed_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) + x = torch.cat((class_token, x), dim=1) + + if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None: + hw_size, embed_size = self.spatial_pos.shape + pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) + pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) + pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) + x.add_(pos_embedding) + + return x + + +class MViT(nn.Module): + def __init__( + self, + spatial_size: Tuple[int, int], + temporal_size: int, + block_setting: Sequence[MSBlockConfig], + residual_pool: bool, + residual_with_cls_embed: bool, + rel_pos_embed: bool, + proj_after_attn: bool, + dropout: float = 0.5, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + num_classes: int = 400, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + patch_embed_kernel: Tuple[int, int, int] = (3, 7, 7), + patch_embed_stride: Tuple[int, int, int] = (2, 4, 4), + patch_embed_padding: Tuple[int, int, int] = (1, 3, 3), + ) -> None: + """ + MViT main class. + + Args: + spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. + temporal_size (int): The temporal size ``T`` of the input. + block_setting (sequence of MSBlockConfig): The Network structure. + residual_pool (bool): If True, use MViTv2 pooling residual connection. + residual_with_cls_embed (bool): If True, the addition on the residual connection will include + the class embedding. + rel_pos_embed (bool): If True, use MViTv2's relative positional embeddings. + proj_after_attn (bool): If True, apply the projection after the attention. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + num_classes (int): The number of classes. + block (callable, optional): Module specifying the layer which consists of the attention and mlp. + norm_layer (callable, optional): Module specifying the normalization layer to use. + patch_embed_kernel (tuple of ints): The kernel of the convolution that patchifies the input. + patch_embed_stride (tuple of ints): The stride of the convolution that patchifies the input. + patch_embed_padding (tuple of ints): The padding of the convolution that patchifies the input. + """ + super().__init__() + # This implementation employs a different parameterization scheme than the one used at PyTorch Video: + # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py + # We remove any experimental configuration that didn't make it to the final variants of the models. To represent + # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper. + _log_api_usage_once(self) + total_stage_blocks = len(block_setting) + if total_stage_blocks == 0: + raise ValueError("The configuration parameter can't be empty.") + + if block is None: + block = MultiscaleBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # Patch Embedding module + self.conv_proj = nn.Conv3d( + in_channels=3, + out_channels=block_setting[0].input_channels, + kernel_size=patch_embed_kernel, + stride=patch_embed_stride, + padding=patch_embed_padding, + ) + + input_size = [size // stride for size, stride in zip((temporal_size,) + spatial_size, self.conv_proj.stride)] + + # Spatio-Temporal Class Positional Encoding + self.pos_encoding = PositionalEncoding( + embed_size=block_setting[0].input_channels, + spatial_size=(input_size[1], input_size[2]), + temporal_size=input_size[0], + rel_pos_embed=rel_pos_embed, + ) + + # Encoder module + self.blocks = nn.ModuleList() + for stage_block_id, cnf in enumerate(block_setting): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + + self.blocks.append( + block( + input_size=input_size, + cnf=cnf, + residual_pool=residual_pool, + residual_with_cls_embed=residual_with_cls_embed, + rel_pos_embed=rel_pos_embed, + proj_after_attn=proj_after_attn, + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + + if len(cnf.stride_q) > 0: + input_size = [size // stride for size, stride in zip(input_size, cnf.stride_q)] + self.norm = norm_layer(block_setting[-1].output_channels) + + # Classifier module + self.head = nn.Sequential( + nn.Dropout(dropout, inplace=True), + nn.Linear(block_setting[-1].output_channels, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.LayerNorm): + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, PositionalEncoding): + for weights in m.parameters(): + nn.init.trunc_normal_(weights, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Convert if necessary (B, C, H, W) -> (B, C, 1, H, W) + x = _unsqueeze(x, 5, 2)[0] + # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) + x = self.conv_proj(x) + x = x.flatten(2).transpose(1, 2) + + # add positional encoding + x = self.pos_encoding(x) + + # pass patches through the encoder + thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size + for block in self.blocks: + x, thw = block(x, thw) + x = self.norm(x) + + # classifier "token" as used by standard language architectures + x = x[:, 0] + x = self.head(x) + + return x + + +def _mvit( + block_setting: List[MSBlockConfig], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> MViT: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"]) + _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"]) + spatial_size = kwargs.pop("spatial_size", (224, 224)) + temporal_size = kwargs.pop("temporal_size", 16) + + model = MViT( + spatial_size=spatial_size, + temporal_size=temporal_size, + block_setting=block_setting, + residual_pool=kwargs.pop("residual_pool", False), + residual_with_cls_embed=kwargs.pop("residual_with_cls_embed", True), + rel_pos_embed=kwargs.pop("rel_pos_embed", False), + proj_after_attn=kwargs.pop("proj_after_attn", False), + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +class MViT_V1_B_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), + "num_params": 36610672, + "_metrics": { + "Kinetics-400": { + "acc@1": 78.477, + "acc@5": 93.582, + } + }, + "_ops": 70.599, + "_file_size": 139.764, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MViT_V2_S_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v2_s-ae3be167.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`" + ), + "num_params": 34537744, + "_metrics": { + "Kinetics-400": { + "acc@1": 80.757, + "acc@5": 94.665, + } + }, + "_ops": 64.224, + "_file_size": 131.884, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MViT_V1_B_Weights.KINETICS400_V1)) +def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: + """ + Constructs a base MViTV1 architecture from + `Multiscale Vision Transformers `__. + + .. betastatus:: video module + + Args: + weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V1_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V1_B_Weights + :members: + """ + weights = MViT_V1_B_Weights.verify(weights) + + config: Dict[str, List] = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768], + "kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) + + return _mvit( + spatial_size=(224, 224), + temporal_size=16, + block_setting=block_setting, + residual_pool=False, + residual_with_cls_embed=False, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1)) +def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: + """Constructs a small MViTV2 architecture from + `Multiscale Vision Transformers `__ and + `MViTv2: Improved Multiscale Vision Transformers for Classification + and Detection `__. + + .. betastatus:: video module + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + :members: + """ + weights = MViT_V2_S_Weights.verify(weights) + + config: Dict[str, List] = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768], + "output_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "kernel_q": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [ + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 2, 2], + [1, 1, 1], + ], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) + + return _mvit( + spatial_size=(224, 224), + temporal_size=16, + block_setting=block_setting, + residual_pool=True, + residual_with_cls_embed=False, + rel_pos_embed=True, + proj_after_attn=True, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), + weights=weights, + progress=progress, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/resnet.py b/.venv/lib/python3.11/site-packages/torchvision/models/video/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cb2884013c053118555344617e4b1efb8ddaab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/video/resnet.py @@ -0,0 +1,503 @@ +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union + +import torch.nn as nn +from torch import Tensor + +from ...transforms._presets import VideoClassification +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "VideoResNet", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", + "r3d_18", + "mc3_18", + "r2plus1d_18", +] + + +class Conv3DSimple(nn.Conv3d): + def __init__( + self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1 + ) -> None: + + super().__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: + return stride, stride, stride + + +class Conv2Plus1D(nn.Sequential): + def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None: + super().__init__( + nn.Conv3d( + in_planes, + midplanes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ), + nn.BatchNorm3d(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d( + midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False + ), + ) + + @staticmethod + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: + return stride, stride, stride + + +class Conv3DNoTemporal(nn.Conv3d): + def __init__( + self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1 + ) -> None: + + super().__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride: int) -> Tuple[int, int, int]: + return 1, stride, stride + + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__( + self, + inplanes: int, + planes: int, + conv_builder: Callable[..., nn.Module], + stride: int = 1, + downsample: Optional[nn.Module] = None, + ) -> None: + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super().__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes)) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes: int, + planes: int, + conv_builder: Callable[..., nn.Module], + stride: int = 1, + downsample: Optional[nn.Module] = None, + ) -> None: + + super().__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + nn.BatchNorm3d(planes * self.expansion), + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem""" + + def __init__(self) -> None: + super().__init__( + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + + +class R2Plus1dStem(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution""" + + def __init__(self) -> None: + super().__init__( + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False), + nn.BatchNorm3d(45), + nn.ReLU(inplace=True), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + ) + + +class VideoResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + num_classes: int = 400, + zero_init_residual: bool = False, + ) -> None: + """Generic resnet video generator. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): resnet building block + conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator + function for each layer + layers (List[int]): number of blocks per layer + stem (Callable[..., nn.Module]): module specifying the ResNet stem. + num_classes (int, optional): Dimension of the final FC layer. Defaults to 400. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super().__init__() + _log_api_usage_once(self) + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2) + + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + # init weights + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type] + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + # Flatten the layer to fc + x = x.flatten(1) + x = self.fc(x) + + return x + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]], + planes: int, + blocks: int, + stride: int = 1, + ) -> nn.Sequential: + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + +def _video_resnet( + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> VideoResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = VideoResNet(block, conv_makers, layers, stem, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", + "_docs": ( + "The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`." + ), +} + + +class R3D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "num_params": 33371472, + "_metrics": { + "Kinetics-400": { + "acc@1": 63.200, + "acc@5": 83.479, + } + }, + "_ops": 40.697, + "_file_size": 127.359, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MC3_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "num_params": 11695440, + "_metrics": { + "Kinetics-400": { + "acc@1": 63.960, + "acc@5": 84.130, + } + }, + "_ops": 43.343, + "_file_size": 44.672, + }, + ) + DEFAULT = KINETICS400_V1 + + +class R2Plus1D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "num_params": 31505325, + "_metrics": { + "Kinetics-400": { + "acc@1": 67.463, + "acc@5": 86.175, + } + }, + "_ops": 40.519, + "_file_size": 120.318, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) +def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: + """Construct 18 layer Resnet3D model. + + .. betastatus:: video module + + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. + + Args: + weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.R3D_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.R3D_18_Weights + :members: + """ + weights = R3D_18_Weights.verify(weights) + + return _video_resnet( + BasicBlock, + [Conv3DSimple] * 4, + [2, 2, 2, 2], + BasicStem, + weights, + progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) +def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: + """Construct 18 layer Mixed Convolution network as in + + .. betastatus:: video module + + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. + + Args: + weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MC3_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MC3_18_Weights + :members: + """ + weights = MC3_18_Weights.verify(weights) + + return _video_resnet( + BasicBlock, + [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] + [2, 2, 2, 2], + BasicStem, + weights, + progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) +def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: + """Construct 18 layer deep R(2+1)D network as in + + .. betastatus:: video module + + Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition `__. + + Args: + weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.R2Plus1D_18_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights + :members: + """ + weights = R2Plus1D_18_Weights.verify(weights) + + return _video_resnet( + BasicBlock, + [Conv2Plus1D] * 4, + [2, 2, 2, 2], + R2Plus1dStem, + weights, + progress, + **kwargs, + ) + + +# The dictionary below is internal implementation detail and will be removed in v0.15 +from .._utils import _ModelURLs + + +model_urls = _ModelURLs( + { + "r3d_18": R3D_18_Weights.KINETICS400_V1.url, + "mc3_18": MC3_18_Weights.KINETICS400_V1.url, + "r2plus1d_18": R2Plus1D_18_Weights.KINETICS400_V1.url, + } +) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/s3d.py b/.venv/lib/python3.11/site-packages/torchvision/models/video/s3d.py new file mode 100644 index 0000000000000000000000000000000000000000..4b202829b24fb1dc314452d38a521dfe6c8e446f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/video/s3d.py @@ -0,0 +1,219 @@ +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch import nn +from torchvision.ops.misc import Conv3dNormActivation + +from ...transforms._presets import VideoClassification +from ...utils import _log_api_usage_once +from .._api import register_model, Weights, WeightsEnum +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "S3D", + "S3D_Weights", + "s3d", +] + + +class TemporalSeparableConv(nn.Sequential): + def __init__( + self, + in_planes: int, + out_planes: int, + kernel_size: int, + stride: int, + padding: int, + norm_layer: Callable[..., nn.Module], + ): + super().__init__( + Conv3dNormActivation( + in_planes, + out_planes, + kernel_size=(1, kernel_size, kernel_size), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + norm_layer=norm_layer, + ), + Conv3dNormActivation( + out_planes, + out_planes, + kernel_size=(kernel_size, 1, 1), + stride=(stride, 1, 1), + padding=(padding, 0, 0), + bias=False, + norm_layer=norm_layer, + ), + ) + + +class SepInceptionBlock3D(nn.Module): + def __init__( + self, + in_planes: int, + b0_out: int, + b1_mid: int, + b1_out: int, + b2_mid: int, + b2_out: int, + b3_out: int, + norm_layer: Callable[..., nn.Module], + ): + super().__init__() + + self.branch0 = Conv3dNormActivation(in_planes, b0_out, kernel_size=1, stride=1, norm_layer=norm_layer) + self.branch1 = nn.Sequential( + Conv3dNormActivation(in_planes, b1_mid, kernel_size=1, stride=1, norm_layer=norm_layer), + TemporalSeparableConv(b1_mid, b1_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), + ) + self.branch2 = nn.Sequential( + Conv3dNormActivation(in_planes, b2_mid, kernel_size=1, stride=1, norm_layer=norm_layer), + TemporalSeparableConv(b2_mid, b2_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), + Conv3dNormActivation(in_planes, b3_out, kernel_size=1, stride=1, norm_layer=norm_layer), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + + return out + + +class S3D(nn.Module): + """S3D main class. + + Args: + num_class (int): number of classes for the classification task. + dropout (float): dropout probability. + norm_layer (Optional[Callable]): Module specifying the normalization layer to use. + + Inputs: + x (Tensor): batch of videos with dimensions (batch, channel, time, height, width) + """ + + def __init__( + self, + num_classes: int = 400, + dropout: float = 0.2, + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm3d, eps=0.001, momentum=0.001) + + self.features = nn.Sequential( + TemporalSeparableConv(3, 64, 7, 2, 3, norm_layer), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + Conv3dNormActivation( + 64, + 64, + kernel_size=1, + stride=1, + norm_layer=norm_layer, + ), + TemporalSeparableConv(64, 192, 3, 1, 1, norm_layer), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + SepInceptionBlock3D(192, 64, 96, 128, 16, 32, 32, norm_layer), + SepInceptionBlock3D(256, 128, 128, 192, 32, 96, 64, norm_layer), + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), + SepInceptionBlock3D(480, 192, 96, 208, 16, 48, 64, norm_layer), + SepInceptionBlock3D(512, 160, 112, 224, 24, 64, 64, norm_layer), + SepInceptionBlock3D(512, 128, 128, 256, 24, 64, 64, norm_layer), + SepInceptionBlock3D(512, 112, 144, 288, 32, 64, 64, norm_layer), + SepInceptionBlock3D(528, 256, 160, 320, 32, 128, 128, norm_layer), + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)), + SepInceptionBlock3D(832, 256, 160, 320, 32, 128, 128, norm_layer), + SepInceptionBlock3D(832, 384, 192, 384, 48, 128, 128, norm_layer), + ) + self.avgpool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) + self.classifier = nn.Sequential( + nn.Dropout(p=dropout), + nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True), + ) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = self.classifier(x) + x = torch.mean(x, dim=(2, 3, 4)) + return x + + +class S3D_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/s3d-d76dad2f.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256, 256), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 14, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification#s3d", + "_docs": ( + "The weights aim to approximate the accuracy of the paper. The accuracies are estimated on clip-level " + "with parameters `frame_rate=15`, `clips_per_video=1`, and `clip_len=128`." + ), + "num_params": 8320048, + "_metrics": { + "Kinetics-400": { + "acc@1": 68.368, + "acc@5": 88.050, + } + }, + "_ops": 17.979, + "_file_size": 31.972, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", S3D_Weights.KINETICS400_V1)) +def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwargs: Any) -> S3D: + """Construct Separable 3D CNN model. + + Reference: `Rethinking Spatiotemporal Feature Learning `__. + + .. betastatus:: video module + + Args: + weights (:class:`~torchvision.models.video.S3D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.S3D_Weights` + below for more details, and possible values. By default, no + pre-trained weights are used. + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.S3D`` base class. + Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.S3D_Weights + :members: + """ + weights = S3D_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = S3D(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/video/swin_transformer.py b/.venv/lib/python3.11/site-packages/torchvision/models/video/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d87ffbe5af6caa1de0a3760fa5c506fdf8e231 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/video/swin_transformer.py @@ -0,0 +1,743 @@ +# Modified from 2d Swin Transformers in torchvision: +# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py + +from functools import partial +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ...transforms._presets import VideoClassification + +from ...utils import _log_api_usage_once + +from .._api import register_model, Weights, WeightsEnum + +from .._meta import _KINETICS400_CATEGORIES +from .._utils import _ovewrite_named_param, handle_legacy_interface +from ..swin_transformer import PatchMerging, SwinTransformerBlock + +__all__ = [ + "SwinTransformer3d", + "Swin3D_T_Weights", + "Swin3D_S_Weights", + "Swin3D_B_Weights", + "swin3d_t", + "swin3d_s", + "swin3d_b", +] + + +def _get_window_and_shift_size( + shift_size: List[int], size_dhw: List[int], window_size: List[int] +) -> Tuple[List[int], List[int]]: + for i in range(3): + if size_dhw[i] <= window_size[i]: + # In this case, window_size will adapt to the input size, and no need to shift + window_size[i] = size_dhw[i] + shift_size[i] = 0 + + return window_size, shift_size + + +torch.fx.wrap("_get_window_and_shift_size") + + +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> Tensor: + window_vol = window_size[0] * window_size[1] * window_size[2] + # In 3d case we flatten the relative_position_bias + relative_position_bias = relative_position_bias_table[ + relative_position_index[:window_vol, :window_vol].flatten() # type: ignore[index] + ] + relative_position_bias = relative_position_bias.view(window_vol, window_vol, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + +def _compute_pad_size_3d(size_dhw: Tuple[int, int, int], patch_size: Tuple[int, int, int]) -> Tuple[int, int, int]: + pad_size = [(patch_size[i] - size_dhw[i] % patch_size[i]) % patch_size[i] for i in range(3)] + return pad_size[0], pad_size[1], pad_size[2] + + +torch.fx.wrap("_compute_pad_size_3d") + + +def _compute_attention_mask_3d( + x: Tensor, + size_dhw: Tuple[int, int, int], + window_size: Tuple[int, int, int], + shift_size: Tuple[int, int, int], +) -> Tensor: + # generate attention mask + attn_mask = x.new_zeros(*size_dhw) + num_windows = (size_dhw[0] // window_size[0]) * (size_dhw[1] // window_size[1]) * (size_dhw[2] // window_size[2]) + slices = [ + ( + (0, -window_size[i]), + (-window_size[i], -shift_size[i]), + (-shift_size[i], None), + ) + for i in range(3) + ] + count = 0 + for d in slices[0]: + for h in slices[1]: + for w in slices[2]: + attn_mask[d[0] : d[1], h[0] : h[1], w[0] : w[1]] = count + count += 1 + + # Partition window on attn_mask + attn_mask = attn_mask.view( + size_dhw[0] // window_size[0], + window_size[0], + size_dhw[1] // window_size[1], + window_size[1], + size_dhw[2] // window_size[2], + window_size[2], + ) + attn_mask = attn_mask.permute(0, 2, 4, 1, 3, 5).reshape( + num_windows, window_size[0] * window_size[1] * window_size[2] + ) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + +torch.fx.wrap("_compute_attention_mask_3d") + + +def shifted_window_attention_3d( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, + training: bool = True, +) -> Tensor: + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): 3-dimensions window size, T, H, W . + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention (T, H, W). + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. + Returns: + Tensor[B, T, H, W, C]: The output tensor after shifted window attention. + """ + b, t, h, w, c = input.shape + # pad feature maps to multiples of window size + pad_size = _compute_pad_size_3d((t, h, w), (window_size[0], window_size[1], window_size[2])) + x = F.pad(input, (0, 0, 0, pad_size[2], 0, pad_size[1], 0, pad_size[0])) + _, tp, hp, wp, _ = x.shape + padded_size = (tp, hp, wp) + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + + # partition windows + num_windows = ( + (padded_size[0] // window_size[0]) * (padded_size[1] // window_size[1]) * (padded_size[2] // window_size[2]) + ) + x = x.view( + b, + padded_size[0] // window_size[0], + window_size[0], + padded_size[1] // window_size[1], + window_size[1], + padded_size[2] // window_size[2], + window_size[2], + c, + ) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape( + b * num_windows, window_size[0] * window_size[1] * window_size[2], c + ) # B*nW, Wd*Wh*Ww, C + + # multi-head attention + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, c // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * (c // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask to handle shifted windows with varying size + attn_mask = _compute_attention_mask_3d( + x, + (padded_size[0], padded_size[1], padded_size[2]), + (window_size[0], window_size[1], window_size[2]), + (shift_size[0], shift_size[1], shift_size[2]), + ) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout, training=training) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout, training=training) + + # reverse windows + x = x.view( + b, + padded_size[0] // window_size[0], + padded_size[1] // window_size[1], + padded_size[2] // window_size[2], + window_size[0], + window_size[1], + window_size[2], + c, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, tp, hp, wp, c) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + + # unpad features + x = x[:, :t, :h, :w, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention_3d") + + +class ShiftedWindowAttention3d(nn.Module): + """ + See :func:`shifted_window_attention_3d`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ) -> None: + super().__init__() + if len(window_size) != 3 or len(shift_size) != 3: + raise ValueError("window_size and shift_size must be of length 2") + + self.window_size = window_size # Wd, Wh, Ww + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self) -> None: + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), + self.num_heads, + ) + ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def define_relative_position_index(self) -> None: + # get pair-wise relative position index for each token inside the window + coords_dhw = [torch.arange(self.window_size[i]) for i in range(3)] + coords = torch.stack( + torch.meshgrid(coords_dhw[0], coords_dhw[1], coords_dhw[2], indexing="ij") + ) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + # We don't flatten the relative_position_index here in 3d case. + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def get_relative_position_bias(self, window_size: List[int]) -> torch.Tensor: + return _get_relative_position_bias(self.relative_position_bias_table, self.relative_position_index, window_size) # type: ignore + + def forward(self, x: Tensor) -> Tensor: + _, t, h, w, _ = x.shape + size_dhw = [t, h, w] + window_size, shift_size = self.window_size.copy(), self.shift_size.copy() + # Handle case where window_size is larger than the input tensor + window_size, shift_size = _get_window_and_shift_size(shift_size, size_dhw, window_size) + + relative_position_bias = self.get_relative_position_bias(window_size) + + return shifted_window_attention_3d( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + window_size, + self.num_heads, + shift_size=shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + training=self.training, + ) + + +# Modified from: +# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py +class PatchEmbed3d(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (List[int]): Patch token size. + in_channels (int): Number of input channels. Default: 3 + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size: List[int], + in_channels: int = 3, + embed_dim: int = 96, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.tuple_patch_size = (patch_size[0], patch_size[1], patch_size[2]) + + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=self.tuple_patch_size, + stride=self.tuple_patch_size, + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + # padding + _, _, t, h, w = x.size() + pad_size = _compute_pad_size_3d((t, h, w), self.tuple_patch_size) + x = F.pad(x, (0, pad_size[2], 0, pad_size[1], 0, pad_size[0])) + x = self.proj(x) # B C T Wh Ww + x = x.permute(0, 2, 3, 4, 1) # B T Wh Ww C + if self.norm is not None: + x = self.norm(x) + return x + + +class SwinTransformer3d(nn.Module): + """ + Implements 3D Swin Transformer from the `"Video Swin Transformer" `_ paper. + Args: + patch_size (List[int]): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (List[int]): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. + num_classes (int): Number of classes for classification head. Default: 400. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + block (nn.Module, optional): SwinTransformer Block. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. + patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. + """ + + def __init__( + self, + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.1, + num_classes: int = 400, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, + patch_embed: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = partial(SwinTransformerBlock, attn_layer=ShiftedWindowAttention3d) + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + if patch_embed is None: + patch_embed = PatchEmbed3d + + # split image into non-overlapping patches + self.patch_embed = patch_embed(patch_size=patch_size, embed_dim=embed_dim, norm_layer=norm_layer) + self.pos_drop = nn.Dropout(p=dropout) + + layers: List[nn.Module] = [] + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2**i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + attn_layer=ShiftedWindowAttention3d, + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(downsample_layer(dim, norm_layer)) + self.features = nn.Sequential(*layers) + + self.num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool3d(1) + self.head = nn.Linear(self.num_features, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: Tensor) -> Tensor: + # x: B C T H W + x = self.patch_embed(x) # B _T _H _W C + x = self.pos_drop(x) + x = self.features(x) # B _T _H _W C + x = self.norm(x) + x = x.permute(0, 4, 1, 2, 3) # B, C, _T, _H, _W + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.head(x) + return x + + +def _swin_transformer3d( + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SwinTransformer3d: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = SwinTransformer3d( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META = { + "categories": _KINETICS400_CATEGORIES, + "min_size": (1, 1), + "min_temporal_size": 1, +} + + +class Swin3D_T_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_t-7615ae03.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 28158070, + "_metrics": { + "Kinetics-400": { + "acc@1": 77.715, + "acc@5": 93.519, + } + }, + "_ops": 43.882, + "_file_size": 121.543, + }, + ) + DEFAULT = KINETICS400_V1 + + +class Swin3D_S_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_s-da41c237.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 49816678, + "_metrics": { + "Kinetics-400": { + "acc@1": 79.521, + "acc@5": 94.158, + } + }, + "_ops": 82.841, + "_file_size": 218.288, + }, + ) + DEFAULT = KINETICS400_V1 + + +class Swin3D_B_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 88048984, + "_metrics": { + "Kinetics-400": { + "acc@1": 79.427, + "acc@5": 94.386, + } + }, + "_ops": 140.667, + "_file_size": 364.134, + }, + ) + KINETICS400_IMAGENET22K_V1 = Weights( + url="https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.4850, 0.4560, 0.4060), + std=(0.2290, 0.2240, 0.2250), + ), + meta={ + **_COMMON_META, + "recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400", + "_docs": ( + "The weights were ported from the paper. The accuracies are estimated on video-level " + "with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`" + ), + "num_params": 88048984, + "_metrics": { + "Kinetics-400": { + "acc@1": 81.643, + "acc@5": 95.574, + } + }, + "_ops": 140.667, + "_file_size": 364.134, + }, + ) + DEFAULT = KINETICS400_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin3D_T_Weights.KINETICS400_V1)) +def swin3d_t(*, weights: Optional[Swin3D_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d: + """ + Constructs a swin_tiny architecture from + `Video Swin Transformer `_. + + Args: + weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.Swin3D_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.Swin3D_T_Weights + :members: + """ + weights = Swin3D_T_Weights.verify(weights) + + return _swin_transformer3d( + patch_size=[2, 4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 7, 7], + stochastic_depth_prob=0.1, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin3D_S_Weights.KINETICS400_V1)) +def swin3d_s(*, weights: Optional[Swin3D_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d: + """ + Constructs a swin_small architecture from + `Video Swin Transformer `_. + + Args: + weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.Swin3D_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.Swin3D_S_Weights + :members: + """ + weights = Swin3D_S_Weights.verify(weights) + + return _swin_transformer3d( + patch_size=[2, 4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 7, 7], + stochastic_depth_prob=0.1, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", Swin3D_B_Weights.KINETICS400_V1)) +def swin3d_b(*, weights: Optional[Swin3D_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d: + """ + Constructs a swin_base architecture from + `Video Swin Transformer `_. + + Args: + weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.Swin3D_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.Swin3D_B_Weights + :members: + """ + weights = Swin3D_B_Weights.verify(weights) + + return _swin_transformer3d( + patch_size=[2, 4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[8, 7, 7], + stochastic_depth_prob=0.1, + weights=weights, + progress=progress, + **kwargs, + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/vision_transformer.py b/.venv/lib/python3.11/site-packages/torchvision/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f2983ef9db0a3adf232dd6fe1b90ce9417ce0853 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/models/vision_transformer.py @@ -0,0 +1,864 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, List, NamedTuple, Optional + +import torch +import torch.nn as nn + +from ..ops.misc import Conv2dNormActivation, MLP +from ..transforms._presets import ImageClassification, InterpolationMode +from ..utils import _log_api_usage_once +from ._api import register_model, Weights, WeightsEnum +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param, handle_legacy_interface + + +__all__ = [ + "VisionTransformer", + "ViT_B_16_Weights", + "ViT_B_32_Weights", + "ViT_L_16_Weights", + "ViT_L_32_Weights", + "ViT_H_14_Weights", + "vit_b_16", + "vit_b_32", + "vit_l_16", + "vit_l_32", + "vit_h_14", +] + + +class ConvStemConfig(NamedTuple): + out_channels: int + kernel_size: int + stride: int + norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d + activation_layer: Callable[..., nn.Module] = nn.ReLU + + +class MLPBlock(MLP): + """Transformer MLP block.""" + + _version = 2 + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): + super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053 + for i in range(2): + for type in ["weight", "bias"]: + old_key = f"{prefix}linear_{i+1}.{type}" + new_key = f"{prefix}{3*i}.{type}" + if old_key in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(x, x, x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def forward(self, input: torch.Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self.pos_embedding + return self.ln(self.layers(self.dropout(input))) + + +class VisionTransformer(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__() + _log_api_usage_once(self) + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_classes = num_classes + self.representation_size = representation_size + self.norm_layer = norm_layer + + if conv_stem_configs is not None: + # As per https://arxiv.org/abs/2106.14881 + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_bn_relu_{i}", + Conv2dNormActivation( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + norm_layer=conv_stem_layer_config.norm_layer, + activation_layer=conv_stem_layer_config.activation_layer, + ), + ) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) + ) + self.conv_proj: nn.Module = seq_proj + else: + self.conv_proj = nn.Conv2d( + in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size + ) + + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.seq_length = seq_length + + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + else: + heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = nn.Linear(representation_size, num_classes) + + self.heads = nn.Sequential(heads_layers) + + if isinstance(self.conv_proj, nn.Conv2d): + # Init the patchify stem + fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + if self.conv_proj.bias is not None: + nn.init.zeros_(self.conv_proj.bias) + elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): + # Init the last 1x1 conv of the conv stem + nn.init.normal_( + self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) + ) + if self.conv_proj.conv_last.bias is not None: + nn.init.zeros_(self.conv_proj.conv_last.bias) + + if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.heads.pre_logits.bias) + + if isinstance(self.heads.head, nn.Linear): + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) + + def _process_input(self, x: torch.Tensor) -> torch.Tensor: + n, c, h, w = x.shape + p = self.patch_size + torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!") + torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!") + n_h = h // p + n_w = w // p + + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + return x + + def forward(self, x: torch.Tensor): + # Reshape and permute the input tensor + x = self._process_input(x) + n = x.shape[0] + + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x) + + # Classifier "token" as used by standard language architectures + x = x[:, 0] + + x = self.heads(x) + + return x + + +def _vision_transformer( + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> VisionTransformer: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0]) + image_size = kwargs.pop("image_size", 224) + + model = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + **kwargs, + ) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) + + return model + + +_COMMON_META: Dict[str, Any] = { + "categories": _IMAGENET_CATEGORIES, +} + +_COMMON_SWAG_META = { + **_COMMON_META, + "recipe": "https://github.com/facebookresearch/SWAG", + "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", +} + + +class ViT_B_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 86567656, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.072, + "acc@5": 95.318, + } + }, + "_ops": 17.564, + "_file_size": 330.285, + "_docs": """ + These weights were trained from scratch by using a modified version of `DeIT + `_'s training recipe. + """, + }, + ) + IMAGENET1K_SWAG_E2E_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", + transforms=partial( + ImageClassification, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 86859496, + "min_size": (384, 384), + "_metrics": { + "ImageNet-1K": { + "acc@1": 85.304, + "acc@5": 97.650, + } + }, + "_ops": 55.484, + "_file_size": 331.398, + "_docs": """ + These weights are learnt via transfer learning by end-to-end fine-tuning the original + `SWAG `_ weights on ImageNet-1K data. + """, + }, + ) + IMAGENET1K_SWAG_LINEAR_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth", + transforms=partial( + ImageClassification, + crop_size=224, + resize_size=224, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "recipe": "https://github.com/pytorch/vision/pull/5793", + "num_params": 86567656, + "min_size": (224, 224), + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.886, + "acc@5": 96.180, + } + }, + "_ops": 17.564, + "_file_size": 330.285, + "_docs": """ + These weights are composed of the original frozen `SWAG `_ trunk + weights and a linear classifier learnt on top of them trained on ImageNet-1K data. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_B_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88224232, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", + "_metrics": { + "ImageNet-1K": { + "acc@1": 75.912, + "acc@5": 92.466, + } + }, + "_ops": 4.409, + "_file_size": 336.604, + "_docs": """ + These weights were trained from scratch by using a modified version of `DeIT + `_'s training recipe. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=242), + meta={ + **_COMMON_META, + "num_params": 304326632, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", + "_metrics": { + "ImageNet-1K": { + "acc@1": 79.662, + "acc@5": 94.638, + } + }, + "_ops": 61.555, + "_file_size": 1161.023, + "_docs": """ + These weights were trained from scratch by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + IMAGENET1K_SWAG_E2E_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth", + transforms=partial( + ImageClassification, + crop_size=512, + resize_size=512, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 305174504, + "min_size": (512, 512), + "_metrics": { + "ImageNet-1K": { + "acc@1": 88.064, + "acc@5": 98.512, + } + }, + "_ops": 361.986, + "_file_size": 1164.258, + "_docs": """ + These weights are learnt via transfer learning by end-to-end fine-tuning the original + `SWAG `_ weights on ImageNet-1K data. + """, + }, + ) + IMAGENET1K_SWAG_LINEAR_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth", + transforms=partial( + ImageClassification, + crop_size=224, + resize_size=224, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "recipe": "https://github.com/pytorch/vision/pull/5793", + "num_params": 304326632, + "min_size": (224, 224), + "_metrics": { + "ImageNet-1K": { + "acc@1": 85.146, + "acc@5": 97.422, + } + }, + "_ops": 61.555, + "_file_size": 1161.023, + "_docs": """ + These weights are composed of the original frozen `SWAG `_ trunk + weights and a linear classifier learnt on top of them trained on ImageNet-1K data. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 306535400, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", + "_metrics": { + "ImageNet-1K": { + "acc@1": 76.972, + "acc@5": 93.07, + } + }, + "_ops": 15.378, + "_file_size": 1169.449, + "_docs": """ + These weights were trained from scratch by using a modified version of `DeIT + `_'s training recipe. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_H_14_Weights(WeightsEnum): + IMAGENET1K_SWAG_E2E_V1 = Weights( + url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth", + transforms=partial( + ImageClassification, + crop_size=518, + resize_size=518, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 633470440, + "min_size": (518, 518), + "_metrics": { + "ImageNet-1K": { + "acc@1": 88.552, + "acc@5": 98.694, + } + }, + "_ops": 1016.717, + "_file_size": 2416.643, + "_docs": """ + These weights are learnt via transfer learning by end-to-end fine-tuning the original + `SWAG `_ weights on ImageNet-1K data. + """, + }, + ) + IMAGENET1K_SWAG_LINEAR_V1 = Weights( + url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth", + transforms=partial( + ImageClassification, + crop_size=224, + resize_size=224, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "recipe": "https://github.com/pytorch/vision/pull/5793", + "num_params": 632045800, + "min_size": (224, 224), + "_metrics": { + "ImageNet-1K": { + "acc@1": 85.708, + "acc@5": 97.730, + } + }, + "_ops": 167.295, + "_file_size": 2411.209, + "_docs": """ + These weights are composed of the original frozen `SWAG `_ trunk + weights and a linear classifier learnt on top of them trained on ImageNet-1K data. + """, + }, + ) + DEFAULT = IMAGENET1K_SWAG_E2E_V1 + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) +def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_16 architecture from + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. + + Args: + weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.ViT_B_16_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ViT_B_16_Weights + :members: + """ + weights = ViT_B_16_Weights.verify(weights) + + return _vision_transformer( + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) +def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_b_32 architecture from + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. + + Args: + weights (:class:`~torchvision.models.ViT_B_32_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.ViT_B_32_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ViT_B_32_Weights + :members: + """ + weights = ViT_B_32_Weights.verify(weights) + + return _vision_transformer( + patch_size=32, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) +def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_16 architecture from + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. + + Args: + weights (:class:`~torchvision.models.ViT_L_16_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.ViT_L_16_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ViT_L_16_Weights + :members: + """ + weights = ViT_L_16_Weights.verify(weights) + + return _vision_transformer( + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) +def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_l_32 architecture from + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. + + Args: + weights (:class:`~torchvision.models.ViT_L_32_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.ViT_L_32_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ViT_L_32_Weights + :members: + """ + weights = ViT_L_32_Weights.verify(weights) + + return _vision_transformer( + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights, + progress=progress, + **kwargs, + ) + + +@register_model() +@handle_legacy_interface(weights=("pretrained", None)) +def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_h_14 architecture from + `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `_. + + Args: + weights (:class:`~torchvision.models.ViT_H_14_Weights`, optional): The pretrained + weights to use. See :class:`~torchvision.models.ViT_H_14_Weights` + below for more details and possible values. By default, no pre-trained weights are used. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ViT_H_14_Weights + :members: + """ + weights = ViT_H_14_Weights.verify(weights) + + return _vision_transformer( + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + weights=weights, + progress=progress, + **kwargs, + ) + + +def interpolate_embeddings( + image_size: int, + patch_size: int, + model_state: "OrderedDict[str, torch.Tensor]", + interpolation_mode: str = "bicubic", + reset_heads: bool = False, +) -> "OrderedDict[str, torch.Tensor]": + """This function helps interpolate positional embeddings during checkpoint loading, + especially when you want to apply a pre-trained model on images with different resolution. + + Args: + image_size (int): Image size of the new model. + patch_size (int): Patch size of the new model. + model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. + interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. + reset_heads (bool): If true, not copying the state of heads. Default: False. + + Returns: + OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. + """ + # Shape of pos_embedding is (1, seq_length, hidden_dim) + pos_embedding = model_state["encoder.pos_embedding"] + n, seq_length, hidden_dim = pos_embedding.shape + if n != 1: + raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") + + new_seq_length = (image_size // patch_size) ** 2 + 1 + + # Need to interpolate the weights for the position embedding. + # We do this by reshaping the positions embeddings to a 2d grid, performing + # an interpolation in the (h, w) space and then reshaping back to a 1d grid. + if new_seq_length != seq_length: + # The class token embedding shouldn't be interpolated, so we split it up. + seq_length -= 1 + new_seq_length -= 1 + pos_embedding_token = pos_embedding[:, :1, :] + pos_embedding_img = pos_embedding[:, 1:, :] + + # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) + pos_embedding_img = pos_embedding_img.permute(0, 2, 1) + seq_length_1d = int(math.sqrt(seq_length)) + if seq_length_1d * seq_length_1d != seq_length: + raise ValueError( + f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}" + ) + + # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) + pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) + new_seq_length_1d = image_size // patch_size + + # Perform interpolation. + # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) + new_pos_embedding_img = nn.functional.interpolate( + pos_embedding_img, + size=new_seq_length_1d, + mode=interpolation_mode, + align_corners=True, + ) + + # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) + new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) + + # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) + new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) + new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) + + model_state["encoder.pos_embedding"] = new_pos_embedding + + if reset_heads: + model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() + for k, v in model_state.items(): + if not k.startswith("heads"): + model_state_copy[k] = v + model_state = model_state_copy + + return model_state