Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- src/custom_timm/__pycache__/__init__.cpython-312.pyc +0 -0
- src/custom_timm/__pycache__/version.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/__init__.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/config.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/constants.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/dataset.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/loader.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/mixup.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/transforms.cpython-312.pyc +0 -0
- src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc +0 -0
- src/custom_timm/data/parsers/__init__.py +2 -0
- src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc +0 -0
- src/custom_timm/data/parsers/class_map.py +22 -0
- src/custom_timm/data/parsers/img_extensions.py +50 -0
- src/custom_timm/data/parsers/parser.py +17 -0
- src/custom_timm/data/parsers/parser_factory.py +28 -0
- src/custom_timm/data/parsers/parser_image_folder.py +90 -0
- src/custom_timm/data/parsers/parser_image_in_tar.py +229 -0
- src/custom_timm/data/parsers/parser_image_tar.py +74 -0
- src/custom_timm/data/parsers/parser_tfds.py +301 -0
- src/custom_timm/models/gluon_resnet.py +245 -0
- src/custom_timm/models/gluon_xception.py +267 -0
- src/custom_timm/models/hardcorenas.py +151 -0
- src/custom_timm/models/helpers.py +796 -0
- src/custom_timm/models/hrnet.py +858 -0
- src/custom_timm/models/hub.py +170 -0
- src/custom_timm/models/inception_resnet_v2.py +382 -0
- src/custom_timm/models/inception_v3.py +475 -0
- src/custom_timm/models/inception_v4.py +330 -0
- src/custom_timm/models/levit.py +592 -0
- src/custom_timm/optim/__init__.py +15 -0
- src/custom_timm/optim/adabelief.py +201 -0
- src/custom_timm/optim/adafactor.py +167 -0
- src/custom_timm/optim/adahessian.py +156 -0
- src/custom_timm/optim/adamp.py +105 -0
- src/custom_timm/optim/adamw.py +122 -0
- src/custom_timm/optim/lamb.py +192 -0
- src/custom_timm/optim/lars.py +135 -0
- src/custom_timm/optim/lookahead.py +61 -0
- src/custom_timm/optim/madgrad.py +184 -0
- src/custom_timm/optim/nadam.py +92 -0
- src/custom_timm/optim/nvnovograd.py +120 -0
- src/custom_timm/optim/optim_factory.py +340 -0
- src/custom_timm/optim/radam.py +89 -0
- src/custom_timm/optim/rmsprop_tf.py +139 -0
src/custom_timm/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (674 Bytes). View file
|
|
|
src/custom_timm/__pycache__/version.cpython-312.pyc
ADDED
|
Binary file (274 Bytes). View file
|
|
|
src/custom_timm/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc
ADDED
|
Binary file (35.2 kB). View file
|
|
|
src/custom_timm/data/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
src/custom_timm/data/__pycache__/constants.cpython-312.pyc
ADDED
|
Binary file (754 Bytes). View file
|
|
|
src/custom_timm/data/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (7.84 kB). View file
|
|
|
src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc
ADDED
|
Binary file (5.98 kB). View file
|
|
|
src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc
ADDED
|
Binary file (7.33 kB). View file
|
|
|
src/custom_timm/data/__pycache__/loader.cpython-312.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
src/custom_timm/data/__pycache__/mixup.cpython-312.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc
ADDED
|
Binary file (6.36 kB). View file
|
|
|
src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
src/custom_timm/data/__pycache__/transforms.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc
ADDED
|
Binary file (7.82 kB). View file
|
|
|
src/custom_timm/data/parsers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .parser_factory import create_parser
|
| 2 |
+
from .img_extensions import *
|
src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
src/custom_timm/data/parsers/class_map.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
def load_class_map(map_or_filename, root=''):
|
| 5 |
+
if isinstance(map_or_filename, dict):
|
| 6 |
+
assert dict, 'class_map dict must be non-empty'
|
| 7 |
+
return map_or_filename
|
| 8 |
+
class_map_path = map_or_filename
|
| 9 |
+
if not os.path.exists(class_map_path):
|
| 10 |
+
class_map_path = os.path.join(root, class_map_path)
|
| 11 |
+
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
|
| 12 |
+
class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
|
| 13 |
+
if class_map_ext == '.txt':
|
| 14 |
+
with open(class_map_path) as f:
|
| 15 |
+
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
| 16 |
+
elif class_map_ext == '.pkl':
|
| 17 |
+
with open(class_map_path,'rb') as f:
|
| 18 |
+
class_to_idx = pickle.load(f)
|
| 19 |
+
else:
|
| 20 |
+
assert False, f'Unsupported class map file extension ({class_map_ext}).'
|
| 21 |
+
return class_to_idx
|
| 22 |
+
|
src/custom_timm/data/parsers/img_extensions.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
|
| 7 |
+
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _set_extensions(extensions):
|
| 11 |
+
global IMG_EXTENSIONS
|
| 12 |
+
global _IMG_EXTENSIONS_SET
|
| 13 |
+
dedupe = set() # NOTE de-duping tuple while keeping original order
|
| 14 |
+
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
|
| 15 |
+
_IMG_EXTENSIONS_SET = set(extensions)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _valid_extension(x: str):
|
| 19 |
+
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_img_extension(ext):
|
| 23 |
+
return ext in _IMG_EXTENSIONS_SET
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_img_extensions(as_set=False):
|
| 27 |
+
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def set_img_extensions(extensions):
|
| 31 |
+
assert len(extensions)
|
| 32 |
+
for x in extensions:
|
| 33 |
+
assert _valid_extension(x)
|
| 34 |
+
_set_extensions(extensions)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def add_img_extensions(ext):
|
| 38 |
+
if not isinstance(ext, (list, tuple, set)):
|
| 39 |
+
ext = (ext,)
|
| 40 |
+
for x in ext:
|
| 41 |
+
assert _valid_extension(x)
|
| 42 |
+
extensions = IMG_EXTENSIONS + tuple(ext)
|
| 43 |
+
_set_extensions(extensions)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def del_img_extensions(ext):
|
| 47 |
+
if not isinstance(ext, (list, tuple, set)):
|
| 48 |
+
ext = (ext,)
|
| 49 |
+
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
|
| 50 |
+
_set_extensions(extensions)
|
src/custom_timm/data/parsers/parser.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Parser:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
@abstractmethod
|
| 9 |
+
def _filename(self, index, basename=False, absolute=False):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def filename(self, index, basename=False, absolute=False):
|
| 13 |
+
return self._filename(index, basename=basename, absolute=absolute)
|
| 14 |
+
|
| 15 |
+
def filenames(self, basename=False, absolute=False):
|
| 16 |
+
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
|
| 17 |
+
|
src/custom_timm/data/parsers/parser_factory.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from .parser_image_folder import ParserImageFolder
|
| 4 |
+
from .parser_image_in_tar import ParserImageInTar
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def create_parser(name, root, split='train', **kwargs):
|
| 8 |
+
name = name.lower()
|
| 9 |
+
name = name.split('/', 2)
|
| 10 |
+
prefix = ''
|
| 11 |
+
if len(name) > 1:
|
| 12 |
+
prefix = name[0]
|
| 13 |
+
name = name[-1]
|
| 14 |
+
|
| 15 |
+
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
| 16 |
+
# explicitly select other options shortly
|
| 17 |
+
if prefix == 'tfds':
|
| 18 |
+
from .parser_tfds import ParserTfds # defer tensorflow import
|
| 19 |
+
parser = ParserTfds(root, name, split=split, **kwargs)
|
| 20 |
+
else:
|
| 21 |
+
assert os.path.exists(root)
|
| 22 |
+
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
| 23 |
+
# FIXME support split here, in parser?
|
| 24 |
+
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
| 25 |
+
parser = ParserImageInTar(root, **kwargs)
|
| 26 |
+
else:
|
| 27 |
+
parser = ParserImageFolder(root, **kwargs)
|
| 28 |
+
return parser
|
src/custom_timm/data/parsers/parser_image_folder.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" A dataset parser that reads images from folders
|
| 2 |
+
|
| 3 |
+
Folders are scannerd recursively to find image files. Labels are based
|
| 4 |
+
on the folder hierarchy, just leaf folders by default.
|
| 5 |
+
|
| 6 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
| 10 |
+
|
| 11 |
+
from custom_timm.utils.misc import natural_key
|
| 12 |
+
|
| 13 |
+
from .class_map import load_class_map
|
| 14 |
+
from .img_extensions import get_img_extensions
|
| 15 |
+
from .parser import Parser
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_images_and_targets(
|
| 19 |
+
folder: str,
|
| 20 |
+
types: Optional[Union[List, Tuple, Set]] = None,
|
| 21 |
+
class_to_idx: Optional[Dict] = None,
|
| 22 |
+
leaf_name_only: bool = True,
|
| 23 |
+
sort: bool = True
|
| 24 |
+
):
|
| 25 |
+
""" Walk folder recursively to discover images and map them to classes by folder names.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
folder: root of folder to recrusively search
|
| 29 |
+
types: types (file extensions) to search for in path
|
| 30 |
+
class_to_idx: specify mapping for class (folder name) to class index if set
|
| 31 |
+
leaf_name_only: use only leaf-name of folder walk for class names
|
| 32 |
+
sort: re-sort found images by name (for consistent ordering)
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
A list of image and target tuples, class_to_idx mapping
|
| 36 |
+
"""
|
| 37 |
+
types = get_img_extensions(as_set=True) if not types else set(types)
|
| 38 |
+
labels = []
|
| 39 |
+
filenames = []
|
| 40 |
+
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
| 41 |
+
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
| 42 |
+
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
| 43 |
+
for f in files:
|
| 44 |
+
base, ext = os.path.splitext(f)
|
| 45 |
+
if ext.lower() in types:
|
| 46 |
+
filenames.append(os.path.join(root, f))
|
| 47 |
+
labels.append(label)
|
| 48 |
+
if class_to_idx is None:
|
| 49 |
+
# building class index
|
| 50 |
+
unique_labels = set(labels)
|
| 51 |
+
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
| 52 |
+
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
| 53 |
+
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
| 54 |
+
if sort:
|
| 55 |
+
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
| 56 |
+
return images_and_targets, class_to_idx
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ParserImageFolder(Parser):
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
root,
|
| 64 |
+
class_map=''):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.root = root
|
| 68 |
+
class_to_idx = None
|
| 69 |
+
if class_map:
|
| 70 |
+
class_to_idx = load_class_map(class_map, root)
|
| 71 |
+
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
| 72 |
+
if len(self.samples) == 0:
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
f'Found 0 images in subfolders of {root}. '
|
| 75 |
+
f'Supported image extensions are {", ".join(get_img_extensions())}')
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, index):
|
| 78 |
+
path, target = self.samples[index]
|
| 79 |
+
return open(path, 'rb'), target
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.samples)
|
| 83 |
+
|
| 84 |
+
def _filename(self, index, basename=False, absolute=False):
|
| 85 |
+
filename = self.samples[index][0]
|
| 86 |
+
if basename:
|
| 87 |
+
filename = os.path.basename(filename)
|
| 88 |
+
elif not absolute:
|
| 89 |
+
filename = os.path.relpath(filename, self.root)
|
| 90 |
+
return filename
|
src/custom_timm/data/parsers/parser_image_in_tar.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" A dataset parser that reads tarfile based datasets
|
| 2 |
+
|
| 3 |
+
This parser can read and extract image samples from:
|
| 4 |
+
* a single tar of image files
|
| 5 |
+
* a folder of multiple tarfiles containing imagefiles
|
| 6 |
+
* a tar of tars containing image files
|
| 7 |
+
|
| 8 |
+
Labels are based on the combined folder and/or tar name structure.
|
| 9 |
+
|
| 10 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 11 |
+
"""
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import pickle
|
| 15 |
+
import tarfile
|
| 16 |
+
from glob import glob
|
| 17 |
+
from typing import List, Tuple, Dict, Set, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from custom_timm.utils.misc import natural_key
|
| 22 |
+
|
| 23 |
+
from .class_map import load_class_map
|
| 24 |
+
from .img_extensions import get_img_extensions
|
| 25 |
+
from .parser import Parser
|
| 26 |
+
|
| 27 |
+
_logger = logging.getLogger(__name__)
|
| 28 |
+
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TarState:
|
| 32 |
+
|
| 33 |
+
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
|
| 34 |
+
self.tf: tarfile.TarFile = tf
|
| 35 |
+
self.ti: tarfile.TarInfo = ti
|
| 36 |
+
self.children: Dict[str, TarState] = {} # child states (tars within tars)
|
| 37 |
+
|
| 38 |
+
def reset(self):
|
| 39 |
+
self.tf = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
|
| 43 |
+
sample_count = 0
|
| 44 |
+
for i, ti in enumerate(tf):
|
| 45 |
+
if not ti.isfile():
|
| 46 |
+
continue
|
| 47 |
+
dirname, basename = os.path.split(ti.path)
|
| 48 |
+
name, ext = os.path.splitext(basename)
|
| 49 |
+
ext = ext.lower()
|
| 50 |
+
if ext == '.tar':
|
| 51 |
+
with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
|
| 52 |
+
child_info = dict(
|
| 53 |
+
name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
|
| 54 |
+
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
|
| 55 |
+
_logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
|
| 56 |
+
parent_info['children'].append(child_info)
|
| 57 |
+
elif ext in extensions:
|
| 58 |
+
parent_info['samples'].append(ti)
|
| 59 |
+
sample_count += 1
|
| 60 |
+
return sample_count
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def extract_tarinfos(
|
| 64 |
+
root,
|
| 65 |
+
class_name_to_idx: Optional[Dict] = None,
|
| 66 |
+
cache_tarinfo: Optional[bool] = None,
|
| 67 |
+
extensions: Optional[Union[List, Tuple, Set]] = None,
|
| 68 |
+
sort: bool = True
|
| 69 |
+
):
|
| 70 |
+
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
|
| 71 |
+
root_is_tar = False
|
| 72 |
+
if os.path.isfile(root):
|
| 73 |
+
assert os.path.splitext(root)[-1].lower() == '.tar'
|
| 74 |
+
tar_filenames = [root]
|
| 75 |
+
root, root_name = os.path.split(root)
|
| 76 |
+
root_name = os.path.splitext(root_name)[0]
|
| 77 |
+
root_is_tar = True
|
| 78 |
+
else:
|
| 79 |
+
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
|
| 80 |
+
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
|
| 81 |
+
num_tars = len(tar_filenames)
|
| 82 |
+
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
|
| 83 |
+
assert num_tars, f'No .tar files found at specified path ({root}).'
|
| 84 |
+
|
| 85 |
+
_logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
|
| 86 |
+
info = dict(tartrees=[])
|
| 87 |
+
cache_path = ''
|
| 88 |
+
if cache_tarinfo is None:
|
| 89 |
+
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
|
| 90 |
+
if cache_tarinfo:
|
| 91 |
+
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
|
| 92 |
+
cache_path = os.path.join(root, cache_filename)
|
| 93 |
+
if os.path.exists(cache_path):
|
| 94 |
+
_logger.info(f'Reading tar info from cache file {cache_path}.')
|
| 95 |
+
with open(cache_path, 'rb') as pf:
|
| 96 |
+
info = pickle.load(pf)
|
| 97 |
+
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
|
| 98 |
+
else:
|
| 99 |
+
for i, fn in enumerate(tar_filenames):
|
| 100 |
+
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
|
| 101 |
+
with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
|
| 102 |
+
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
|
| 103 |
+
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
|
| 104 |
+
num_children = len(parent_info["children"])
|
| 105 |
+
_logger.debug(
|
| 106 |
+
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
|
| 107 |
+
info['tartrees'].append(parent_info)
|
| 108 |
+
if cache_path:
|
| 109 |
+
_logger.info(f'Writing tar info to cache file {cache_path}.')
|
| 110 |
+
with open(cache_path, 'wb') as pf:
|
| 111 |
+
pickle.dump(info, pf)
|
| 112 |
+
|
| 113 |
+
samples = []
|
| 114 |
+
labels = []
|
| 115 |
+
build_class_map = False
|
| 116 |
+
if class_name_to_idx is None:
|
| 117 |
+
build_class_map = True
|
| 118 |
+
|
| 119 |
+
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
|
| 120 |
+
# class map arg or from unique paths.
|
| 121 |
+
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
|
| 122 |
+
# this covers my current use cases and keeps things a little easier to test for now.
|
| 123 |
+
tarfiles = []
|
| 124 |
+
|
| 125 |
+
def _label_from_paths(*path, leaf_only=True):
|
| 126 |
+
path = os.path.join(*path).strip(os.path.sep)
|
| 127 |
+
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
|
| 128 |
+
|
| 129 |
+
def _add_samples(info, fn):
|
| 130 |
+
added = 0
|
| 131 |
+
for s in info['samples']:
|
| 132 |
+
label = _label_from_paths(info['path'], os.path.dirname(s.path))
|
| 133 |
+
if not build_class_map and label not in class_name_to_idx:
|
| 134 |
+
continue
|
| 135 |
+
samples.append((s, fn, info['ti']))
|
| 136 |
+
labels.append(label)
|
| 137 |
+
added += 1
|
| 138 |
+
return added
|
| 139 |
+
|
| 140 |
+
_logger.info(f'Collecting samples and building tar states.')
|
| 141 |
+
for parent_info in info['tartrees']:
|
| 142 |
+
# if tartree has children, we assume all samples are at the child level
|
| 143 |
+
tar_name = None if root_is_tar else parent_info['name']
|
| 144 |
+
tar_state = TarState()
|
| 145 |
+
parent_added = 0
|
| 146 |
+
for child_info in parent_info['children']:
|
| 147 |
+
child_added = _add_samples(child_info, fn=tar_name)
|
| 148 |
+
if child_added:
|
| 149 |
+
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
|
| 150 |
+
parent_added += child_added
|
| 151 |
+
parent_added += _add_samples(parent_info, fn=tar_name)
|
| 152 |
+
if parent_added:
|
| 153 |
+
tarfiles.append((tar_name, tar_state))
|
| 154 |
+
del info
|
| 155 |
+
|
| 156 |
+
if build_class_map:
|
| 157 |
+
# build class index
|
| 158 |
+
sorted_labels = list(sorted(set(labels), key=natural_key))
|
| 159 |
+
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
| 160 |
+
|
| 161 |
+
_logger.info(f'Mapping targets and sorting samples.')
|
| 162 |
+
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
|
| 163 |
+
if sort:
|
| 164 |
+
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
|
| 165 |
+
samples, targets = zip(*samples_and_targets)
|
| 166 |
+
samples = np.array(samples)
|
| 167 |
+
targets = np.array(targets)
|
| 168 |
+
_logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
|
| 169 |
+
return samples, targets, class_name_to_idx, tarfiles
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ParserImageInTar(Parser):
|
| 173 |
+
""" Multi-tarfile dataset parser where there is one .tar file per class
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
class_name_to_idx = None
|
| 180 |
+
if class_map:
|
| 181 |
+
class_name_to_idx = load_class_map(class_map, root)
|
| 182 |
+
self.root = root
|
| 183 |
+
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
| 184 |
+
self.root,
|
| 185 |
+
class_name_to_idx=class_name_to_idx,
|
| 186 |
+
cache_tarinfo=cache_tarinfo
|
| 187 |
+
)
|
| 188 |
+
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
| 189 |
+
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
| 190 |
+
self.root_is_tar = True
|
| 191 |
+
self.tar_state = tarfiles[0][1]
|
| 192 |
+
else:
|
| 193 |
+
self.root_is_tar = False
|
| 194 |
+
self.tar_state = dict(tarfiles)
|
| 195 |
+
self.cache_tarfiles = cache_tarfiles
|
| 196 |
+
|
| 197 |
+
def __len__(self):
|
| 198 |
+
return len(self.samples)
|
| 199 |
+
|
| 200 |
+
def __getitem__(self, index):
|
| 201 |
+
sample = self.samples[index]
|
| 202 |
+
target = self.targets[index]
|
| 203 |
+
sample_ti, parent_fn, child_ti = sample
|
| 204 |
+
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
|
| 205 |
+
|
| 206 |
+
tf = None
|
| 207 |
+
cache_state = None
|
| 208 |
+
if self.cache_tarfiles:
|
| 209 |
+
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
|
| 210 |
+
tf = cache_state.tf
|
| 211 |
+
if tf is None:
|
| 212 |
+
tf = tarfile.open(parent_abs)
|
| 213 |
+
if self.cache_tarfiles:
|
| 214 |
+
cache_state.tf = tf
|
| 215 |
+
if child_ti is not None:
|
| 216 |
+
ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
|
| 217 |
+
if ctf is None:
|
| 218 |
+
ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
|
| 219 |
+
if self.cache_tarfiles:
|
| 220 |
+
cache_state.children[child_ti.name].tf = ctf
|
| 221 |
+
tf = ctf
|
| 222 |
+
|
| 223 |
+
return tf.extractfile(sample_ti), target
|
| 224 |
+
|
| 225 |
+
def _filename(self, index, basename=False, absolute=False):
|
| 226 |
+
filename = self.samples[index][0].name
|
| 227 |
+
if basename:
|
| 228 |
+
filename = os.path.basename(filename)
|
| 229 |
+
return filename
|
src/custom_timm/data/parsers/parser_image_tar.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" A dataset parser that reads single tarfile based datasets
|
| 2 |
+
|
| 3 |
+
This parser can read datasets consisting if a single tarfile containing images.
|
| 4 |
+
I am planning to deprecated it in favour of ParerImageInTar.
|
| 5 |
+
|
| 6 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import tarfile
|
| 10 |
+
|
| 11 |
+
from custom_timm.utils.misc import natural_key
|
| 12 |
+
|
| 13 |
+
from .class_map import load_class_map
|
| 14 |
+
from .img_extensions import get_img_extensions
|
| 15 |
+
from .parser import Parser
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
| 19 |
+
extensions = get_img_extensions(as_set=True)
|
| 20 |
+
files = []
|
| 21 |
+
labels = []
|
| 22 |
+
for ti in tarfile.getmembers():
|
| 23 |
+
if not ti.isfile():
|
| 24 |
+
continue
|
| 25 |
+
dirname, basename = os.path.split(ti.path)
|
| 26 |
+
label = os.path.basename(dirname)
|
| 27 |
+
ext = os.path.splitext(basename)[1]
|
| 28 |
+
if ext.lower() in extensions:
|
| 29 |
+
files.append(ti)
|
| 30 |
+
labels.append(label)
|
| 31 |
+
if class_to_idx is None:
|
| 32 |
+
unique_labels = set(labels)
|
| 33 |
+
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
| 34 |
+
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
| 35 |
+
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
| 36 |
+
if sort:
|
| 37 |
+
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
| 38 |
+
return tarinfo_and_targets, class_to_idx
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ParserImageTar(Parser):
|
| 42 |
+
""" Single tarfile dataset where classes are mapped to folders within tar
|
| 43 |
+
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
| 44 |
+
operate on folders of tars or tars in tars.
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, root, class_map=''):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
class_to_idx = None
|
| 50 |
+
if class_map:
|
| 51 |
+
class_to_idx = load_class_map(class_map, root)
|
| 52 |
+
assert os.path.isfile(root)
|
| 53 |
+
self.root = root
|
| 54 |
+
|
| 55 |
+
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
| 56 |
+
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
|
| 57 |
+
self.imgs = self.samples
|
| 58 |
+
self.tarfile = None # lazy init in __getitem__
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, index):
|
| 61 |
+
if self.tarfile is None:
|
| 62 |
+
self.tarfile = tarfile.open(self.root)
|
| 63 |
+
tarinfo, target = self.samples[index]
|
| 64 |
+
fileobj = self.tarfile.extractfile(tarinfo)
|
| 65 |
+
return fileobj, target
|
| 66 |
+
|
| 67 |
+
def __len__(self):
|
| 68 |
+
return len(self.samples)
|
| 69 |
+
|
| 70 |
+
def _filename(self, index, basename=False, absolute=False):
|
| 71 |
+
filename = self.samples[index][0].name
|
| 72 |
+
if basename:
|
| 73 |
+
filename = os.path.basename(filename)
|
| 74 |
+
return filename
|
src/custom_timm/data/parsers/parser_tfds.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Dataset parser interface that wraps TFDS datasets
|
| 2 |
+
|
| 3 |
+
Wraps many (most?) TFDS image-classification datasets
|
| 4 |
+
from https://github.com/tensorflow/datasets
|
| 5 |
+
https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
| 6 |
+
|
| 7 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 8 |
+
"""
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import tensorflow as tf
|
| 16 |
+
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
| 17 |
+
import tensorflow_datasets as tfds
|
| 18 |
+
try:
|
| 19 |
+
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
|
| 20 |
+
has_buggy_even_splits = False
|
| 21 |
+
except TypeError:
|
| 22 |
+
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
|
| 23 |
+
"Please update or use tfds-nightly for better fine-grained split behaviour.")
|
| 24 |
+
has_buggy_even_splits = True
|
| 25 |
+
# NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
|
| 26 |
+
# import resource
|
| 27 |
+
# low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
|
| 28 |
+
# resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
|
| 29 |
+
except ImportError as e:
|
| 30 |
+
print(e)
|
| 31 |
+
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
| 32 |
+
exit(1)
|
| 33 |
+
from .parser import Parser
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
| 37 |
+
SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
|
| 38 |
+
PREFETCH_SIZE = 2048 # examples to prefetch
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def even_split_indices(split, n, num_examples):
|
| 42 |
+
partitions = [round(i * num_examples / n) for i in range(n + 1)]
|
| 43 |
+
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_class_labels(info):
|
| 47 |
+
if 'label' not in info.features:
|
| 48 |
+
return {}
|
| 49 |
+
class_label = info.features['label']
|
| 50 |
+
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
| 51 |
+
return class_to_idx
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ParserTfds(Parser):
|
| 55 |
+
""" Wrap Tensorflow Datasets for use in PyTorch
|
| 56 |
+
|
| 57 |
+
There several things to be aware of:
|
| 58 |
+
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
|
| 59 |
+
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
|
| 60 |
+
https://github.com/pytorch/pytorch/issues/33413
|
| 61 |
+
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
|
| 62 |
+
from each worker could be a different size. For training this is worked around by option above, for
|
| 63 |
+
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
|
| 64 |
+
across replicas are of same size. This will slightly alter the results, distributed validation will not be
|
| 65 |
+
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
|
| 66 |
+
since there are up to N * J extra examples with IterableDatasets.
|
| 67 |
+
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
|
| 68 |
+
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
|
| 69 |
+
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
|
| 70 |
+
benefit of distributed training or fast dataloading should be much less for small datasets.
|
| 71 |
+
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS
|
| 72 |
+
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
|
| 73 |
+
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
|
| 74 |
+
components.
|
| 75 |
+
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
root,
|
| 81 |
+
name,
|
| 82 |
+
split='train',
|
| 83 |
+
is_training=False,
|
| 84 |
+
batch_size=None,
|
| 85 |
+
download=False,
|
| 86 |
+
repeats=0,
|
| 87 |
+
seed=42,
|
| 88 |
+
input_name='image',
|
| 89 |
+
input_image='RGB',
|
| 90 |
+
target_name='label',
|
| 91 |
+
target_image='',
|
| 92 |
+
prefetch_size=None,
|
| 93 |
+
shuffle_size=None,
|
| 94 |
+
max_threadpool_size=None
|
| 95 |
+
):
|
| 96 |
+
""" Tensorflow-datasets Wrapper
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
|
| 100 |
+
name: tfds dataset name (eg `imagenet2012`)
|
| 101 |
+
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
|
| 102 |
+
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
|
| 103 |
+
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
|
| 104 |
+
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
| 105 |
+
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
| 106 |
+
seed: common seed for shard shuffle across all distributed/worker instances
|
| 107 |
+
input_name: name of Feature to return as data (input)
|
| 108 |
+
input_image: image mode if input is an image (currently PIL mode string)
|
| 109 |
+
target_name: name of Feature to return as target (label)
|
| 110 |
+
target_image: image mode if target is an image (currently PIL mode string)
|
| 111 |
+
prefetch_size: override default tf.data prefetch buffer size
|
| 112 |
+
shuffle_size: override default tf.data shuffle buffer size
|
| 113 |
+
max_threadpool_size: override default threadpool size for tf.data
|
| 114 |
+
"""
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.root = root
|
| 117 |
+
self.split = split
|
| 118 |
+
self.is_training = is_training
|
| 119 |
+
if self.is_training:
|
| 120 |
+
assert batch_size is not None, \
|
| 121 |
+
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
| 122 |
+
self.batch_size = batch_size
|
| 123 |
+
self.repeats = repeats
|
| 124 |
+
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
| 125 |
+
|
| 126 |
+
# performance settings
|
| 127 |
+
self.prefetch_size = prefetch_size or PREFETCH_SIZE
|
| 128 |
+
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
|
| 129 |
+
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
|
| 130 |
+
|
| 131 |
+
# TFDS builder and split information
|
| 132 |
+
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
|
| 133 |
+
self.input_image = input_image
|
| 134 |
+
self.target_name = target_name
|
| 135 |
+
self.target_image = target_image
|
| 136 |
+
self.builder = tfds.builder(name, data_dir=root)
|
| 137 |
+
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
| 138 |
+
if download:
|
| 139 |
+
self.builder.download_and_prepare()
|
| 140 |
+
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
| 141 |
+
self.split_info = self.builder.info.splits[split]
|
| 142 |
+
self.num_examples = self.split_info.num_examples
|
| 143 |
+
|
| 144 |
+
# Distributed world state
|
| 145 |
+
self.dist_rank = 0
|
| 146 |
+
self.dist_num_replicas = 1
|
| 147 |
+
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
| 148 |
+
self.dist_rank = dist.get_rank()
|
| 149 |
+
self.dist_num_replicas = dist.get_world_size()
|
| 150 |
+
|
| 151 |
+
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
|
| 152 |
+
self.global_num_workers = 1
|
| 153 |
+
self.worker_info = None
|
| 154 |
+
self.worker_seed = 0 # seed unique to each work instance
|
| 155 |
+
self.subsplit = None # set when data is distributed across workers using sub-splits
|
| 156 |
+
self.ds = None # initialized lazily on each dataloader worker process
|
| 157 |
+
|
| 158 |
+
def _lazy_init(self):
|
| 159 |
+
""" Lazily initialize the dataset.
|
| 160 |
+
|
| 161 |
+
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
|
| 162 |
+
will be using the dataset instance. The __init__ method is called on the main process,
|
| 163 |
+
this will be called in a dataloader worker process.
|
| 164 |
+
|
| 165 |
+
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
|
| 166 |
+
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
|
| 167 |
+
before it is passed to dataloader.
|
| 168 |
+
"""
|
| 169 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 170 |
+
|
| 171 |
+
# setup input context to split dataset across distributed processes
|
| 172 |
+
num_workers = 1
|
| 173 |
+
global_worker_id = 0
|
| 174 |
+
if worker_info is not None:
|
| 175 |
+
self.worker_info = worker_info
|
| 176 |
+
self.worker_seed = worker_info.seed
|
| 177 |
+
num_workers = worker_info.num_workers
|
| 178 |
+
self.global_num_workers = self.dist_num_replicas * num_workers
|
| 179 |
+
global_worker_id = self.dist_rank * num_workers + worker_info.id
|
| 180 |
+
|
| 181 |
+
""" Data sharding
|
| 182 |
+
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
| 183 |
+
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
|
| 184 |
+
between the splits each iteration, but that understanding could be wrong.
|
| 185 |
+
|
| 186 |
+
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
|
| 187 |
+
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
|
| 188 |
+
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
|
| 189 |
+
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
|
| 190 |
+
"""
|
| 191 |
+
should_subsplit = self.global_num_workers > 1 and (
|
| 192 |
+
self.split_info.num_shards < self.global_num_workers or not self.is_training)
|
| 193 |
+
if should_subsplit:
|
| 194 |
+
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal
|
| 195 |
+
# read patterns for distributed training (overlap across shards) so better to use InputContext there
|
| 196 |
+
if has_buggy_even_splits:
|
| 197 |
+
# my even_split workaround doesn't work on subsplits, upgrade tfds!
|
| 198 |
+
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
|
| 199 |
+
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
|
| 200 |
+
self.subsplit = subsplits[global_worker_id]
|
| 201 |
+
else:
|
| 202 |
+
subsplits = tfds.even_splits(self.split, self.global_num_workers)
|
| 203 |
+
self.subsplit = subsplits[global_worker_id]
|
| 204 |
+
|
| 205 |
+
input_context = None
|
| 206 |
+
if self.global_num_workers > 1 and self.subsplit is None:
|
| 207 |
+
# set input context to divide shards among distributed replicas
|
| 208 |
+
input_context = tf.distribute.InputContext(
|
| 209 |
+
num_input_pipelines=self.global_num_workers,
|
| 210 |
+
input_pipeline_id=global_worker_id,
|
| 211 |
+
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
| 212 |
+
)
|
| 213 |
+
read_config = tfds.ReadConfig(
|
| 214 |
+
shuffle_seed=self.common_seed,
|
| 215 |
+
shuffle_reshuffle_each_iteration=True,
|
| 216 |
+
input_context=input_context)
|
| 217 |
+
ds = self.builder.as_dataset(
|
| 218 |
+
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
|
| 219 |
+
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
| 220 |
+
options = tf.data.Options()
|
| 221 |
+
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
| 222 |
+
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
|
| 223 |
+
getattr(options, thread_member).max_intra_op_parallelism = 1
|
| 224 |
+
ds = ds.with_options(options)
|
| 225 |
+
if self.is_training or self.repeats > 1:
|
| 226 |
+
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
| 227 |
+
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
| 228 |
+
ds = ds.repeat() # allow wrap around and break iteration manually
|
| 229 |
+
if self.is_training:
|
| 230 |
+
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
| 231 |
+
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
|
| 232 |
+
self.ds = tfds.as_numpy(ds)
|
| 233 |
+
|
| 234 |
+
def __iter__(self):
|
| 235 |
+
if self.ds is None:
|
| 236 |
+
self._lazy_init()
|
| 237 |
+
|
| 238 |
+
# Compute a rounded up sample count that is used to:
|
| 239 |
+
# 1. make batches even cross workers & replicas in distributed validation.
|
| 240 |
+
# This adds extra examples and will slightly alter validation results.
|
| 241 |
+
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
| 242 |
+
# batches are produced (underlying tfds iter wraps around)
|
| 243 |
+
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
|
| 244 |
+
if self.is_training:
|
| 245 |
+
# round up to nearest batch_size per worker-replica
|
| 246 |
+
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
|
| 247 |
+
|
| 248 |
+
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
| 249 |
+
example_count = 0
|
| 250 |
+
for example in self.ds:
|
| 251 |
+
input_data = example[self.input_name]
|
| 252 |
+
if self.input_image:
|
| 253 |
+
input_data = Image.fromarray(input_data, mode=self.input_image)
|
| 254 |
+
target_data = example[self.target_name]
|
| 255 |
+
if self.target_image:
|
| 256 |
+
target_data = Image.fromarray(target_data, mode=self.target_image)
|
| 257 |
+
yield input_data, target_data
|
| 258 |
+
example_count += 1
|
| 259 |
+
if self.is_training and example_count >= target_example_count:
|
| 260 |
+
# Need to break out of loop when repeat() is enabled for training w/ oversampling
|
| 261 |
+
# this results in extra examples per epoch but seems more desirable than dropping
|
| 262 |
+
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
| 263 |
+
break
|
| 264 |
+
|
| 265 |
+
# Pad across distributed nodes (make counts equal by adding examples)
|
| 266 |
+
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
|
| 267 |
+
0 < example_count < target_example_count:
|
| 268 |
+
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
| 269 |
+
# For single process case, it won't matter if workers return different batch sizes.
|
| 270 |
+
# If using input_context or % based splits, sample count can vary significantly across workers and this
|
| 271 |
+
# approach should not be used (hence disabled if self.subsplit isn't set).
|
| 272 |
+
while example_count < target_example_count:
|
| 273 |
+
yield input_data, target_data # yield prev sample again
|
| 274 |
+
example_count += 1
|
| 275 |
+
|
| 276 |
+
def __len__(self):
|
| 277 |
+
# this is just an estimate and does not factor in extra examples added to pad batches based on
|
| 278 |
+
# complete worker & replica info (not available until init in dataloader).
|
| 279 |
+
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
|
| 280 |
+
|
| 281 |
+
def _filename(self, index, basename=False, absolute=False):
|
| 282 |
+
assert False, "Not supported" # no random access to examples
|
| 283 |
+
|
| 284 |
+
def filenames(self, basename=False, absolute=False):
|
| 285 |
+
""" Return all filenames in dataset, overrides base"""
|
| 286 |
+
if self.ds is None:
|
| 287 |
+
self._lazy_init()
|
| 288 |
+
names = []
|
| 289 |
+
for sample in self.ds:
|
| 290 |
+
if len(names) > self.num_examples:
|
| 291 |
+
break # safety for ds.repeat() case
|
| 292 |
+
if 'file_name' in sample:
|
| 293 |
+
name = sample['file_name']
|
| 294 |
+
elif 'filename' in sample:
|
| 295 |
+
name = sample['filename']
|
| 296 |
+
elif 'id' in sample:
|
| 297 |
+
name = sample['id']
|
| 298 |
+
else:
|
| 299 |
+
assert False, "No supported name field present"
|
| 300 |
+
names.append(name)
|
| 301 |
+
return names
|
src/custom_timm/models/gluon_resnet.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants
|
| 2 |
+
This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions
|
| 3 |
+
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
|
| 4 |
+
by Ross Wightman
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 8 |
+
from .helpers import build_model_with_cfg
|
| 9 |
+
from .layers import SEModule
|
| 10 |
+
from .registry import register_model
|
| 11 |
+
from .resnet import ResNet, Bottleneck, BasicBlock
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _cfg(url='', **kwargs):
|
| 15 |
+
return {
|
| 16 |
+
'url': url,
|
| 17 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 18 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 19 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 20 |
+
'first_conv': 'conv1', 'classifier': 'fc',
|
| 21 |
+
**kwargs
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
default_cfgs = {
|
| 26 |
+
'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'),
|
| 27 |
+
'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'),
|
| 28 |
+
'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'),
|
| 29 |
+
'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'),
|
| 30 |
+
'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'),
|
| 31 |
+
'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth',
|
| 32 |
+
first_conv='conv1.0'),
|
| 33 |
+
'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth',
|
| 34 |
+
first_conv='conv1.0'),
|
| 35 |
+
'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth',
|
| 36 |
+
first_conv='conv1.0'),
|
| 37 |
+
'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth',
|
| 38 |
+
first_conv='conv1.0'),
|
| 39 |
+
'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth',
|
| 40 |
+
first_conv='conv1.0'),
|
| 41 |
+
'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth',
|
| 42 |
+
first_conv='conv1.0'),
|
| 43 |
+
'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth',
|
| 44 |
+
first_conv='conv1.0'),
|
| 45 |
+
'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth',
|
| 46 |
+
first_conv='conv1.0'),
|
| 47 |
+
'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth',
|
| 48 |
+
first_conv='conv1.0'),
|
| 49 |
+
'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
|
| 50 |
+
'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
|
| 51 |
+
'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
|
| 52 |
+
'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
|
| 53 |
+
'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
|
| 54 |
+
'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
|
| 55 |
+
'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
|
| 56 |
+
first_conv='conv1.0'),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _create_resnet(variant, pretrained=False, **kwargs):
|
| 61 |
+
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@register_model
|
| 65 |
+
def gluon_resnet18_v1b(pretrained=False, **kwargs):
|
| 66 |
+
"""Constructs a ResNet-18 model.
|
| 67 |
+
"""
|
| 68 |
+
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
| 69 |
+
return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@register_model
|
| 73 |
+
def gluon_resnet34_v1b(pretrained=False, **kwargs):
|
| 74 |
+
"""Constructs a ResNet-34 model.
|
| 75 |
+
"""
|
| 76 |
+
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
|
| 77 |
+
return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@register_model
|
| 81 |
+
def gluon_resnet50_v1b(pretrained=False, **kwargs):
|
| 82 |
+
"""Constructs a ResNet-50 model.
|
| 83 |
+
"""
|
| 84 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
| 85 |
+
return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@register_model
|
| 89 |
+
def gluon_resnet101_v1b(pretrained=False, **kwargs):
|
| 90 |
+
"""Constructs a ResNet-101 model.
|
| 91 |
+
"""
|
| 92 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
|
| 93 |
+
return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@register_model
|
| 97 |
+
def gluon_resnet152_v1b(pretrained=False, **kwargs):
|
| 98 |
+
"""Constructs a ResNet-152 model.
|
| 99 |
+
"""
|
| 100 |
+
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
|
| 101 |
+
return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@register_model
|
| 105 |
+
def gluon_resnet50_v1c(pretrained=False, **kwargs):
|
| 106 |
+
"""Constructs a ResNet-50 model.
|
| 107 |
+
"""
|
| 108 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs)
|
| 109 |
+
return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@register_model
|
| 113 |
+
def gluon_resnet101_v1c(pretrained=False, **kwargs):
|
| 114 |
+
"""Constructs a ResNet-101 model.
|
| 115 |
+
"""
|
| 116 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs)
|
| 117 |
+
return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@register_model
|
| 121 |
+
def gluon_resnet152_v1c(pretrained=False, **kwargs):
|
| 122 |
+
"""Constructs a ResNet-152 model.
|
| 123 |
+
"""
|
| 124 |
+
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs)
|
| 125 |
+
return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@register_model
|
| 129 |
+
def gluon_resnet50_v1d(pretrained=False, **kwargs):
|
| 130 |
+
"""Constructs a ResNet-50 model.
|
| 131 |
+
"""
|
| 132 |
+
model_args = dict(
|
| 133 |
+
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
| 134 |
+
return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@register_model
|
| 138 |
+
def gluon_resnet101_v1d(pretrained=False, **kwargs):
|
| 139 |
+
"""Constructs a ResNet-101 model.
|
| 140 |
+
"""
|
| 141 |
+
model_args = dict(
|
| 142 |
+
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
| 143 |
+
return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@register_model
|
| 147 |
+
def gluon_resnet152_v1d(pretrained=False, **kwargs):
|
| 148 |
+
"""Constructs a ResNet-152 model.
|
| 149 |
+
"""
|
| 150 |
+
model_args = dict(
|
| 151 |
+
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
| 152 |
+
return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@register_model
|
| 156 |
+
def gluon_resnet50_v1s(pretrained=False, **kwargs):
|
| 157 |
+
"""Constructs a ResNet-50 model.
|
| 158 |
+
"""
|
| 159 |
+
model_args = dict(
|
| 160 |
+
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs)
|
| 161 |
+
return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@register_model
|
| 166 |
+
def gluon_resnet101_v1s(pretrained=False, **kwargs):
|
| 167 |
+
"""Constructs a ResNet-101 model.
|
| 168 |
+
"""
|
| 169 |
+
model_args = dict(
|
| 170 |
+
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs)
|
| 171 |
+
return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@register_model
|
| 175 |
+
def gluon_resnet152_v1s(pretrained=False, **kwargs):
|
| 176 |
+
"""Constructs a ResNet-152 model.
|
| 177 |
+
"""
|
| 178 |
+
model_args = dict(
|
| 179 |
+
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs)
|
| 180 |
+
return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@register_model
|
| 185 |
+
def gluon_resnext50_32x4d(pretrained=False, **kwargs):
|
| 186 |
+
"""Constructs a ResNeXt50-32x4d model.
|
| 187 |
+
"""
|
| 188 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
| 189 |
+
return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@register_model
|
| 193 |
+
def gluon_resnext101_32x4d(pretrained=False, **kwargs):
|
| 194 |
+
"""Constructs a ResNeXt-101 model.
|
| 195 |
+
"""
|
| 196 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
| 197 |
+
return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@register_model
|
| 201 |
+
def gluon_resnext101_64x4d(pretrained=False, **kwargs):
|
| 202 |
+
"""Constructs a ResNeXt-101 model.
|
| 203 |
+
"""
|
| 204 |
+
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
|
| 205 |
+
return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@register_model
|
| 209 |
+
def gluon_seresnext50_32x4d(pretrained=False, **kwargs):
|
| 210 |
+
"""Constructs a SEResNeXt50-32x4d model.
|
| 211 |
+
"""
|
| 212 |
+
model_args = dict(
|
| 213 |
+
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
| 214 |
+
block_args=dict(attn_layer=SEModule), **kwargs)
|
| 215 |
+
return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@register_model
|
| 219 |
+
def gluon_seresnext101_32x4d(pretrained=False, **kwargs):
|
| 220 |
+
"""Constructs a SEResNeXt-101-32x4d model.
|
| 221 |
+
"""
|
| 222 |
+
model_args = dict(
|
| 223 |
+
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
|
| 224 |
+
block_args=dict(attn_layer=SEModule), **kwargs)
|
| 225 |
+
return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@register_model
|
| 229 |
+
def gluon_seresnext101_64x4d(pretrained=False, **kwargs):
|
| 230 |
+
"""Constructs a SEResNeXt-101-64x4d model.
|
| 231 |
+
"""
|
| 232 |
+
model_args = dict(
|
| 233 |
+
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4,
|
| 234 |
+
block_args=dict(attn_layer=SEModule), **kwargs)
|
| 235 |
+
return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@register_model
|
| 239 |
+
def gluon_senet154(pretrained=False, **kwargs):
|
| 240 |
+
"""Constructs an SENet-154 model.
|
| 241 |
+
"""
|
| 242 |
+
model_args = dict(
|
| 243 |
+
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
|
| 244 |
+
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs)
|
| 245 |
+
return _create_resnet('gluon_senet154', pretrained, **model_args)
|
src/custom_timm/models/gluon_xception.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytorch impl of Gluon Xception
|
| 2 |
+
This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl.
|
| 3 |
+
|
| 4 |
+
Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html)
|
| 5 |
+
Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception
|
| 6 |
+
|
| 7 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 8 |
+
"""
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 16 |
+
from .helpers import build_model_with_cfg
|
| 17 |
+
from .layers import create_classifier, get_padding
|
| 18 |
+
from .registry import register_model
|
| 19 |
+
|
| 20 |
+
__all__ = ['Xception65']
|
| 21 |
+
|
| 22 |
+
default_cfgs = {
|
| 23 |
+
'gluon_xception65': {
|
| 24 |
+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
| 25 |
+
'input_size': (3, 299, 299),
|
| 26 |
+
'crop_pct': 0.903,
|
| 27 |
+
'pool_size': (10, 10),
|
| 28 |
+
'interpolation': 'bicubic',
|
| 29 |
+
'mean': IMAGENET_DEFAULT_MEAN,
|
| 30 |
+
'std': IMAGENET_DEFAULT_STD,
|
| 31 |
+
'num_classes': 1000,
|
| 32 |
+
'first_conv': 'conv1',
|
| 33 |
+
'classifier': 'fc'
|
| 34 |
+
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
| 35 |
+
},
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
""" PADDING NOTES
|
| 39 |
+
The original PyTorch and Gluon impl of these models dutifully reproduced the
|
| 40 |
+
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
|
| 41 |
+
for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SeparableConv2d(nn.Module):
|
| 46 |
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
|
| 47 |
+
super(SeparableConv2d, self).__init__()
|
| 48 |
+
self.kernel_size = kernel_size
|
| 49 |
+
self.dilation = dilation
|
| 50 |
+
|
| 51 |
+
# depthwise convolution
|
| 52 |
+
padding = get_padding(kernel_size, stride, dilation)
|
| 53 |
+
self.conv_dw = nn.Conv2d(
|
| 54 |
+
inplanes, inplanes, kernel_size, stride=stride,
|
| 55 |
+
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
|
| 56 |
+
self.bn = norm_layer(num_features=inplanes)
|
| 57 |
+
# pointwise convolution
|
| 58 |
+
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = self.conv_dw(x)
|
| 62 |
+
x = self.bn(x)
|
| 63 |
+
x = self.conv_pw(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Block(nn.Module):
|
| 68 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
|
| 69 |
+
super(Block, self).__init__()
|
| 70 |
+
if isinstance(planes, (list, tuple)):
|
| 71 |
+
assert len(planes) == 3
|
| 72 |
+
else:
|
| 73 |
+
planes = (planes,) * 3
|
| 74 |
+
outplanes = planes[-1]
|
| 75 |
+
|
| 76 |
+
if outplanes != inplanes or stride != 1:
|
| 77 |
+
self.skip = nn.Sequential()
|
| 78 |
+
self.skip.add_module('conv1', nn.Conv2d(
|
| 79 |
+
inplanes, outplanes, 1, stride=stride, bias=False)),
|
| 80 |
+
self.skip.add_module('bn1', norm_layer(num_features=outplanes))
|
| 81 |
+
else:
|
| 82 |
+
self.skip = None
|
| 83 |
+
|
| 84 |
+
rep = OrderedDict()
|
| 85 |
+
for i in range(3):
|
| 86 |
+
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
|
| 87 |
+
rep['conv%d' % (i + 1)] = SeparableConv2d(
|
| 88 |
+
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
|
| 89 |
+
rep['bn%d' % (i + 1)] = norm_layer(planes[i])
|
| 90 |
+
inplanes = planes[i]
|
| 91 |
+
|
| 92 |
+
if not start_with_relu:
|
| 93 |
+
del rep['act1']
|
| 94 |
+
else:
|
| 95 |
+
rep['act1'] = nn.ReLU(inplace=False)
|
| 96 |
+
self.rep = nn.Sequential(rep)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
skip = x
|
| 100 |
+
if self.skip is not None:
|
| 101 |
+
skip = self.skip(skip)
|
| 102 |
+
x = self.rep(x) + skip
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Xception65(nn.Module):
|
| 107 |
+
"""Modified Aligned Xception.
|
| 108 |
+
|
| 109 |
+
NOTE: only the 65 layer version is included here, the 71 layer variant
|
| 110 |
+
was not correct and had no pretrained weights
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
| 114 |
+
drop_rate=0., global_pool='avg'):
|
| 115 |
+
super(Xception65, self).__init__()
|
| 116 |
+
self.num_classes = num_classes
|
| 117 |
+
self.drop_rate = drop_rate
|
| 118 |
+
if output_stride == 32:
|
| 119 |
+
entry_block3_stride = 2
|
| 120 |
+
exit_block20_stride = 2
|
| 121 |
+
middle_dilation = 1
|
| 122 |
+
exit_dilation = (1, 1)
|
| 123 |
+
elif output_stride == 16:
|
| 124 |
+
entry_block3_stride = 2
|
| 125 |
+
exit_block20_stride = 1
|
| 126 |
+
middle_dilation = 1
|
| 127 |
+
exit_dilation = (1, 2)
|
| 128 |
+
elif output_stride == 8:
|
| 129 |
+
entry_block3_stride = 1
|
| 130 |
+
exit_block20_stride = 1
|
| 131 |
+
middle_dilation = 2
|
| 132 |
+
exit_dilation = (2, 4)
|
| 133 |
+
else:
|
| 134 |
+
raise NotImplementedError
|
| 135 |
+
|
| 136 |
+
# Entry flow
|
| 137 |
+
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
| 138 |
+
self.bn1 = norm_layer(num_features=32)
|
| 139 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 140 |
+
|
| 141 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
| 142 |
+
self.bn2 = norm_layer(num_features=64)
|
| 143 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 144 |
+
|
| 145 |
+
self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
| 146 |
+
self.block1_act = nn.ReLU(inplace=True)
|
| 147 |
+
self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
| 148 |
+
self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
|
| 149 |
+
|
| 150 |
+
# Middle flow
|
| 151 |
+
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
| 152 |
+
728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
|
| 153 |
+
|
| 154 |
+
# Exit flow
|
| 155 |
+
self.block20 = Block(
|
| 156 |
+
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
|
| 157 |
+
self.block20_act = nn.ReLU(inplace=True)
|
| 158 |
+
|
| 159 |
+
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
| 160 |
+
self.bn3 = norm_layer(num_features=1536)
|
| 161 |
+
self.act3 = nn.ReLU(inplace=True)
|
| 162 |
+
|
| 163 |
+
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
| 164 |
+
self.bn4 = norm_layer(num_features=1536)
|
| 165 |
+
self.act4 = nn.ReLU(inplace=True)
|
| 166 |
+
|
| 167 |
+
self.num_features = 2048
|
| 168 |
+
self.conv5 = SeparableConv2d(
|
| 169 |
+
1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
| 170 |
+
self.bn5 = norm_layer(num_features=self.num_features)
|
| 171 |
+
self.act5 = nn.ReLU(inplace=True)
|
| 172 |
+
self.feature_info = [
|
| 173 |
+
dict(num_chs=64, reduction=2, module='act2'),
|
| 174 |
+
dict(num_chs=128, reduction=4, module='block1_act'),
|
| 175 |
+
dict(num_chs=256, reduction=8, module='block3.rep.act1'),
|
| 176 |
+
dict(num_chs=728, reduction=16, module='block20.rep.act1'),
|
| 177 |
+
dict(num_chs=2048, reduction=32, module='act5'),
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
| 181 |
+
|
| 182 |
+
@torch.jit.ignore
|
| 183 |
+
def group_matcher(self, coarse=False):
|
| 184 |
+
matcher = dict(
|
| 185 |
+
stem=r'^conv[12]|bn[12]',
|
| 186 |
+
blocks=[
|
| 187 |
+
(r'^mid\.block(\d+)', None),
|
| 188 |
+
(r'^block(\d+)', None),
|
| 189 |
+
(r'^conv[345]|bn[345]', (99,)),
|
| 190 |
+
],
|
| 191 |
+
)
|
| 192 |
+
return matcher
|
| 193 |
+
|
| 194 |
+
@torch.jit.ignore
|
| 195 |
+
def set_grad_checkpointing(self, enable=True):
|
| 196 |
+
assert not enable, "gradient checkpointing not supported"
|
| 197 |
+
|
| 198 |
+
@torch.jit.ignore
|
| 199 |
+
def get_classifier(self):
|
| 200 |
+
return self.fc
|
| 201 |
+
|
| 202 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 203 |
+
self.num_classes = num_classes
|
| 204 |
+
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
| 205 |
+
|
| 206 |
+
def forward_features(self, x):
|
| 207 |
+
# Entry flow
|
| 208 |
+
x = self.conv1(x)
|
| 209 |
+
x = self.bn1(x)
|
| 210 |
+
x = self.act1(x)
|
| 211 |
+
|
| 212 |
+
x = self.conv2(x)
|
| 213 |
+
x = self.bn2(x)
|
| 214 |
+
x = self.act2(x)
|
| 215 |
+
|
| 216 |
+
x = self.block1(x)
|
| 217 |
+
x = self.block1_act(x)
|
| 218 |
+
# c1 = x
|
| 219 |
+
x = self.block2(x)
|
| 220 |
+
# c2 = x
|
| 221 |
+
x = self.block3(x)
|
| 222 |
+
|
| 223 |
+
# Middle flow
|
| 224 |
+
x = self.mid(x)
|
| 225 |
+
# c3 = x
|
| 226 |
+
|
| 227 |
+
# Exit flow
|
| 228 |
+
x = self.block20(x)
|
| 229 |
+
x = self.block20_act(x)
|
| 230 |
+
x = self.conv3(x)
|
| 231 |
+
x = self.bn3(x)
|
| 232 |
+
x = self.act3(x)
|
| 233 |
+
|
| 234 |
+
x = self.conv4(x)
|
| 235 |
+
x = self.bn4(x)
|
| 236 |
+
x = self.act4(x)
|
| 237 |
+
|
| 238 |
+
x = self.conv5(x)
|
| 239 |
+
x = self.bn5(x)
|
| 240 |
+
x = self.act5(x)
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
def forward_head(self, x):
|
| 244 |
+
x = self.global_pool(x)
|
| 245 |
+
if self.drop_rate:
|
| 246 |
+
F.dropout(x, self.drop_rate, training=self.training)
|
| 247 |
+
x = self.fc(x)
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
def forward(self, x):
|
| 251 |
+
x = self.forward_features(x)
|
| 252 |
+
x = self.forward_head(x)
|
| 253 |
+
return x
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _create_gluon_xception(variant, pretrained=False, **kwargs):
|
| 257 |
+
return build_model_with_cfg(
|
| 258 |
+
Xception65, variant, pretrained,
|
| 259 |
+
feature_cfg=dict(feature_cls='hook'),
|
| 260 |
+
**kwargs)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@register_model
|
| 264 |
+
def gluon_xception65(pretrained=False, **kwargs):
|
| 265 |
+
""" Modified Aligned Xception-65
|
| 266 |
+
"""
|
| 267 |
+
return _create_gluon_xception('gluon_xception65', pretrained, **kwargs)
|
src/custom_timm/models/hardcorenas.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 6 |
+
from .efficientnet_blocks import SqueezeExcite
|
| 7 |
+
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
|
| 8 |
+
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
| 9 |
+
from .layers import get_act_fn
|
| 10 |
+
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
|
| 11 |
+
from .registry import register_model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _cfg(url='', **kwargs):
|
| 15 |
+
return {
|
| 16 |
+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 17 |
+
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
| 18 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 19 |
+
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
| 20 |
+
**kwargs
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
default_cfgs = {
|
| 25 |
+
'hardcorenas_a': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'),
|
| 26 |
+
'hardcorenas_b': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'),
|
| 27 |
+
'hardcorenas_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'),
|
| 28 |
+
'hardcorenas_d': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'),
|
| 29 |
+
'hardcorenas_e': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'),
|
| 30 |
+
'hardcorenas_f': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
| 35 |
+
"""Creates a hardcorenas model
|
| 36 |
+
|
| 37 |
+
Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS
|
| 38 |
+
Paper: https://arxiv.org/abs/2102.11646
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
num_features = 1280
|
| 42 |
+
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
| 43 |
+
model_kwargs = dict(
|
| 44 |
+
block_args=decode_arch_def(arch_def),
|
| 45 |
+
num_features=num_features,
|
| 46 |
+
stem_size=32,
|
| 47 |
+
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
| 48 |
+
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
| 49 |
+
se_layer=se_layer,
|
| 50 |
+
**kwargs,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
features_only = False
|
| 54 |
+
model_cls = MobileNetV3
|
| 55 |
+
kwargs_filter = None
|
| 56 |
+
if model_kwargs.pop('features_only', False):
|
| 57 |
+
features_only = True
|
| 58 |
+
kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
|
| 59 |
+
model_cls = MobileNetV3Features
|
| 60 |
+
model = build_model_with_cfg(
|
| 61 |
+
model_cls, variant, pretrained,
|
| 62 |
+
pretrained_strict=not features_only,
|
| 63 |
+
kwargs_filter=kwargs_filter,
|
| 64 |
+
**model_kwargs)
|
| 65 |
+
if features_only:
|
| 66 |
+
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@register_model
|
| 71 |
+
def hardcorenas_a(pretrained=False, **kwargs):
|
| 72 |
+
""" hardcorenas_A """
|
| 73 |
+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
| 74 |
+
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
|
| 75 |
+
['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'],
|
| 76 |
+
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'],
|
| 77 |
+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
| 78 |
+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs)
|
| 79 |
+
return model
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@register_model
|
| 83 |
+
def hardcorenas_b(pretrained=False, **kwargs):
|
| 84 |
+
""" hardcorenas_B """
|
| 85 |
+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'],
|
| 86 |
+
['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'],
|
| 87 |
+
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'],
|
| 88 |
+
['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
|
| 89 |
+
['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
|
| 90 |
+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
|
| 91 |
+
['cn_r1_k1_s1_c960']]
|
| 92 |
+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs)
|
| 93 |
+
return model
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@register_model
|
| 97 |
+
def hardcorenas_c(pretrained=False, **kwargs):
|
| 98 |
+
""" hardcorenas_C """
|
| 99 |
+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
| 100 |
+
['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre',
|
| 101 |
+
'ir_r1_k5_s1_e3_c40_nre'],
|
| 102 |
+
['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
|
| 103 |
+
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
|
| 104 |
+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
|
| 105 |
+
['cn_r1_k1_s1_c960']]
|
| 106 |
+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs)
|
| 107 |
+
return model
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@register_model
|
| 111 |
+
def hardcorenas_d(pretrained=False, **kwargs):
|
| 112 |
+
""" hardcorenas_D """
|
| 113 |
+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
| 114 |
+
['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'],
|
| 115 |
+
['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
|
| 116 |
+
'ir_r1_k3_s1_e3_c80_se0.25'],
|
| 117 |
+
['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25',
|
| 118 |
+
'ir_r1_k5_s1_e3_c112_se0.25'],
|
| 119 |
+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
|
| 120 |
+
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
| 121 |
+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs)
|
| 122 |
+
return model
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@register_model
|
| 126 |
+
def hardcorenas_e(pretrained=False, **kwargs):
|
| 127 |
+
""" hardcorenas_E """
|
| 128 |
+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
| 129 |
+
['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25',
|
| 130 |
+
'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'],
|
| 131 |
+
['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
|
| 132 |
+
'ir_r1_k5_s1_e3_c112_se0.25'],
|
| 133 |
+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
|
| 134 |
+
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
| 135 |
+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs)
|
| 136 |
+
return model
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@register_model
|
| 140 |
+
def hardcorenas_f(pretrained=False, **kwargs):
|
| 141 |
+
""" hardcorenas_F """
|
| 142 |
+
arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
|
| 143 |
+
['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
|
| 144 |
+
['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
|
| 145 |
+
'ir_r1_k3_s1_e3_c80_se0.25'],
|
| 146 |
+
['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
|
| 147 |
+
'ir_r1_k3_s1_e3_c112_se0.25'],
|
| 148 |
+
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25',
|
| 149 |
+
'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
|
| 150 |
+
model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs)
|
| 151 |
+
return model
|
src/custom_timm/models/helpers.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Model creation / weight loading / state_dict helpers
|
| 2 |
+
|
| 3 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 4 |
+
"""
|
| 5 |
+
import collections.abc
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
from collections import OrderedDict, defaultdict
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
from itertools import chain
|
| 13 |
+
from typing import Any, Callable, Optional, Tuple, Dict, Union
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.hub import load_state_dict_from_url
|
| 18 |
+
from torch.utils.checkpoint import checkpoint
|
| 19 |
+
|
| 20 |
+
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
| 21 |
+
from .fx_features import FeatureGraphNet
|
| 22 |
+
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
| 23 |
+
from .layers import Conv2dSame, Linear, BatchNormAct2d
|
| 24 |
+
from .registry import get_pretrained_cfg
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Global variables for rarely used pretrained checkpoint download progress and hash check.
|
| 31 |
+
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
|
| 32 |
+
_DOWNLOAD_PROGRESS = False
|
| 33 |
+
_CHECK_HASH = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def clean_state_dict(state_dict):
|
| 37 |
+
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
| 38 |
+
cleaned_state_dict = OrderedDict()
|
| 39 |
+
for k, v in state_dict.items():
|
| 40 |
+
name = k[7:] if k.startswith('module.') else k
|
| 41 |
+
cleaned_state_dict[name] = v
|
| 42 |
+
return cleaned_state_dict
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_state_dict(checkpoint_path, use_ema=True):
|
| 46 |
+
if checkpoint_path and os.path.isfile(checkpoint_path):
|
| 47 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 48 |
+
state_dict_key = ''
|
| 49 |
+
if isinstance(checkpoint, dict):
|
| 50 |
+
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
| 51 |
+
state_dict_key = 'state_dict_ema'
|
| 52 |
+
elif use_ema and checkpoint.get('model_ema', None) is not None:
|
| 53 |
+
state_dict_key = 'model_ema'
|
| 54 |
+
elif 'state_dict' in checkpoint:
|
| 55 |
+
state_dict_key = 'state_dict'
|
| 56 |
+
elif 'model' in checkpoint:
|
| 57 |
+
state_dict_key = 'model'
|
| 58 |
+
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
|
| 59 |
+
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
| 60 |
+
return state_dict
|
| 61 |
+
else:
|
| 62 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
| 63 |
+
raise FileNotFoundError()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
| 67 |
+
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
| 68 |
+
# numpy checkpoint, try to load via model specific load_pretrained fn
|
| 69 |
+
if hasattr(model, 'load_pretrained'):
|
| 70 |
+
model.load_pretrained(checkpoint_path)
|
| 71 |
+
else:
|
| 72 |
+
raise NotImplementedError('Model cannot load numpy checkpoint')
|
| 73 |
+
return
|
| 74 |
+
state_dict = load_state_dict(checkpoint_path, use_ema)
|
| 75 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
| 76 |
+
return incompatible_keys
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
| 80 |
+
resume_epoch = None
|
| 81 |
+
if os.path.isfile(checkpoint_path):
|
| 82 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 83 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 84 |
+
if log_info:
|
| 85 |
+
_logger.info('Restoring model state from checkpoint...')
|
| 86 |
+
state_dict = clean_state_dict(checkpoint['state_dict'])
|
| 87 |
+
model.load_state_dict(state_dict)
|
| 88 |
+
|
| 89 |
+
if optimizer is not None and 'optimizer' in checkpoint:
|
| 90 |
+
if log_info:
|
| 91 |
+
_logger.info('Restoring optimizer state from checkpoint...')
|
| 92 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 93 |
+
|
| 94 |
+
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
| 95 |
+
if log_info:
|
| 96 |
+
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
| 97 |
+
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
| 98 |
+
|
| 99 |
+
if 'epoch' in checkpoint:
|
| 100 |
+
resume_epoch = checkpoint['epoch']
|
| 101 |
+
if 'version' in checkpoint and checkpoint['version'] > 1:
|
| 102 |
+
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
| 103 |
+
|
| 104 |
+
if log_info:
|
| 105 |
+
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
| 106 |
+
else:
|
| 107 |
+
model.load_state_dict(checkpoint)
|
| 108 |
+
if log_info:
|
| 109 |
+
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
| 110 |
+
return resume_epoch
|
| 111 |
+
else:
|
| 112 |
+
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
| 113 |
+
raise FileNotFoundError()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _resolve_pretrained_source(pretrained_cfg):
|
| 117 |
+
cfg_source = pretrained_cfg.get('source', '')
|
| 118 |
+
pretrained_url = pretrained_cfg.get('url', None)
|
| 119 |
+
pretrained_file = pretrained_cfg.get('file', None)
|
| 120 |
+
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
| 121 |
+
# resolve where to load pretrained weights from
|
| 122 |
+
load_from = ''
|
| 123 |
+
pretrained_loc = ''
|
| 124 |
+
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
|
| 125 |
+
# hf-hub specified as source via model identifier
|
| 126 |
+
load_from = 'hf-hub'
|
| 127 |
+
assert hf_hub_id
|
| 128 |
+
pretrained_loc = hf_hub_id
|
| 129 |
+
else:
|
| 130 |
+
# default source == timm or unspecified
|
| 131 |
+
if pretrained_file:
|
| 132 |
+
load_from = 'file'
|
| 133 |
+
pretrained_loc = pretrained_file
|
| 134 |
+
elif pretrained_url:
|
| 135 |
+
load_from = 'url'
|
| 136 |
+
pretrained_loc = pretrained_url
|
| 137 |
+
elif hf_hub_id and has_hf_hub(necessary=True):
|
| 138 |
+
# hf-hub available as alternate weight source in default_cfg
|
| 139 |
+
load_from = 'hf-hub'
|
| 140 |
+
pretrained_loc = hf_hub_id
|
| 141 |
+
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
|
| 142 |
+
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
| 143 |
+
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
| 144 |
+
return load_from, pretrained_loc
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def set_pretrained_download_progress(enable=True):
|
| 148 |
+
""" Set download progress for pretrained weights on/off (globally). """
|
| 149 |
+
global _DOWNLOAD_PROGRESS
|
| 150 |
+
_DOWNLOAD_PROGRESS = enable
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def set_pretrained_check_hash(enable=True):
|
| 154 |
+
""" Set hash checking for pretrained weights on/off (globally). """
|
| 155 |
+
global _CHECK_HASH
|
| 156 |
+
_CHECK_HASH = enable
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_custom_pretrained(
|
| 160 |
+
model: nn.Module,
|
| 161 |
+
pretrained_cfg: Optional[Dict] = None,
|
| 162 |
+
load_fn: Optional[Callable] = None,
|
| 163 |
+
):
|
| 164 |
+
r"""Loads a custom (read non .pth) weight file
|
| 165 |
+
|
| 166 |
+
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
| 167 |
+
a passed in custom load fun, or the `load_pretrained` model member fn.
|
| 168 |
+
|
| 169 |
+
If the object is already present in `model_dir`, it's deserialized and returned.
|
| 170 |
+
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
| 171 |
+
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
model: The instantiated model to load weights into
|
| 175 |
+
pretrained_cfg (dict): Default pretrained model cfg
|
| 176 |
+
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
|
| 177 |
+
'laod_pretrained' on the model will be called if it exists
|
| 178 |
+
"""
|
| 179 |
+
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
|
| 180 |
+
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
| 181 |
+
if not load_from:
|
| 182 |
+
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
| 183 |
+
return
|
| 184 |
+
if load_from == 'hf-hub': # FIXME
|
| 185 |
+
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
|
| 186 |
+
elif load_from == 'url':
|
| 187 |
+
pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS)
|
| 188 |
+
|
| 189 |
+
if load_fn is not None:
|
| 190 |
+
load_fn(model, pretrained_loc)
|
| 191 |
+
elif hasattr(model, 'load_pretrained'):
|
| 192 |
+
model.load_pretrained(pretrained_loc)
|
| 193 |
+
else:
|
| 194 |
+
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def adapt_input_conv(in_chans, conv_weight):
|
| 198 |
+
conv_type = conv_weight.dtype
|
| 199 |
+
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
| 200 |
+
O, I, J, K = conv_weight.shape
|
| 201 |
+
if in_chans == 1:
|
| 202 |
+
if I > 3:
|
| 203 |
+
assert conv_weight.shape[1] % 3 == 0
|
| 204 |
+
# For models with space2depth stems
|
| 205 |
+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
| 206 |
+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
| 207 |
+
else:
|
| 208 |
+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
| 209 |
+
elif in_chans != 3:
|
| 210 |
+
if I != 3:
|
| 211 |
+
raise NotImplementedError('Weight format not supported by conversion.')
|
| 212 |
+
else:
|
| 213 |
+
# NOTE this strategy should be better than random init, but there could be other combinations of
|
| 214 |
+
# the original RGB input layer weights that'd work better for specific cases.
|
| 215 |
+
repeat = int(math.ceil(in_chans / 3))
|
| 216 |
+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
| 217 |
+
conv_weight *= (3 / float(in_chans))
|
| 218 |
+
conv_weight = conv_weight.to(conv_type)
|
| 219 |
+
return conv_weight
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def load_pretrained(
|
| 223 |
+
model: nn.Module,
|
| 224 |
+
pretrained_cfg: Optional[Dict] = None,
|
| 225 |
+
num_classes: int = 1000,
|
| 226 |
+
in_chans: int = 3,
|
| 227 |
+
filter_fn: Optional[Callable] = None,
|
| 228 |
+
strict: bool = True,
|
| 229 |
+
):
|
| 230 |
+
""" Load pretrained checkpoint
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
model (nn.Module) : PyTorch model module
|
| 234 |
+
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
|
| 235 |
+
num_classes (int): num_classes for model
|
| 236 |
+
in_chans (int): in_chans for model
|
| 237 |
+
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
| 238 |
+
strict (bool): strict load of checkpoint
|
| 239 |
+
|
| 240 |
+
"""
|
| 241 |
+
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
|
| 242 |
+
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
| 243 |
+
if load_from == 'file':
|
| 244 |
+
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
| 245 |
+
state_dict = load_state_dict(pretrained_loc)
|
| 246 |
+
elif load_from == 'url':
|
| 247 |
+
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
| 248 |
+
state_dict = load_state_dict_from_url(
|
| 249 |
+
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
|
| 250 |
+
elif load_from == 'hf-hub':
|
| 251 |
+
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
| 252 |
+
if isinstance(pretrained_loc, (list, tuple)):
|
| 253 |
+
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
| 254 |
+
else:
|
| 255 |
+
state_dict = load_state_dict_from_hf(pretrained_loc)
|
| 256 |
+
else:
|
| 257 |
+
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
if filter_fn is not None:
|
| 261 |
+
# for backwards compat with filter fn that take one arg, try one first, the two
|
| 262 |
+
try:
|
| 263 |
+
state_dict = filter_fn(state_dict)
|
| 264 |
+
except TypeError:
|
| 265 |
+
state_dict = filter_fn(state_dict, model)
|
| 266 |
+
|
| 267 |
+
input_convs = pretrained_cfg.get('first_conv', None)
|
| 268 |
+
if input_convs is not None and in_chans != 3:
|
| 269 |
+
if isinstance(input_convs, str):
|
| 270 |
+
input_convs = (input_convs,)
|
| 271 |
+
for input_conv_name in input_convs:
|
| 272 |
+
weight_name = input_conv_name + '.weight'
|
| 273 |
+
try:
|
| 274 |
+
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
|
| 275 |
+
_logger.info(
|
| 276 |
+
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
|
| 277 |
+
except NotImplementedError as e:
|
| 278 |
+
del state_dict[weight_name]
|
| 279 |
+
strict = False
|
| 280 |
+
_logger.warning(
|
| 281 |
+
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
| 282 |
+
|
| 283 |
+
classifiers = pretrained_cfg.get('classifier', None)
|
| 284 |
+
label_offset = pretrained_cfg.get('label_offset', 0)
|
| 285 |
+
if classifiers is not None:
|
| 286 |
+
if isinstance(classifiers, str):
|
| 287 |
+
classifiers = (classifiers,)
|
| 288 |
+
if num_classes != pretrained_cfg['num_classes']:
|
| 289 |
+
for classifier_name in classifiers:
|
| 290 |
+
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
| 291 |
+
state_dict.pop(classifier_name + '.weight', None)
|
| 292 |
+
state_dict.pop(classifier_name + '.bias', None)
|
| 293 |
+
strict = False
|
| 294 |
+
elif label_offset > 0:
|
| 295 |
+
for classifier_name in classifiers:
|
| 296 |
+
# special case for pretrained weights with an extra background class in pretrained weights
|
| 297 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
| 298 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
| 299 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
| 300 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
| 301 |
+
|
| 302 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def extract_layer(model, layer):
|
| 306 |
+
layer = layer.split('.')
|
| 307 |
+
module = model
|
| 308 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
| 309 |
+
module = model.module
|
| 310 |
+
if not hasattr(model, 'module') and layer[0] == 'module':
|
| 311 |
+
layer = layer[1:]
|
| 312 |
+
for l in layer:
|
| 313 |
+
if hasattr(module, l):
|
| 314 |
+
if not l.isdigit():
|
| 315 |
+
module = getattr(module, l)
|
| 316 |
+
else:
|
| 317 |
+
module = module[int(l)]
|
| 318 |
+
else:
|
| 319 |
+
return module
|
| 320 |
+
return module
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def set_layer(model, layer, val):
|
| 324 |
+
layer = layer.split('.')
|
| 325 |
+
module = model
|
| 326 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
| 327 |
+
module = model.module
|
| 328 |
+
lst_index = 0
|
| 329 |
+
module2 = module
|
| 330 |
+
for l in layer:
|
| 331 |
+
if hasattr(module2, l):
|
| 332 |
+
if not l.isdigit():
|
| 333 |
+
module2 = getattr(module2, l)
|
| 334 |
+
else:
|
| 335 |
+
module2 = module2[int(l)]
|
| 336 |
+
lst_index += 1
|
| 337 |
+
lst_index -= 1
|
| 338 |
+
for l in layer[:lst_index]:
|
| 339 |
+
if not l.isdigit():
|
| 340 |
+
module = getattr(module, l)
|
| 341 |
+
else:
|
| 342 |
+
module = module[int(l)]
|
| 343 |
+
l = layer[lst_index]
|
| 344 |
+
setattr(module, l, val)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def adapt_model_from_string(parent_module, model_string):
|
| 348 |
+
separator = '***'
|
| 349 |
+
state_dict = {}
|
| 350 |
+
lst_shape = model_string.split(separator)
|
| 351 |
+
for k in lst_shape:
|
| 352 |
+
k = k.split(':')
|
| 353 |
+
key = k[0]
|
| 354 |
+
shape = k[1][1:-1].split(',')
|
| 355 |
+
if shape[0] != '':
|
| 356 |
+
state_dict[key] = [int(i) for i in shape]
|
| 357 |
+
|
| 358 |
+
new_module = deepcopy(parent_module)
|
| 359 |
+
for n, m in parent_module.named_modules():
|
| 360 |
+
old_module = extract_layer(parent_module, n)
|
| 361 |
+
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
| 362 |
+
if isinstance(old_module, Conv2dSame):
|
| 363 |
+
conv = Conv2dSame
|
| 364 |
+
else:
|
| 365 |
+
conv = nn.Conv2d
|
| 366 |
+
s = state_dict[n + '.weight']
|
| 367 |
+
in_channels = s[1]
|
| 368 |
+
out_channels = s[0]
|
| 369 |
+
g = 1
|
| 370 |
+
if old_module.groups > 1:
|
| 371 |
+
in_channels = out_channels
|
| 372 |
+
g = in_channels
|
| 373 |
+
new_conv = conv(
|
| 374 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
| 375 |
+
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
| 376 |
+
groups=g, stride=old_module.stride)
|
| 377 |
+
set_layer(new_module, n, new_conv)
|
| 378 |
+
elif isinstance(old_module, BatchNormAct2d):
|
| 379 |
+
new_bn = BatchNormAct2d(
|
| 380 |
+
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
| 381 |
+
affine=old_module.affine, track_running_stats=True)
|
| 382 |
+
new_bn.drop = old_module.drop
|
| 383 |
+
new_bn.act = old_module.act
|
| 384 |
+
set_layer(new_module, n, new_bn)
|
| 385 |
+
elif isinstance(old_module, nn.BatchNorm2d):
|
| 386 |
+
new_bn = nn.BatchNorm2d(
|
| 387 |
+
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
| 388 |
+
affine=old_module.affine, track_running_stats=True)
|
| 389 |
+
set_layer(new_module, n, new_bn)
|
| 390 |
+
elif isinstance(old_module, nn.Linear):
|
| 391 |
+
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
| 392 |
+
num_features = state_dict[n + '.weight'][1]
|
| 393 |
+
new_fc = Linear(
|
| 394 |
+
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
| 395 |
+
set_layer(new_module, n, new_fc)
|
| 396 |
+
if hasattr(new_module, 'num_features'):
|
| 397 |
+
new_module.num_features = num_features
|
| 398 |
+
new_module.eval()
|
| 399 |
+
parent_module.eval()
|
| 400 |
+
|
| 401 |
+
return new_module
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def adapt_model_from_file(parent_module, model_variant):
|
| 405 |
+
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
|
| 406 |
+
with open(adapt_file, 'r') as f:
|
| 407 |
+
return adapt_model_from_string(parent_module, f.read().strip())
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def pretrained_cfg_for_features(pretrained_cfg):
|
| 411 |
+
pretrained_cfg = deepcopy(pretrained_cfg)
|
| 412 |
+
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
| 413 |
+
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
|
| 414 |
+
for tr in to_remove:
|
| 415 |
+
pretrained_cfg.pop(tr, None)
|
| 416 |
+
return pretrained_cfg
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def set_default_kwargs(kwargs, names, pretrained_cfg):
|
| 420 |
+
for n in names:
|
| 421 |
+
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
| 422 |
+
# pretrained_cfg has one input_size=(C, H ,W) entry
|
| 423 |
+
if n == 'img_size':
|
| 424 |
+
input_size = pretrained_cfg.get('input_size', None)
|
| 425 |
+
if input_size is not None:
|
| 426 |
+
assert len(input_size) == 3
|
| 427 |
+
kwargs.setdefault(n, input_size[-2:])
|
| 428 |
+
elif n == 'in_chans':
|
| 429 |
+
input_size = pretrained_cfg.get('input_size', None)
|
| 430 |
+
if input_size is not None:
|
| 431 |
+
assert len(input_size) == 3
|
| 432 |
+
kwargs.setdefault(n, input_size[0])
|
| 433 |
+
else:
|
| 434 |
+
default_val = pretrained_cfg.get(n, None)
|
| 435 |
+
if default_val is not None:
|
| 436 |
+
kwargs.setdefault(n, pretrained_cfg[n])
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def filter_kwargs(kwargs, names):
|
| 440 |
+
if not kwargs or not names:
|
| 441 |
+
return
|
| 442 |
+
for n in names:
|
| 443 |
+
kwargs.pop(n, None)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
| 447 |
+
""" Update the default_cfg and kwargs before passing to model
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
pretrained_cfg: input pretrained cfg (updated in-place)
|
| 451 |
+
kwargs: keyword args passed to model build fn (updated in-place)
|
| 452 |
+
kwargs_filter: keyword arg keys that must be removed before model __init__
|
| 453 |
+
"""
|
| 454 |
+
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
| 455 |
+
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
|
| 456 |
+
if pretrained_cfg.get('fixed_input_size', False):
|
| 457 |
+
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
|
| 458 |
+
default_kwarg_names += ('img_size',)
|
| 459 |
+
set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg)
|
| 460 |
+
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
| 461 |
+
filter_kwargs(kwargs, names=kwargs_filter)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
|
| 465 |
+
if pretrained_cfg and isinstance(pretrained_cfg, dict):
|
| 466 |
+
# highest priority, pretrained_cfg available and passed as arg
|
| 467 |
+
return deepcopy(pretrained_cfg)
|
| 468 |
+
# fallback to looking up pretrained cfg in model registry by variant identifier
|
| 469 |
+
pretrained_cfg = get_pretrained_cfg(variant)
|
| 470 |
+
if not pretrained_cfg:
|
| 471 |
+
_logger.warning(
|
| 472 |
+
f"No pretrained configuration specified for {variant} model. Using a default."
|
| 473 |
+
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
|
| 474 |
+
pretrained_cfg = dict(
|
| 475 |
+
url='',
|
| 476 |
+
num_classes=1000,
|
| 477 |
+
input_size=(3, 224, 224),
|
| 478 |
+
pool_size=None,
|
| 479 |
+
crop_pct=.9,
|
| 480 |
+
interpolation='bicubic',
|
| 481 |
+
first_conv='',
|
| 482 |
+
classifier='',
|
| 483 |
+
)
|
| 484 |
+
return pretrained_cfg
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def build_model_with_cfg(
|
| 488 |
+
model_cls: Callable,
|
| 489 |
+
variant: str,
|
| 490 |
+
pretrained: bool,
|
| 491 |
+
pretrained_cfg: Optional[Dict] = None,
|
| 492 |
+
model_cfg: Optional[Any] = None,
|
| 493 |
+
feature_cfg: Optional[Dict] = None,
|
| 494 |
+
pretrained_strict: bool = True,
|
| 495 |
+
pretrained_filter_fn: Optional[Callable] = None,
|
| 496 |
+
pretrained_custom_load: bool = False,
|
| 497 |
+
kwargs_filter: Optional[Tuple[str]] = None,
|
| 498 |
+
**kwargs):
|
| 499 |
+
""" Build model with specified default_cfg and optional model_cfg
|
| 500 |
+
|
| 501 |
+
This helper fn aids in the construction of a model including:
|
| 502 |
+
* handling default_cfg and associated pretrained weight loading
|
| 503 |
+
* passing through optional model_cfg for models with config based arch spec
|
| 504 |
+
* features_only model adaptation
|
| 505 |
+
* pruning config / model adaptation
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
model_cls (nn.Module): model class
|
| 509 |
+
variant (str): model variant name
|
| 510 |
+
pretrained (bool): load pretrained weights
|
| 511 |
+
pretrained_cfg (dict): model's pretrained weight/task config
|
| 512 |
+
model_cfg (Optional[Dict]): model's architecture config
|
| 513 |
+
feature_cfg (Optional[Dict]: feature extraction adapter config
|
| 514 |
+
pretrained_strict (bool): load pretrained weights strictly
|
| 515 |
+
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
| 516 |
+
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
|
| 517 |
+
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
| 518 |
+
**kwargs: model args passed through to model __init__
|
| 519 |
+
"""
|
| 520 |
+
pruned = kwargs.pop('pruned', False)
|
| 521 |
+
features = False
|
| 522 |
+
feature_cfg = feature_cfg or {}
|
| 523 |
+
|
| 524 |
+
# resolve and update model pretrained config and model kwargs
|
| 525 |
+
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg)
|
| 526 |
+
update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
| 527 |
+
pretrained_cfg.setdefault('architecture', variant)
|
| 528 |
+
|
| 529 |
+
# Setup for feature extraction wrapper done at end of this fn
|
| 530 |
+
if kwargs.pop('features_only', False):
|
| 531 |
+
features = True
|
| 532 |
+
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
| 533 |
+
if 'out_indices' in kwargs:
|
| 534 |
+
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
| 535 |
+
|
| 536 |
+
# Build the model
|
| 537 |
+
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
| 538 |
+
model.pretrained_cfg = pretrained_cfg
|
| 539 |
+
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
| 540 |
+
|
| 541 |
+
if pruned:
|
| 542 |
+
model = adapt_model_from_file(model, variant)
|
| 543 |
+
|
| 544 |
+
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
| 545 |
+
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
| 546 |
+
if pretrained:
|
| 547 |
+
if pretrained_custom_load:
|
| 548 |
+
# FIXME improve custom load trigger
|
| 549 |
+
load_custom_pretrained(model, pretrained_cfg=pretrained_cfg)
|
| 550 |
+
else:
|
| 551 |
+
load_pretrained(
|
| 552 |
+
model,
|
| 553 |
+
pretrained_cfg=pretrained_cfg,
|
| 554 |
+
num_classes=num_classes_pretrained,
|
| 555 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 556 |
+
filter_fn=pretrained_filter_fn,
|
| 557 |
+
strict=pretrained_strict)
|
| 558 |
+
|
| 559 |
+
# Wrap the model in a feature extraction module if enabled
|
| 560 |
+
if features:
|
| 561 |
+
feature_cls = FeatureListNet
|
| 562 |
+
if 'feature_cls' in feature_cfg:
|
| 563 |
+
feature_cls = feature_cfg.pop('feature_cls')
|
| 564 |
+
if isinstance(feature_cls, str):
|
| 565 |
+
feature_cls = feature_cls.lower()
|
| 566 |
+
if 'hook' in feature_cls:
|
| 567 |
+
feature_cls = FeatureHookNet
|
| 568 |
+
elif feature_cls == 'fx':
|
| 569 |
+
feature_cls = FeatureGraphNet
|
| 570 |
+
else:
|
| 571 |
+
assert False, f'Unknown feature class {feature_cls}'
|
| 572 |
+
model = feature_cls(model, **feature_cfg)
|
| 573 |
+
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
|
| 574 |
+
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
| 575 |
+
|
| 576 |
+
return model
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def model_parameters(model, exclude_head=False):
|
| 580 |
+
if exclude_head:
|
| 581 |
+
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
| 582 |
+
return [p for p in model.parameters()][:-2]
|
| 583 |
+
else:
|
| 584 |
+
return model.parameters()
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
| 588 |
+
if not depth_first and include_root:
|
| 589 |
+
fn(module=module, name=name)
|
| 590 |
+
for child_name, child_module in module.named_children():
|
| 591 |
+
child_name = '.'.join((name, child_name)) if name else child_name
|
| 592 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 593 |
+
if depth_first and include_root:
|
| 594 |
+
fn(module=module, name=name)
|
| 595 |
+
return module
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
| 599 |
+
if not depth_first and include_root:
|
| 600 |
+
yield name, module
|
| 601 |
+
for child_name, child_module in module.named_children():
|
| 602 |
+
child_name = '.'.join((name, child_name)) if name else child_name
|
| 603 |
+
yield from named_modules(
|
| 604 |
+
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 605 |
+
if depth_first and include_root:
|
| 606 |
+
yield name, module
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
|
| 610 |
+
if module._parameters and not depth_first and include_root:
|
| 611 |
+
yield name, module
|
| 612 |
+
for child_name, child_module in module.named_children():
|
| 613 |
+
child_name = '.'.join((name, child_name)) if name else child_name
|
| 614 |
+
yield from named_modules_with_params(
|
| 615 |
+
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 616 |
+
if module._parameters and depth_first and include_root:
|
| 617 |
+
yield name, module
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
MATCH_PREV_GROUP = (99999,)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def group_with_matcher(
|
| 624 |
+
named_objects,
|
| 625 |
+
group_matcher: Union[Dict, Callable],
|
| 626 |
+
output_values: bool = False,
|
| 627 |
+
reverse: bool = False
|
| 628 |
+
):
|
| 629 |
+
if isinstance(group_matcher, dict):
|
| 630 |
+
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
|
| 631 |
+
compiled = []
|
| 632 |
+
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
|
| 633 |
+
if mspec is None:
|
| 634 |
+
continue
|
| 635 |
+
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
|
| 636 |
+
if isinstance(mspec, (tuple, list)):
|
| 637 |
+
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
|
| 638 |
+
for sspec in mspec:
|
| 639 |
+
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
|
| 640 |
+
else:
|
| 641 |
+
compiled += [(re.compile(mspec), (group_ordinal,), None)]
|
| 642 |
+
group_matcher = compiled
|
| 643 |
+
|
| 644 |
+
def _get_grouping(name):
|
| 645 |
+
if isinstance(group_matcher, (list, tuple)):
|
| 646 |
+
for match_fn, prefix, suffix in group_matcher:
|
| 647 |
+
r = match_fn.match(name)
|
| 648 |
+
if r:
|
| 649 |
+
parts = (prefix, r.groups(), suffix)
|
| 650 |
+
# map all tuple elem to int for numeric sort, filter out None entries
|
| 651 |
+
return tuple(map(float, chain.from_iterable(filter(None, parts))))
|
| 652 |
+
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
|
| 653 |
+
else:
|
| 654 |
+
ord = group_matcher(name)
|
| 655 |
+
if not isinstance(ord, collections.abc.Iterable):
|
| 656 |
+
return ord,
|
| 657 |
+
return tuple(ord)
|
| 658 |
+
|
| 659 |
+
# map layers into groups via ordinals (ints or tuples of ints) from matcher
|
| 660 |
+
grouping = defaultdict(list)
|
| 661 |
+
for k, v in named_objects:
|
| 662 |
+
grouping[_get_grouping(k)].append(v if output_values else k)
|
| 663 |
+
|
| 664 |
+
# remap to integers
|
| 665 |
+
layer_id_to_param = defaultdict(list)
|
| 666 |
+
lid = -1
|
| 667 |
+
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
|
| 668 |
+
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
|
| 669 |
+
lid += 1
|
| 670 |
+
layer_id_to_param[lid].extend(grouping[k])
|
| 671 |
+
|
| 672 |
+
if reverse:
|
| 673 |
+
assert not output_values, "reverse mapping only sensible for name output"
|
| 674 |
+
# output reverse mapping
|
| 675 |
+
param_to_layer_id = {}
|
| 676 |
+
for lid, lm in layer_id_to_param.items():
|
| 677 |
+
for n in lm:
|
| 678 |
+
param_to_layer_id[n] = lid
|
| 679 |
+
return param_to_layer_id
|
| 680 |
+
|
| 681 |
+
return layer_id_to_param
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def group_parameters(
|
| 685 |
+
module: nn.Module,
|
| 686 |
+
group_matcher,
|
| 687 |
+
output_values=False,
|
| 688 |
+
reverse=False,
|
| 689 |
+
):
|
| 690 |
+
return group_with_matcher(
|
| 691 |
+
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def group_modules(
|
| 695 |
+
module: nn.Module,
|
| 696 |
+
group_matcher,
|
| 697 |
+
output_values=False,
|
| 698 |
+
reverse=False,
|
| 699 |
+
):
|
| 700 |
+
return group_with_matcher(
|
| 701 |
+
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def checkpoint_seq(
|
| 705 |
+
functions,
|
| 706 |
+
x,
|
| 707 |
+
every=1,
|
| 708 |
+
flatten=False,
|
| 709 |
+
skip_last=False,
|
| 710 |
+
preserve_rng_state=True
|
| 711 |
+
):
|
| 712 |
+
r"""A helper function for checkpointing sequential models.
|
| 713 |
+
|
| 714 |
+
Sequential models execute a list of modules/functions in order
|
| 715 |
+
(sequentially). Therefore, we can divide such a sequence into segments
|
| 716 |
+
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
|
| 717 |
+
manner, i.e., not storing the intermediate activations. The inputs of each
|
| 718 |
+
checkpointed segment will be saved for re-running the segment in the backward pass.
|
| 719 |
+
|
| 720 |
+
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
| 721 |
+
|
| 722 |
+
.. warning::
|
| 723 |
+
Checkpointing currently only supports :func:`torch.autograd.backward`
|
| 724 |
+
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
| 725 |
+
is not supported.
|
| 726 |
+
|
| 727 |
+
.. warning:
|
| 728 |
+
At least one of the inputs needs to have :code:`requires_grad=True` if
|
| 729 |
+
grads are needed for model inputs, otherwise the checkpointed part of the
|
| 730 |
+
model won't have gradients.
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
|
| 734 |
+
x: A Tensor that is input to :attr:`functions`
|
| 735 |
+
every: checkpoint every-n functions (default: 1)
|
| 736 |
+
flatten (bool): flatten nn.Sequential of nn.Sequentials
|
| 737 |
+
skip_last (bool): skip checkpointing the last function in the sequence if True
|
| 738 |
+
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
|
| 739 |
+
the RNG state during each checkpoint.
|
| 740 |
+
|
| 741 |
+
Returns:
|
| 742 |
+
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
| 743 |
+
|
| 744 |
+
Example:
|
| 745 |
+
>>> model = nn.Sequential(...)
|
| 746 |
+
>>> input_var = checkpoint_seq(model, input_var, every=2)
|
| 747 |
+
"""
|
| 748 |
+
def run_function(start, end, functions):
|
| 749 |
+
def forward(_x):
|
| 750 |
+
for j in range(start, end + 1):
|
| 751 |
+
_x = functions[j](_x)
|
| 752 |
+
return _x
|
| 753 |
+
return forward
|
| 754 |
+
|
| 755 |
+
if isinstance(functions, torch.nn.Sequential):
|
| 756 |
+
functions = functions.children()
|
| 757 |
+
if flatten:
|
| 758 |
+
functions = chain.from_iterable(functions)
|
| 759 |
+
if not isinstance(functions, (tuple, list)):
|
| 760 |
+
functions = tuple(functions)
|
| 761 |
+
|
| 762 |
+
num_checkpointed = len(functions)
|
| 763 |
+
if skip_last:
|
| 764 |
+
num_checkpointed -= 1
|
| 765 |
+
end = -1
|
| 766 |
+
for start in range(0, num_checkpointed, every):
|
| 767 |
+
end = min(start + every - 1, num_checkpointed - 1)
|
| 768 |
+
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
|
| 769 |
+
if skip_last:
|
| 770 |
+
return run_function(end + 1, len(functions) - 1, functions)(x)
|
| 771 |
+
return x
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
|
| 775 |
+
prefix_is_tuple = isinstance(prefix, tuple)
|
| 776 |
+
if isinstance(module_types, str):
|
| 777 |
+
if module_types == 'container':
|
| 778 |
+
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
|
| 779 |
+
else:
|
| 780 |
+
module_types = (nn.Sequential,)
|
| 781 |
+
for name, module in named_modules:
|
| 782 |
+
if depth and isinstance(module, module_types):
|
| 783 |
+
yield from flatten_modules(
|
| 784 |
+
module.named_children(),
|
| 785 |
+
depth - 1,
|
| 786 |
+
prefix=(name,) if prefix_is_tuple else name,
|
| 787 |
+
module_types=module_types,
|
| 788 |
+
)
|
| 789 |
+
else:
|
| 790 |
+
if prefix_is_tuple:
|
| 791 |
+
name = prefix + (name,)
|
| 792 |
+
yield name, module
|
| 793 |
+
else:
|
| 794 |
+
if prefix:
|
| 795 |
+
name = '.'.join([prefix, name])
|
| 796 |
+
yield name, module
|
src/custom_timm/models/hrnet.py
ADDED
|
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" HRNet
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HRNet/HRNet-Image-Classification
|
| 4 |
+
|
| 5 |
+
Original header:
|
| 6 |
+
Copyright (c) Microsoft
|
| 7 |
+
Licensed under the MIT License.
|
| 8 |
+
Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
| 9 |
+
Modified by Ke Sun (sunk@mail.ustc.edu.cn)
|
| 10 |
+
"""
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 19 |
+
from .features import FeatureInfo
|
| 20 |
+
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
| 21 |
+
from .layers import create_classifier
|
| 22 |
+
from .registry import register_model
|
| 23 |
+
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
|
| 24 |
+
|
| 25 |
+
_BN_MOMENTUM = 0.1
|
| 26 |
+
_logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _cfg(url='', **kwargs):
|
| 30 |
+
return {
|
| 31 |
+
'url': url,
|
| 32 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 33 |
+
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
| 34 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 35 |
+
'first_conv': 'conv1', 'classifier': 'classifier',
|
| 36 |
+
**kwargs
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
default_cfgs = {
|
| 41 |
+
'hrnet_w18_small': _cfg(
|
| 42 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'),
|
| 43 |
+
'hrnet_w18_small_v2': _cfg(
|
| 44 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'),
|
| 45 |
+
'hrnet_w18': _cfg(
|
| 46 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'),
|
| 47 |
+
'hrnet_w30': _cfg(
|
| 48 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'),
|
| 49 |
+
'hrnet_w32': _cfg(
|
| 50 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'),
|
| 51 |
+
'hrnet_w40': _cfg(
|
| 52 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'),
|
| 53 |
+
'hrnet_w44': _cfg(
|
| 54 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'),
|
| 55 |
+
'hrnet_w48': _cfg(
|
| 56 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'),
|
| 57 |
+
'hrnet_w64': _cfg(
|
| 58 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
cfg_cls = dict(
|
| 62 |
+
hrnet_w18_small=dict(
|
| 63 |
+
STEM_WIDTH=64,
|
| 64 |
+
STAGE1=dict(
|
| 65 |
+
NUM_MODULES=1,
|
| 66 |
+
NUM_BRANCHES=1,
|
| 67 |
+
BLOCK='BOTTLENECK',
|
| 68 |
+
NUM_BLOCKS=(1,),
|
| 69 |
+
NUM_CHANNELS=(32,),
|
| 70 |
+
FUSE_METHOD='SUM',
|
| 71 |
+
),
|
| 72 |
+
STAGE2=dict(
|
| 73 |
+
NUM_MODULES=1,
|
| 74 |
+
NUM_BRANCHES=2,
|
| 75 |
+
BLOCK='BASIC',
|
| 76 |
+
NUM_BLOCKS=(2, 2),
|
| 77 |
+
NUM_CHANNELS=(16, 32),
|
| 78 |
+
FUSE_METHOD='SUM'
|
| 79 |
+
),
|
| 80 |
+
STAGE3=dict(
|
| 81 |
+
NUM_MODULES=1,
|
| 82 |
+
NUM_BRANCHES=3,
|
| 83 |
+
BLOCK='BASIC',
|
| 84 |
+
NUM_BLOCKS=(2, 2, 2),
|
| 85 |
+
NUM_CHANNELS=(16, 32, 64),
|
| 86 |
+
FUSE_METHOD='SUM'
|
| 87 |
+
),
|
| 88 |
+
STAGE4=dict(
|
| 89 |
+
NUM_MODULES=1,
|
| 90 |
+
NUM_BRANCHES=4,
|
| 91 |
+
BLOCK='BASIC',
|
| 92 |
+
NUM_BLOCKS=(2, 2, 2, 2),
|
| 93 |
+
NUM_CHANNELS=(16, 32, 64, 128),
|
| 94 |
+
FUSE_METHOD='SUM',
|
| 95 |
+
),
|
| 96 |
+
),
|
| 97 |
+
|
| 98 |
+
hrnet_w18_small_v2=dict(
|
| 99 |
+
STEM_WIDTH=64,
|
| 100 |
+
STAGE1=dict(
|
| 101 |
+
NUM_MODULES=1,
|
| 102 |
+
NUM_BRANCHES=1,
|
| 103 |
+
BLOCK='BOTTLENECK',
|
| 104 |
+
NUM_BLOCKS=(2,),
|
| 105 |
+
NUM_CHANNELS=(64,),
|
| 106 |
+
FUSE_METHOD='SUM',
|
| 107 |
+
),
|
| 108 |
+
STAGE2=dict(
|
| 109 |
+
NUM_MODULES=1,
|
| 110 |
+
NUM_BRANCHES=2,
|
| 111 |
+
BLOCK='BASIC',
|
| 112 |
+
NUM_BLOCKS=(2, 2),
|
| 113 |
+
NUM_CHANNELS=(18, 36),
|
| 114 |
+
FUSE_METHOD='SUM'
|
| 115 |
+
),
|
| 116 |
+
STAGE3=dict(
|
| 117 |
+
NUM_MODULES=3,
|
| 118 |
+
NUM_BRANCHES=3,
|
| 119 |
+
BLOCK='BASIC',
|
| 120 |
+
NUM_BLOCKS=(2, 2, 2),
|
| 121 |
+
NUM_CHANNELS=(18, 36, 72),
|
| 122 |
+
FUSE_METHOD='SUM'
|
| 123 |
+
),
|
| 124 |
+
STAGE4=dict(
|
| 125 |
+
NUM_MODULES=2,
|
| 126 |
+
NUM_BRANCHES=4,
|
| 127 |
+
BLOCK='BASIC',
|
| 128 |
+
NUM_BLOCKS=(2, 2, 2, 2),
|
| 129 |
+
NUM_CHANNELS=(18, 36, 72, 144),
|
| 130 |
+
FUSE_METHOD='SUM',
|
| 131 |
+
),
|
| 132 |
+
),
|
| 133 |
+
|
| 134 |
+
hrnet_w18=dict(
|
| 135 |
+
STEM_WIDTH=64,
|
| 136 |
+
STAGE1=dict(
|
| 137 |
+
NUM_MODULES=1,
|
| 138 |
+
NUM_BRANCHES=1,
|
| 139 |
+
BLOCK='BOTTLENECK',
|
| 140 |
+
NUM_BLOCKS=(4,),
|
| 141 |
+
NUM_CHANNELS=(64,),
|
| 142 |
+
FUSE_METHOD='SUM',
|
| 143 |
+
),
|
| 144 |
+
STAGE2=dict(
|
| 145 |
+
NUM_MODULES=1,
|
| 146 |
+
NUM_BRANCHES=2,
|
| 147 |
+
BLOCK='BASIC',
|
| 148 |
+
NUM_BLOCKS=(4, 4),
|
| 149 |
+
NUM_CHANNELS=(18, 36),
|
| 150 |
+
FUSE_METHOD='SUM'
|
| 151 |
+
),
|
| 152 |
+
STAGE3=dict(
|
| 153 |
+
NUM_MODULES=4,
|
| 154 |
+
NUM_BRANCHES=3,
|
| 155 |
+
BLOCK='BASIC',
|
| 156 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 157 |
+
NUM_CHANNELS=(18, 36, 72),
|
| 158 |
+
FUSE_METHOD='SUM'
|
| 159 |
+
),
|
| 160 |
+
STAGE4=dict(
|
| 161 |
+
NUM_MODULES=3,
|
| 162 |
+
NUM_BRANCHES=4,
|
| 163 |
+
BLOCK='BASIC',
|
| 164 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 165 |
+
NUM_CHANNELS=(18, 36, 72, 144),
|
| 166 |
+
FUSE_METHOD='SUM',
|
| 167 |
+
),
|
| 168 |
+
),
|
| 169 |
+
|
| 170 |
+
hrnet_w30=dict(
|
| 171 |
+
STEM_WIDTH=64,
|
| 172 |
+
STAGE1=dict(
|
| 173 |
+
NUM_MODULES=1,
|
| 174 |
+
NUM_BRANCHES=1,
|
| 175 |
+
BLOCK='BOTTLENECK',
|
| 176 |
+
NUM_BLOCKS=(4,),
|
| 177 |
+
NUM_CHANNELS=(64,),
|
| 178 |
+
FUSE_METHOD='SUM',
|
| 179 |
+
),
|
| 180 |
+
STAGE2=dict(
|
| 181 |
+
NUM_MODULES=1,
|
| 182 |
+
NUM_BRANCHES=2,
|
| 183 |
+
BLOCK='BASIC',
|
| 184 |
+
NUM_BLOCKS=(4, 4),
|
| 185 |
+
NUM_CHANNELS=(30, 60),
|
| 186 |
+
FUSE_METHOD='SUM'
|
| 187 |
+
),
|
| 188 |
+
STAGE3=dict(
|
| 189 |
+
NUM_MODULES=4,
|
| 190 |
+
NUM_BRANCHES=3,
|
| 191 |
+
BLOCK='BASIC',
|
| 192 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 193 |
+
NUM_CHANNELS=(30, 60, 120),
|
| 194 |
+
FUSE_METHOD='SUM'
|
| 195 |
+
),
|
| 196 |
+
STAGE4=dict(
|
| 197 |
+
NUM_MODULES=3,
|
| 198 |
+
NUM_BRANCHES=4,
|
| 199 |
+
BLOCK='BASIC',
|
| 200 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 201 |
+
NUM_CHANNELS=(30, 60, 120, 240),
|
| 202 |
+
FUSE_METHOD='SUM',
|
| 203 |
+
),
|
| 204 |
+
),
|
| 205 |
+
|
| 206 |
+
hrnet_w32=dict(
|
| 207 |
+
STEM_WIDTH=64,
|
| 208 |
+
STAGE1=dict(
|
| 209 |
+
NUM_MODULES=1,
|
| 210 |
+
NUM_BRANCHES=1,
|
| 211 |
+
BLOCK='BOTTLENECK',
|
| 212 |
+
NUM_BLOCKS=(4,),
|
| 213 |
+
NUM_CHANNELS=(64,),
|
| 214 |
+
FUSE_METHOD='SUM',
|
| 215 |
+
),
|
| 216 |
+
STAGE2=dict(
|
| 217 |
+
NUM_MODULES=1,
|
| 218 |
+
NUM_BRANCHES=2,
|
| 219 |
+
BLOCK='BASIC',
|
| 220 |
+
NUM_BLOCKS=(4, 4),
|
| 221 |
+
NUM_CHANNELS=(32, 64),
|
| 222 |
+
FUSE_METHOD='SUM'
|
| 223 |
+
),
|
| 224 |
+
STAGE3=dict(
|
| 225 |
+
NUM_MODULES=4,
|
| 226 |
+
NUM_BRANCHES=3,
|
| 227 |
+
BLOCK='BASIC',
|
| 228 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 229 |
+
NUM_CHANNELS=(32, 64, 128),
|
| 230 |
+
FUSE_METHOD='SUM'
|
| 231 |
+
),
|
| 232 |
+
STAGE4=dict(
|
| 233 |
+
NUM_MODULES=3,
|
| 234 |
+
NUM_BRANCHES=4,
|
| 235 |
+
BLOCK='BASIC',
|
| 236 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 237 |
+
NUM_CHANNELS=(32, 64, 128, 256),
|
| 238 |
+
FUSE_METHOD='SUM',
|
| 239 |
+
),
|
| 240 |
+
),
|
| 241 |
+
|
| 242 |
+
hrnet_w40=dict(
|
| 243 |
+
STEM_WIDTH=64,
|
| 244 |
+
STAGE1=dict(
|
| 245 |
+
NUM_MODULES=1,
|
| 246 |
+
NUM_BRANCHES=1,
|
| 247 |
+
BLOCK='BOTTLENECK',
|
| 248 |
+
NUM_BLOCKS=(4,),
|
| 249 |
+
NUM_CHANNELS=(64,),
|
| 250 |
+
FUSE_METHOD='SUM',
|
| 251 |
+
),
|
| 252 |
+
STAGE2=dict(
|
| 253 |
+
NUM_MODULES=1,
|
| 254 |
+
NUM_BRANCHES=2,
|
| 255 |
+
BLOCK='BASIC',
|
| 256 |
+
NUM_BLOCKS=(4, 4),
|
| 257 |
+
NUM_CHANNELS=(40, 80),
|
| 258 |
+
FUSE_METHOD='SUM'
|
| 259 |
+
),
|
| 260 |
+
STAGE3=dict(
|
| 261 |
+
NUM_MODULES=4,
|
| 262 |
+
NUM_BRANCHES=3,
|
| 263 |
+
BLOCK='BASIC',
|
| 264 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 265 |
+
NUM_CHANNELS=(40, 80, 160),
|
| 266 |
+
FUSE_METHOD='SUM'
|
| 267 |
+
),
|
| 268 |
+
STAGE4=dict(
|
| 269 |
+
NUM_MODULES=3,
|
| 270 |
+
NUM_BRANCHES=4,
|
| 271 |
+
BLOCK='BASIC',
|
| 272 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 273 |
+
NUM_CHANNELS=(40, 80, 160, 320),
|
| 274 |
+
FUSE_METHOD='SUM',
|
| 275 |
+
),
|
| 276 |
+
),
|
| 277 |
+
|
| 278 |
+
hrnet_w44=dict(
|
| 279 |
+
STEM_WIDTH=64,
|
| 280 |
+
STAGE1=dict(
|
| 281 |
+
NUM_MODULES=1,
|
| 282 |
+
NUM_BRANCHES=1,
|
| 283 |
+
BLOCK='BOTTLENECK',
|
| 284 |
+
NUM_BLOCKS=(4,),
|
| 285 |
+
NUM_CHANNELS=(64,),
|
| 286 |
+
FUSE_METHOD='SUM',
|
| 287 |
+
),
|
| 288 |
+
STAGE2=dict(
|
| 289 |
+
NUM_MODULES=1,
|
| 290 |
+
NUM_BRANCHES=2,
|
| 291 |
+
BLOCK='BASIC',
|
| 292 |
+
NUM_BLOCKS=(4, 4),
|
| 293 |
+
NUM_CHANNELS=(44, 88),
|
| 294 |
+
FUSE_METHOD='SUM'
|
| 295 |
+
),
|
| 296 |
+
STAGE3=dict(
|
| 297 |
+
NUM_MODULES=4,
|
| 298 |
+
NUM_BRANCHES=3,
|
| 299 |
+
BLOCK='BASIC',
|
| 300 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 301 |
+
NUM_CHANNELS=(44, 88, 176),
|
| 302 |
+
FUSE_METHOD='SUM'
|
| 303 |
+
),
|
| 304 |
+
STAGE4=dict(
|
| 305 |
+
NUM_MODULES=3,
|
| 306 |
+
NUM_BRANCHES=4,
|
| 307 |
+
BLOCK='BASIC',
|
| 308 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 309 |
+
NUM_CHANNELS=(44, 88, 176, 352),
|
| 310 |
+
FUSE_METHOD='SUM',
|
| 311 |
+
),
|
| 312 |
+
),
|
| 313 |
+
|
| 314 |
+
hrnet_w48=dict(
|
| 315 |
+
STEM_WIDTH=64,
|
| 316 |
+
STAGE1=dict(
|
| 317 |
+
NUM_MODULES=1,
|
| 318 |
+
NUM_BRANCHES=1,
|
| 319 |
+
BLOCK='BOTTLENECK',
|
| 320 |
+
NUM_BLOCKS=(4,),
|
| 321 |
+
NUM_CHANNELS=(64,),
|
| 322 |
+
FUSE_METHOD='SUM',
|
| 323 |
+
),
|
| 324 |
+
STAGE2=dict(
|
| 325 |
+
NUM_MODULES=1,
|
| 326 |
+
NUM_BRANCHES=2,
|
| 327 |
+
BLOCK='BASIC',
|
| 328 |
+
NUM_BLOCKS=(4, 4),
|
| 329 |
+
NUM_CHANNELS=(48, 96),
|
| 330 |
+
FUSE_METHOD='SUM'
|
| 331 |
+
),
|
| 332 |
+
STAGE3=dict(
|
| 333 |
+
NUM_MODULES=4,
|
| 334 |
+
NUM_BRANCHES=3,
|
| 335 |
+
BLOCK='BASIC',
|
| 336 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 337 |
+
NUM_CHANNELS=(48, 96, 192),
|
| 338 |
+
FUSE_METHOD='SUM'
|
| 339 |
+
),
|
| 340 |
+
STAGE4=dict(
|
| 341 |
+
NUM_MODULES=3,
|
| 342 |
+
NUM_BRANCHES=4,
|
| 343 |
+
BLOCK='BASIC',
|
| 344 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 345 |
+
NUM_CHANNELS=(48, 96, 192, 384),
|
| 346 |
+
FUSE_METHOD='SUM',
|
| 347 |
+
),
|
| 348 |
+
),
|
| 349 |
+
|
| 350 |
+
hrnet_w64=dict(
|
| 351 |
+
STEM_WIDTH=64,
|
| 352 |
+
STAGE1=dict(
|
| 353 |
+
NUM_MODULES=1,
|
| 354 |
+
NUM_BRANCHES=1,
|
| 355 |
+
BLOCK='BOTTLENECK',
|
| 356 |
+
NUM_BLOCKS=(4,),
|
| 357 |
+
NUM_CHANNELS=(64,),
|
| 358 |
+
FUSE_METHOD='SUM',
|
| 359 |
+
),
|
| 360 |
+
STAGE2=dict(
|
| 361 |
+
NUM_MODULES=1,
|
| 362 |
+
NUM_BRANCHES=2,
|
| 363 |
+
BLOCK='BASIC',
|
| 364 |
+
NUM_BLOCKS=(4, 4),
|
| 365 |
+
NUM_CHANNELS=(64, 128),
|
| 366 |
+
FUSE_METHOD='SUM'
|
| 367 |
+
),
|
| 368 |
+
STAGE3=dict(
|
| 369 |
+
NUM_MODULES=4,
|
| 370 |
+
NUM_BRANCHES=3,
|
| 371 |
+
BLOCK='BASIC',
|
| 372 |
+
NUM_BLOCKS=(4, 4, 4),
|
| 373 |
+
NUM_CHANNELS=(64, 128, 256),
|
| 374 |
+
FUSE_METHOD='SUM'
|
| 375 |
+
),
|
| 376 |
+
STAGE4=dict(
|
| 377 |
+
NUM_MODULES=3,
|
| 378 |
+
NUM_BRANCHES=4,
|
| 379 |
+
BLOCK='BASIC',
|
| 380 |
+
NUM_BLOCKS=(4, 4, 4, 4),
|
| 381 |
+
NUM_CHANNELS=(64, 128, 256, 512),
|
| 382 |
+
FUSE_METHOD='SUM',
|
| 383 |
+
),
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class HighResolutionModule(nn.Module):
|
| 389 |
+
def __init__(self, num_branches, blocks, num_blocks, num_in_chs,
|
| 390 |
+
num_channels, fuse_method, multi_scale_output=True):
|
| 391 |
+
super(HighResolutionModule, self).__init__()
|
| 392 |
+
self._check_branches(
|
| 393 |
+
num_branches, blocks, num_blocks, num_in_chs, num_channels)
|
| 394 |
+
|
| 395 |
+
self.num_in_chs = num_in_chs
|
| 396 |
+
self.fuse_method = fuse_method
|
| 397 |
+
self.num_branches = num_branches
|
| 398 |
+
|
| 399 |
+
self.multi_scale_output = multi_scale_output
|
| 400 |
+
|
| 401 |
+
self.branches = self._make_branches(
|
| 402 |
+
num_branches, blocks, num_blocks, num_channels)
|
| 403 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 404 |
+
self.fuse_act = nn.ReLU(False)
|
| 405 |
+
|
| 406 |
+
def _check_branches(self, num_branches, blocks, num_blocks, num_in_chs, num_channels):
|
| 407 |
+
error_msg = ''
|
| 408 |
+
if num_branches != len(num_blocks):
|
| 409 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks))
|
| 410 |
+
elif num_branches != len(num_channels):
|
| 411 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels))
|
| 412 |
+
elif num_branches != len(num_in_chs):
|
| 413 |
+
error_msg = 'NUM_BRANCHES({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs))
|
| 414 |
+
if error_msg:
|
| 415 |
+
_logger.error(error_msg)
|
| 416 |
+
raise ValueError(error_msg)
|
| 417 |
+
|
| 418 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
|
| 419 |
+
downsample = None
|
| 420 |
+
if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block.expansion:
|
| 421 |
+
downsample = nn.Sequential(
|
| 422 |
+
nn.Conv2d(
|
| 423 |
+
self.num_in_chs[branch_index], num_channels[branch_index] * block.expansion,
|
| 424 |
+
kernel_size=1, stride=stride, bias=False),
|
| 425 |
+
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM),
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
layers = [block(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)]
|
| 429 |
+
self.num_in_chs[branch_index] = num_channels[branch_index] * block.expansion
|
| 430 |
+
for i in range(1, num_blocks[branch_index]):
|
| 431 |
+
layers.append(block(self.num_in_chs[branch_index], num_channels[branch_index]))
|
| 432 |
+
|
| 433 |
+
return nn.Sequential(*layers)
|
| 434 |
+
|
| 435 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 436 |
+
branches = []
|
| 437 |
+
for i in range(num_branches):
|
| 438 |
+
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
|
| 439 |
+
|
| 440 |
+
return nn.ModuleList(branches)
|
| 441 |
+
|
| 442 |
+
def _make_fuse_layers(self):
|
| 443 |
+
if self.num_branches == 1:
|
| 444 |
+
return nn.Identity()
|
| 445 |
+
|
| 446 |
+
num_branches = self.num_branches
|
| 447 |
+
num_in_chs = self.num_in_chs
|
| 448 |
+
fuse_layers = []
|
| 449 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 450 |
+
fuse_layer = []
|
| 451 |
+
for j in range(num_branches):
|
| 452 |
+
if j > i:
|
| 453 |
+
fuse_layer.append(nn.Sequential(
|
| 454 |
+
nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False),
|
| 455 |
+
nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM),
|
| 456 |
+
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
|
| 457 |
+
elif j == i:
|
| 458 |
+
fuse_layer.append(nn.Identity())
|
| 459 |
+
else:
|
| 460 |
+
conv3x3s = []
|
| 461 |
+
for k in range(i - j):
|
| 462 |
+
if k == i - j - 1:
|
| 463 |
+
num_outchannels_conv3x3 = num_in_chs[i]
|
| 464 |
+
conv3x3s.append(nn.Sequential(
|
| 465 |
+
nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
|
| 466 |
+
nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM)))
|
| 467 |
+
else:
|
| 468 |
+
num_outchannels_conv3x3 = num_in_chs[j]
|
| 469 |
+
conv3x3s.append(nn.Sequential(
|
| 470 |
+
nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
|
| 471 |
+
nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM),
|
| 472 |
+
nn.ReLU(False)))
|
| 473 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 474 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 475 |
+
|
| 476 |
+
return nn.ModuleList(fuse_layers)
|
| 477 |
+
|
| 478 |
+
def get_num_in_chs(self):
|
| 479 |
+
return self.num_in_chs
|
| 480 |
+
|
| 481 |
+
def forward(self, x: List[torch.Tensor]):
|
| 482 |
+
if self.num_branches == 1:
|
| 483 |
+
return [self.branches[0](x[0])]
|
| 484 |
+
|
| 485 |
+
for i, branch in enumerate(self.branches):
|
| 486 |
+
x[i] = branch(x[i])
|
| 487 |
+
|
| 488 |
+
x_fuse = []
|
| 489 |
+
for i, fuse_outer in enumerate(self.fuse_layers):
|
| 490 |
+
y = x[0] if i == 0 else fuse_outer[0](x[0])
|
| 491 |
+
for j in range(1, self.num_branches):
|
| 492 |
+
if i == j:
|
| 493 |
+
y = y + x[j]
|
| 494 |
+
else:
|
| 495 |
+
y = y + fuse_outer[j](x[j])
|
| 496 |
+
x_fuse.append(self.fuse_act(y))
|
| 497 |
+
|
| 498 |
+
return x_fuse
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
blocks_dict = {
|
| 502 |
+
'BASIC': BasicBlock,
|
| 503 |
+
'BOTTLENECK': Bottleneck
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class HighResolutionNet(nn.Module):
|
| 508 |
+
|
| 509 |
+
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'):
|
| 510 |
+
super(HighResolutionNet, self).__init__()
|
| 511 |
+
self.num_classes = num_classes
|
| 512 |
+
self.drop_rate = drop_rate
|
| 513 |
+
|
| 514 |
+
stem_width = cfg['STEM_WIDTH']
|
| 515 |
+
self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False)
|
| 516 |
+
self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM)
|
| 517 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 518 |
+
self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 519 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM)
|
| 520 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 521 |
+
|
| 522 |
+
self.stage1_cfg = cfg['STAGE1']
|
| 523 |
+
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
|
| 524 |
+
block = blocks_dict[self.stage1_cfg['BLOCK']]
|
| 525 |
+
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
|
| 526 |
+
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
| 527 |
+
stage1_out_channel = block.expansion * num_channels
|
| 528 |
+
|
| 529 |
+
self.stage2_cfg = cfg['STAGE2']
|
| 530 |
+
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
| 531 |
+
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
| 532 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
| 533 |
+
self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
|
| 534 |
+
self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
|
| 535 |
+
|
| 536 |
+
self.stage3_cfg = cfg['STAGE3']
|
| 537 |
+
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
| 538 |
+
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
| 539 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
| 540 |
+
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
|
| 541 |
+
self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
|
| 542 |
+
|
| 543 |
+
self.stage4_cfg = cfg['STAGE4']
|
| 544 |
+
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
| 545 |
+
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
| 546 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
| 547 |
+
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
|
| 548 |
+
self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
|
| 549 |
+
|
| 550 |
+
self.head = head
|
| 551 |
+
self.head_channels = None # set if _make_head called
|
| 552 |
+
if head == 'classification':
|
| 553 |
+
# Classification Head
|
| 554 |
+
self.num_features = 2048
|
| 555 |
+
self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels)
|
| 556 |
+
self.global_pool, self.classifier = create_classifier(
|
| 557 |
+
self.num_features, self.num_classes, pool_type=global_pool)
|
| 558 |
+
elif head == 'incre':
|
| 559 |
+
self.num_features = 2048
|
| 560 |
+
self.incre_modules, _, _ = self._make_head(pre_stage_channels, True)
|
| 561 |
+
else:
|
| 562 |
+
self.incre_modules = None
|
| 563 |
+
self.num_features = 256
|
| 564 |
+
|
| 565 |
+
curr_stride = 2
|
| 566 |
+
# module names aren't actually valid here, hook or FeatureNet based extraction would not work
|
| 567 |
+
self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')]
|
| 568 |
+
for i, c in enumerate(self.head_channels if self.head_channels else num_channels):
|
| 569 |
+
curr_stride *= 2
|
| 570 |
+
c = c * 4 if self.head_channels else c # head block expansion factor of 4
|
| 571 |
+
self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')]
|
| 572 |
+
|
| 573 |
+
self.init_weights()
|
| 574 |
+
|
| 575 |
+
def _make_head(self, pre_stage_channels, incre_only=False):
|
| 576 |
+
head_block = Bottleneck
|
| 577 |
+
self.head_channels = [32, 64, 128, 256]
|
| 578 |
+
|
| 579 |
+
# Increasing the #channels on each resolution
|
| 580 |
+
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
|
| 581 |
+
incre_modules = []
|
| 582 |
+
for i, channels in enumerate(pre_stage_channels):
|
| 583 |
+
incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1))
|
| 584 |
+
incre_modules = nn.ModuleList(incre_modules)
|
| 585 |
+
if incre_only:
|
| 586 |
+
return incre_modules, None, None
|
| 587 |
+
|
| 588 |
+
# downsampling modules
|
| 589 |
+
downsamp_modules = []
|
| 590 |
+
for i in range(len(pre_stage_channels) - 1):
|
| 591 |
+
in_channels = self.head_channels[i] * head_block.expansion
|
| 592 |
+
out_channels = self.head_channels[i + 1] * head_block.expansion
|
| 593 |
+
downsamp_module = nn.Sequential(
|
| 594 |
+
nn.Conv2d(
|
| 595 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1),
|
| 596 |
+
nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM),
|
| 597 |
+
nn.ReLU(inplace=True)
|
| 598 |
+
)
|
| 599 |
+
downsamp_modules.append(downsamp_module)
|
| 600 |
+
downsamp_modules = nn.ModuleList(downsamp_modules)
|
| 601 |
+
|
| 602 |
+
final_layer = nn.Sequential(
|
| 603 |
+
nn.Conv2d(
|
| 604 |
+
in_channels=self.head_channels[3] * head_block.expansion,
|
| 605 |
+
out_channels=self.num_features, kernel_size=1, stride=1, padding=0
|
| 606 |
+
),
|
| 607 |
+
nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM),
|
| 608 |
+
nn.ReLU(inplace=True)
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
return incre_modules, downsamp_modules, final_layer
|
| 612 |
+
|
| 613 |
+
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
|
| 614 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 615 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 616 |
+
|
| 617 |
+
transition_layers = []
|
| 618 |
+
for i in range(num_branches_cur):
|
| 619 |
+
if i < num_branches_pre:
|
| 620 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 621 |
+
transition_layers.append(nn.Sequential(
|
| 622 |
+
nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
|
| 623 |
+
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
|
| 624 |
+
nn.ReLU(inplace=True)))
|
| 625 |
+
else:
|
| 626 |
+
transition_layers.append(nn.Identity())
|
| 627 |
+
else:
|
| 628 |
+
conv3x3s = []
|
| 629 |
+
for j in range(i + 1 - num_branches_pre):
|
| 630 |
+
inchannels = num_channels_pre_layer[-1]
|
| 631 |
+
outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
|
| 632 |
+
conv3x3s.append(nn.Sequential(
|
| 633 |
+
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
|
| 634 |
+
nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM),
|
| 635 |
+
nn.ReLU(inplace=True)))
|
| 636 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 637 |
+
|
| 638 |
+
return nn.ModuleList(transition_layers)
|
| 639 |
+
|
| 640 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
| 641 |
+
downsample = None
|
| 642 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
| 643 |
+
downsample = nn.Sequential(
|
| 644 |
+
nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
| 645 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM),
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
layers = [block(inplanes, planes, stride, downsample)]
|
| 649 |
+
inplanes = planes * block.expansion
|
| 650 |
+
for i in range(1, blocks):
|
| 651 |
+
layers.append(block(inplanes, planes))
|
| 652 |
+
|
| 653 |
+
return nn.Sequential(*layers)
|
| 654 |
+
|
| 655 |
+
def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True):
|
| 656 |
+
num_modules = layer_config['NUM_MODULES']
|
| 657 |
+
num_branches = layer_config['NUM_BRANCHES']
|
| 658 |
+
num_blocks = layer_config['NUM_BLOCKS']
|
| 659 |
+
num_channels = layer_config['NUM_CHANNELS']
|
| 660 |
+
block = blocks_dict[layer_config['BLOCK']]
|
| 661 |
+
fuse_method = layer_config['FUSE_METHOD']
|
| 662 |
+
|
| 663 |
+
modules = []
|
| 664 |
+
for i in range(num_modules):
|
| 665 |
+
# multi_scale_output is only used last module
|
| 666 |
+
reset_multi_scale_output = multi_scale_output or i < num_modules - 1
|
| 667 |
+
modules.append(HighResolutionModule(
|
| 668 |
+
num_branches, block, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output)
|
| 669 |
+
)
|
| 670 |
+
num_in_chs = modules[-1].get_num_in_chs()
|
| 671 |
+
|
| 672 |
+
return nn.Sequential(*modules), num_in_chs
|
| 673 |
+
|
| 674 |
+
@torch.jit.ignore
|
| 675 |
+
def init_weights(self):
|
| 676 |
+
for m in self.modules():
|
| 677 |
+
if isinstance(m, nn.Conv2d):
|
| 678 |
+
nn.init.kaiming_normal_(
|
| 679 |
+
m.weight, mode='fan_out', nonlinearity='relu')
|
| 680 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 681 |
+
nn.init.constant_(m.weight, 1)
|
| 682 |
+
nn.init.constant_(m.bias, 0)
|
| 683 |
+
|
| 684 |
+
@torch.jit.ignore
|
| 685 |
+
def group_matcher(self, coarse=False):
|
| 686 |
+
matcher = dict(
|
| 687 |
+
stem=r'^conv[12]|bn[12]',
|
| 688 |
+
blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
|
| 689 |
+
(r'^layer(\d+)\.(\d+)', None),
|
| 690 |
+
(r'^stage(\d+)\.(\d+)', None),
|
| 691 |
+
(r'^transition(\d+)', (99999,)),
|
| 692 |
+
],
|
| 693 |
+
)
|
| 694 |
+
return matcher
|
| 695 |
+
|
| 696 |
+
@torch.jit.ignore
|
| 697 |
+
def set_grad_checkpointing(self, enable=True):
|
| 698 |
+
assert not enable, "gradient checkpointing not supported"
|
| 699 |
+
|
| 700 |
+
@torch.jit.ignore
|
| 701 |
+
def get_classifier(self):
|
| 702 |
+
return self.classifier
|
| 703 |
+
|
| 704 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 705 |
+
self.num_classes = num_classes
|
| 706 |
+
self.global_pool, self.classifier = create_classifier(
|
| 707 |
+
self.num_features, self.num_classes, pool_type=global_pool)
|
| 708 |
+
|
| 709 |
+
def stages(self, x) -> List[torch.Tensor]:
|
| 710 |
+
x = self.layer1(x)
|
| 711 |
+
|
| 712 |
+
xl = [t(x) for i, t in enumerate(self.transition1)]
|
| 713 |
+
yl = self.stage2(xl)
|
| 714 |
+
|
| 715 |
+
xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)]
|
| 716 |
+
yl = self.stage3(xl)
|
| 717 |
+
|
| 718 |
+
xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)]
|
| 719 |
+
yl = self.stage4(xl)
|
| 720 |
+
return yl
|
| 721 |
+
|
| 722 |
+
def forward_features(self, x):
|
| 723 |
+
# Stem
|
| 724 |
+
x = self.conv1(x)
|
| 725 |
+
x = self.bn1(x)
|
| 726 |
+
x = self.act1(x)
|
| 727 |
+
x = self.conv2(x)
|
| 728 |
+
x = self.bn2(x)
|
| 729 |
+
x = self.act2(x)
|
| 730 |
+
|
| 731 |
+
# Stages
|
| 732 |
+
yl = self.stages(x)
|
| 733 |
+
if self.incre_modules is None or self.downsamp_modules is None:
|
| 734 |
+
return yl
|
| 735 |
+
y = self.incre_modules[0](yl[0])
|
| 736 |
+
for i, down in enumerate(self.downsamp_modules):
|
| 737 |
+
y = self.incre_modules[i + 1](yl[i + 1]) + down(y)
|
| 738 |
+
y = self.final_layer(y)
|
| 739 |
+
return y
|
| 740 |
+
|
| 741 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 742 |
+
# Classification Head
|
| 743 |
+
x = self.global_pool(x)
|
| 744 |
+
if self.drop_rate > 0.:
|
| 745 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
| 746 |
+
return x if pre_logits else self.classifier(x)
|
| 747 |
+
|
| 748 |
+
def forward(self, x):
|
| 749 |
+
y = self.forward_features(x)
|
| 750 |
+
x = self.forward_head(y)
|
| 751 |
+
return x
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
class HighResolutionNetFeatures(HighResolutionNet):
|
| 755 |
+
"""HighResolutionNet feature extraction
|
| 756 |
+
|
| 757 |
+
The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so.
|
| 758 |
+
It would be more complicated to use the FeatureNet helpers.
|
| 759 |
+
|
| 760 |
+
The `feature_location=incre` allows grabbing increased channel count features using part of the
|
| 761 |
+
classification head. If `feature_location=''` the default HRNet features are returned. First stem
|
| 762 |
+
conv is used for stride 2 features.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0,
|
| 766 |
+
feature_location='incre', out_indices=(0, 1, 2, 3, 4)):
|
| 767 |
+
assert feature_location in ('incre', '')
|
| 768 |
+
super(HighResolutionNetFeatures, self).__init__(
|
| 769 |
+
cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
|
| 770 |
+
drop_rate=drop_rate, head=feature_location)
|
| 771 |
+
self.feature_info = FeatureInfo(self.feature_info, out_indices)
|
| 772 |
+
self._out_idx = {i for i in out_indices}
|
| 773 |
+
|
| 774 |
+
def forward_features(self, x):
|
| 775 |
+
assert False, 'Not supported'
|
| 776 |
+
|
| 777 |
+
def forward(self, x) -> List[torch.tensor]:
|
| 778 |
+
out = []
|
| 779 |
+
x = self.conv1(x)
|
| 780 |
+
x = self.bn1(x)
|
| 781 |
+
x = self.act1(x)
|
| 782 |
+
if 0 in self._out_idx:
|
| 783 |
+
out.append(x)
|
| 784 |
+
x = self.conv2(x)
|
| 785 |
+
x = self.bn2(x)
|
| 786 |
+
x = self.act2(x)
|
| 787 |
+
x = self.stages(x)
|
| 788 |
+
if self.incre_modules is not None:
|
| 789 |
+
x = [incre(f) for f, incre in zip(x, self.incre_modules)]
|
| 790 |
+
for i, f in enumerate(x):
|
| 791 |
+
if i + 1 in self._out_idx:
|
| 792 |
+
out.append(f)
|
| 793 |
+
return out
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def _create_hrnet(variant, pretrained, **model_kwargs):
|
| 797 |
+
model_cls = HighResolutionNet
|
| 798 |
+
features_only = False
|
| 799 |
+
kwargs_filter = None
|
| 800 |
+
if model_kwargs.pop('features_only', False):
|
| 801 |
+
model_cls = HighResolutionNetFeatures
|
| 802 |
+
kwargs_filter = ('num_classes', 'global_pool')
|
| 803 |
+
features_only = True
|
| 804 |
+
model = build_model_with_cfg(
|
| 805 |
+
model_cls, variant, pretrained,
|
| 806 |
+
model_cfg=cfg_cls[variant],
|
| 807 |
+
pretrained_strict=not features_only,
|
| 808 |
+
kwargs_filter=kwargs_filter,
|
| 809 |
+
**model_kwargs)
|
| 810 |
+
if features_only:
|
| 811 |
+
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
|
| 812 |
+
model.default_cfg = model.pretrained_cfg # backwards compat
|
| 813 |
+
return model
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
@register_model
|
| 817 |
+
def hrnet_w18_small(pretrained=False, **kwargs):
|
| 818 |
+
return _create_hrnet('hrnet_w18_small', pretrained, **kwargs)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
@register_model
|
| 822 |
+
def hrnet_w18_small_v2(pretrained=False, **kwargs):
|
| 823 |
+
return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
@register_model
|
| 827 |
+
def hrnet_w18(pretrained=False, **kwargs):
|
| 828 |
+
return _create_hrnet('hrnet_w18', pretrained, **kwargs)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
@register_model
|
| 832 |
+
def hrnet_w30(pretrained=False, **kwargs):
|
| 833 |
+
return _create_hrnet('hrnet_w30', pretrained, **kwargs)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
@register_model
|
| 837 |
+
def hrnet_w32(pretrained=False, **kwargs):
|
| 838 |
+
return _create_hrnet('hrnet_w32', pretrained, **kwargs)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
@register_model
|
| 842 |
+
def hrnet_w40(pretrained=False, **kwargs):
|
| 843 |
+
return _create_hrnet('hrnet_w40', pretrained, **kwargs)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
@register_model
|
| 847 |
+
def hrnet_w44(pretrained=False, **kwargs):
|
| 848 |
+
return _create_hrnet('hrnet_w44', pretrained, **kwargs)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
@register_model
|
| 852 |
+
def hrnet_w48(pretrained=False, **kwargs):
|
| 853 |
+
return _create_hrnet('hrnet_w48', pretrained, **kwargs)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
@register_model
|
| 857 |
+
def hrnet_w64(pretrained=False, **kwargs):
|
| 858 |
+
return _create_hrnet('hrnet_w64', pretrained, **kwargs)
|
src/custom_timm/models/hub.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from functools import partial
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tempfile import TemporaryDirectory
|
| 7 |
+
from typing import Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from torch.hub import get_dir
|
| 14 |
+
except ImportError:
|
| 15 |
+
from torch.hub import _get_torch_home as get_dir
|
| 16 |
+
|
| 17 |
+
from custom_timm import __version__
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from huggingface_hub import (create_repo, get_hf_file_metadata,
|
| 21 |
+
hf_hub_download, hf_hub_url,
|
| 22 |
+
repo_type_and_id_from_hf_id, upload_folder)
|
| 23 |
+
from huggingface_hub.utils import EntryNotFoundError
|
| 24 |
+
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
| 25 |
+
_has_hf_hub = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
hf_hub_download = None
|
| 28 |
+
_has_hf_hub = False
|
| 29 |
+
|
| 30 |
+
_logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_cache_dir(child_dir=''):
|
| 34 |
+
"""
|
| 35 |
+
Returns the location of the directory where models are cached (and creates it if necessary).
|
| 36 |
+
"""
|
| 37 |
+
# Issue warning to move data if old env is set
|
| 38 |
+
if os.getenv('TORCH_MODEL_ZOO'):
|
| 39 |
+
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
| 40 |
+
|
| 41 |
+
hub_dir = get_dir()
|
| 42 |
+
child_dir = () if not child_dir else (child_dir,)
|
| 43 |
+
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
| 44 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 45 |
+
return model_dir
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
| 49 |
+
parts = urlparse(url)
|
| 50 |
+
filename = os.path.basename(parts.path)
|
| 51 |
+
cached_file = os.path.join(get_cache_dir(), filename)
|
| 52 |
+
if not os.path.exists(cached_file):
|
| 53 |
+
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
| 54 |
+
hash_prefix = None
|
| 55 |
+
if check_hash:
|
| 56 |
+
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
| 57 |
+
hash_prefix = r.group(1) if r else None
|
| 58 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
| 59 |
+
return cached_file
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def has_hf_hub(necessary=False):
|
| 63 |
+
if not _has_hf_hub and necessary:
|
| 64 |
+
# if no HF Hub module installed, and it is necessary to continue, raise error
|
| 65 |
+
raise RuntimeError(
|
| 66 |
+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
| 67 |
+
return _has_hf_hub
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def hf_split(hf_id):
|
| 71 |
+
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
|
| 72 |
+
rev_split = hf_id.split('@')
|
| 73 |
+
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
| 74 |
+
hf_model_id = rev_split[0]
|
| 75 |
+
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
| 76 |
+
return hf_model_id, hf_revision
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
| 80 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 81 |
+
text = reader.read()
|
| 82 |
+
return json.loads(text)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _download_from_hf(model_id: str, filename: str):
|
| 86 |
+
hf_model_id, hf_revision = hf_split(model_id)
|
| 87 |
+
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_model_config_from_hf(model_id: str):
|
| 91 |
+
assert has_hf_hub(True)
|
| 92 |
+
cached_file = _download_from_hf(model_id, 'config.json')
|
| 93 |
+
pretrained_cfg = load_cfg_from_json(cached_file)
|
| 94 |
+
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
| 95 |
+
pretrained_cfg['source'] = 'hf-hub'
|
| 96 |
+
model_name = pretrained_cfg.get('architecture')
|
| 97 |
+
return pretrained_cfg, model_name
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
| 101 |
+
assert has_hf_hub(True)
|
| 102 |
+
cached_file = _download_from_hf(model_id, filename)
|
| 103 |
+
state_dict = torch.load(cached_file, map_location='cpu')
|
| 104 |
+
return state_dict
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def save_for_hf(model, save_directory, model_config=None):
|
| 108 |
+
assert has_hf_hub(True)
|
| 109 |
+
model_config = model_config or {}
|
| 110 |
+
save_directory = Path(save_directory)
|
| 111 |
+
save_directory.mkdir(exist_ok=True, parents=True)
|
| 112 |
+
|
| 113 |
+
weights_path = save_directory / 'pytorch_model.bin'
|
| 114 |
+
torch.save(model.state_dict(), weights_path)
|
| 115 |
+
|
| 116 |
+
config_path = save_directory / 'config.json'
|
| 117 |
+
hf_config = model.pretrained_cfg
|
| 118 |
+
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
|
| 119 |
+
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
|
| 120 |
+
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
|
| 121 |
+
hf_config.update(model_config)
|
| 122 |
+
|
| 123 |
+
with config_path.open('w') as f:
|
| 124 |
+
json.dump(hf_config, f, indent=2)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def push_to_hf_hub(
|
| 128 |
+
model,
|
| 129 |
+
repo_id: str,
|
| 130 |
+
commit_message: str ='Add model',
|
| 131 |
+
token: Optional[str] = None,
|
| 132 |
+
revision: Optional[str] = None,
|
| 133 |
+
private: bool = False,
|
| 134 |
+
create_pr: bool = False,
|
| 135 |
+
model_config: Optional[dict] = None,
|
| 136 |
+
):
|
| 137 |
+
# Create repo if doesn't exist yet
|
| 138 |
+
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
# Infer complete repo_id from repo_url
|
| 141 |
+
# Can be different from the input `repo_id` if repo_owner was implicit
|
| 142 |
+
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
| 143 |
+
repo_id = f"{repo_owner}/{repo_name}"
|
| 144 |
+
|
| 145 |
+
# Check if README file already exist in repo
|
| 146 |
+
try:
|
| 147 |
+
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
| 148 |
+
has_readme = True
|
| 149 |
+
except EntryNotFoundError:
|
| 150 |
+
has_readme = False
|
| 151 |
+
|
| 152 |
+
# Dump model and push to Hub
|
| 153 |
+
with TemporaryDirectory() as tmpdir:
|
| 154 |
+
# Save model weights and config.
|
| 155 |
+
save_for_hf(model, tmpdir, model_config=model_config)
|
| 156 |
+
|
| 157 |
+
# Add readme if does not exist
|
| 158 |
+
if not has_readme:
|
| 159 |
+
readme_path = Path(tmpdir) / "README.md"
|
| 160 |
+
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
|
| 161 |
+
readme_path.write_text(readme_text)
|
| 162 |
+
|
| 163 |
+
# Upload model and return
|
| 164 |
+
return upload_folder(
|
| 165 |
+
repo_id=repo_id,
|
| 166 |
+
folder_path=tmpdir,
|
| 167 |
+
revision=revision,
|
| 168 |
+
create_pr=create_pr,
|
| 169 |
+
commit_message=commit_message,
|
| 170 |
+
)
|
src/custom_timm/models/inception_resnet_v2.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Pytorch Inception-Resnet-V2 implementation
|
| 2 |
+
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
| 3 |
+
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from custom_timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 10 |
+
from .helpers import build_model_with_cfg, flatten_modules
|
| 11 |
+
from .layers import create_classifier
|
| 12 |
+
from .registry import register_model
|
| 13 |
+
|
| 14 |
+
__all__ = ['InceptionResnetV2']
|
| 15 |
+
|
| 16 |
+
default_cfgs = {
|
| 17 |
+
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
| 18 |
+
'inception_resnet_v2': {
|
| 19 |
+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
|
| 20 |
+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
| 21 |
+
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
| 22 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
| 23 |
+
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
| 24 |
+
'label_offset': 1, # 1001 classes in pretrained weights
|
| 25 |
+
},
|
| 26 |
+
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
|
| 27 |
+
'ens_adv_inception_resnet_v2': {
|
| 28 |
+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
|
| 29 |
+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
| 30 |
+
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
| 31 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
| 32 |
+
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
| 33 |
+
'label_offset': 1, # 1001 classes in pretrained weights
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicConv2d(nn.Module):
|
| 39 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
| 40 |
+
super(BasicConv2d, self).__init__()
|
| 41 |
+
self.conv = nn.Conv2d(
|
| 42 |
+
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
| 43 |
+
self.bn = nn.BatchNorm2d(out_planes, eps=.001)
|
| 44 |
+
self.relu = nn.ReLU(inplace=False)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = self.conv(x)
|
| 48 |
+
x = self.bn(x)
|
| 49 |
+
x = self.relu(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Mixed_5b(nn.Module):
|
| 54 |
+
def __init__(self):
|
| 55 |
+
super(Mixed_5b, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
| 58 |
+
|
| 59 |
+
self.branch1 = nn.Sequential(
|
| 60 |
+
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
| 61 |
+
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.branch2 = nn.Sequential(
|
| 65 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
| 66 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
| 67 |
+
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.branch3 = nn.Sequential(
|
| 71 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 72 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
x0 = self.branch0(x)
|
| 77 |
+
x1 = self.branch1(x)
|
| 78 |
+
x2 = self.branch2(x)
|
| 79 |
+
x3 = self.branch3(x)
|
| 80 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Block35(nn.Module):
|
| 85 |
+
def __init__(self, scale=1.0):
|
| 86 |
+
super(Block35, self).__init__()
|
| 87 |
+
|
| 88 |
+
self.scale = scale
|
| 89 |
+
|
| 90 |
+
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
| 91 |
+
|
| 92 |
+
self.branch1 = nn.Sequential(
|
| 93 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
| 94 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.branch2 = nn.Sequential(
|
| 98 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
| 99 |
+
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
| 100 |
+
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
| 104 |
+
self.relu = nn.ReLU(inplace=False)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x0 = self.branch0(x)
|
| 108 |
+
x1 = self.branch1(x)
|
| 109 |
+
x2 = self.branch2(x)
|
| 110 |
+
out = torch.cat((x0, x1, x2), 1)
|
| 111 |
+
out = self.conv2d(out)
|
| 112 |
+
out = out * self.scale + x
|
| 113 |
+
out = self.relu(out)
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Mixed_6a(nn.Module):
|
| 118 |
+
def __init__(self):
|
| 119 |
+
super(Mixed_6a, self).__init__()
|
| 120 |
+
|
| 121 |
+
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
| 122 |
+
|
| 123 |
+
self.branch1 = nn.Sequential(
|
| 124 |
+
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
| 125 |
+
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
| 126 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
x0 = self.branch0(x)
|
| 133 |
+
x1 = self.branch1(x)
|
| 134 |
+
x2 = self.branch2(x)
|
| 135 |
+
out = torch.cat((x0, x1, x2), 1)
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Block17(nn.Module):
|
| 140 |
+
def __init__(self, scale=1.0):
|
| 141 |
+
super(Block17, self).__init__()
|
| 142 |
+
|
| 143 |
+
self.scale = scale
|
| 144 |
+
|
| 145 |
+
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
| 146 |
+
|
| 147 |
+
self.branch1 = nn.Sequential(
|
| 148 |
+
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
| 149 |
+
BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 150 |
+
BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
| 154 |
+
self.relu = nn.ReLU(inplace=False)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
x0 = self.branch0(x)
|
| 158 |
+
x1 = self.branch1(x)
|
| 159 |
+
out = torch.cat((x0, x1), 1)
|
| 160 |
+
out = self.conv2d(out)
|
| 161 |
+
out = out * self.scale + x
|
| 162 |
+
out = self.relu(out)
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class Mixed_7a(nn.Module):
|
| 167 |
+
def __init__(self):
|
| 168 |
+
super(Mixed_7a, self).__init__()
|
| 169 |
+
|
| 170 |
+
self.branch0 = nn.Sequential(
|
| 171 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
| 172 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.branch1 = nn.Sequential(
|
| 176 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
| 177 |
+
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.branch2 = nn.Sequential(
|
| 181 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
| 182 |
+
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
| 183 |
+
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.branch3 = nn.MaxPool2d(3, stride=2)
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
x0 = self.branch0(x)
|
| 190 |
+
x1 = self.branch1(x)
|
| 191 |
+
x2 = self.branch2(x)
|
| 192 |
+
x3 = self.branch3(x)
|
| 193 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 194 |
+
return out
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Block8(nn.Module):
|
| 198 |
+
|
| 199 |
+
def __init__(self, scale=1.0, no_relu=False):
|
| 200 |
+
super(Block8, self).__init__()
|
| 201 |
+
|
| 202 |
+
self.scale = scale
|
| 203 |
+
|
| 204 |
+
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
| 205 |
+
|
| 206 |
+
self.branch1 = nn.Sequential(
|
| 207 |
+
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
| 208 |
+
BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
|
| 209 |
+
BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
| 213 |
+
self.relu = None if no_relu else nn.ReLU(inplace=False)
|
| 214 |
+
|
| 215 |
+
def forward(self, x):
|
| 216 |
+
x0 = self.branch0(x)
|
| 217 |
+
x1 = self.branch1(x)
|
| 218 |
+
out = torch.cat((x0, x1), 1)
|
| 219 |
+
out = self.conv2d(out)
|
| 220 |
+
out = out * self.scale + x
|
| 221 |
+
if self.relu is not None:
|
| 222 |
+
out = self.relu(out)
|
| 223 |
+
return out
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class InceptionResnetV2(nn.Module):
|
| 227 |
+
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
|
| 228 |
+
super(InceptionResnetV2, self).__init__()
|
| 229 |
+
self.drop_rate = drop_rate
|
| 230 |
+
self.num_classes = num_classes
|
| 231 |
+
self.num_features = 1536
|
| 232 |
+
assert output_stride == 32
|
| 233 |
+
|
| 234 |
+
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
| 235 |
+
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
| 236 |
+
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
| 237 |
+
self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
|
| 238 |
+
|
| 239 |
+
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
| 240 |
+
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
| 241 |
+
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
| 242 |
+
self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
|
| 243 |
+
|
| 244 |
+
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
| 245 |
+
self.mixed_5b = Mixed_5b()
|
| 246 |
+
self.repeat = nn.Sequential(
|
| 247 |
+
Block35(scale=0.17),
|
| 248 |
+
Block35(scale=0.17),
|
| 249 |
+
Block35(scale=0.17),
|
| 250 |
+
Block35(scale=0.17),
|
| 251 |
+
Block35(scale=0.17),
|
| 252 |
+
Block35(scale=0.17),
|
| 253 |
+
Block35(scale=0.17),
|
| 254 |
+
Block35(scale=0.17),
|
| 255 |
+
Block35(scale=0.17),
|
| 256 |
+
Block35(scale=0.17)
|
| 257 |
+
)
|
| 258 |
+
self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
|
| 259 |
+
|
| 260 |
+
self.mixed_6a = Mixed_6a()
|
| 261 |
+
self.repeat_1 = nn.Sequential(
|
| 262 |
+
Block17(scale=0.10),
|
| 263 |
+
Block17(scale=0.10),
|
| 264 |
+
Block17(scale=0.10),
|
| 265 |
+
Block17(scale=0.10),
|
| 266 |
+
Block17(scale=0.10),
|
| 267 |
+
Block17(scale=0.10),
|
| 268 |
+
Block17(scale=0.10),
|
| 269 |
+
Block17(scale=0.10),
|
| 270 |
+
Block17(scale=0.10),
|
| 271 |
+
Block17(scale=0.10),
|
| 272 |
+
Block17(scale=0.10),
|
| 273 |
+
Block17(scale=0.10),
|
| 274 |
+
Block17(scale=0.10),
|
| 275 |
+
Block17(scale=0.10),
|
| 276 |
+
Block17(scale=0.10),
|
| 277 |
+
Block17(scale=0.10),
|
| 278 |
+
Block17(scale=0.10),
|
| 279 |
+
Block17(scale=0.10),
|
| 280 |
+
Block17(scale=0.10),
|
| 281 |
+
Block17(scale=0.10)
|
| 282 |
+
)
|
| 283 |
+
self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
|
| 284 |
+
|
| 285 |
+
self.mixed_7a = Mixed_7a()
|
| 286 |
+
self.repeat_2 = nn.Sequential(
|
| 287 |
+
Block8(scale=0.20),
|
| 288 |
+
Block8(scale=0.20),
|
| 289 |
+
Block8(scale=0.20),
|
| 290 |
+
Block8(scale=0.20),
|
| 291 |
+
Block8(scale=0.20),
|
| 292 |
+
Block8(scale=0.20),
|
| 293 |
+
Block8(scale=0.20),
|
| 294 |
+
Block8(scale=0.20),
|
| 295 |
+
Block8(scale=0.20)
|
| 296 |
+
)
|
| 297 |
+
self.block8 = Block8(no_relu=True)
|
| 298 |
+
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
| 299 |
+
self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
|
| 300 |
+
|
| 301 |
+
self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
| 302 |
+
|
| 303 |
+
@torch.jit.ignore
|
| 304 |
+
def group_matcher(self, coarse=False):
|
| 305 |
+
module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
|
| 306 |
+
module_map.pop(('classif',))
|
| 307 |
+
|
| 308 |
+
def _matcher(name):
|
| 309 |
+
if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]):
|
| 310 |
+
return 0
|
| 311 |
+
elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]):
|
| 312 |
+
return 1
|
| 313 |
+
elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]):
|
| 314 |
+
return len(module_map) + 1
|
| 315 |
+
else:
|
| 316 |
+
for k in module_map.keys():
|
| 317 |
+
if k == tuple(name.split('.')[:len(k)]):
|
| 318 |
+
return module_map[k]
|
| 319 |
+
return float('inf')
|
| 320 |
+
return _matcher
|
| 321 |
+
|
| 322 |
+
@torch.jit.ignore
|
| 323 |
+
def set_grad_checkpointing(self, enable=True):
|
| 324 |
+
assert not enable, "checkpointing not supported"
|
| 325 |
+
|
| 326 |
+
@torch.jit.ignore
|
| 327 |
+
def get_classifier(self):
|
| 328 |
+
return self.classif
|
| 329 |
+
|
| 330 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 331 |
+
self.num_classes = num_classes
|
| 332 |
+
self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
| 333 |
+
|
| 334 |
+
def forward_features(self, x):
|
| 335 |
+
x = self.conv2d_1a(x)
|
| 336 |
+
x = self.conv2d_2a(x)
|
| 337 |
+
x = self.conv2d_2b(x)
|
| 338 |
+
x = self.maxpool_3a(x)
|
| 339 |
+
x = self.conv2d_3b(x)
|
| 340 |
+
x = self.conv2d_4a(x)
|
| 341 |
+
x = self.maxpool_5a(x)
|
| 342 |
+
x = self.mixed_5b(x)
|
| 343 |
+
x = self.repeat(x)
|
| 344 |
+
x = self.mixed_6a(x)
|
| 345 |
+
x = self.repeat_1(x)
|
| 346 |
+
x = self.mixed_7a(x)
|
| 347 |
+
x = self.repeat_2(x)
|
| 348 |
+
x = self.block8(x)
|
| 349 |
+
x = self.conv2d_7b(x)
|
| 350 |
+
return x
|
| 351 |
+
|
| 352 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 353 |
+
x = self.global_pool(x)
|
| 354 |
+
if self.drop_rate > 0:
|
| 355 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
| 356 |
+
return x if pre_logits else self.classif(x)
|
| 357 |
+
|
| 358 |
+
def forward(self, x):
|
| 359 |
+
x = self.forward_features(x)
|
| 360 |
+
x = self.forward_head(x)
|
| 361 |
+
return x
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
|
| 365 |
+
return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@register_model
|
| 369 |
+
def inception_resnet_v2(pretrained=False, **kwargs):
|
| 370 |
+
r"""InceptionResnetV2 model architecture from the
|
| 371 |
+
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
| 372 |
+
"""
|
| 373 |
+
return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@register_model
|
| 377 |
+
def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
|
| 378 |
+
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
| 379 |
+
As per https://arxiv.org/abs/1705.07204 and
|
| 380 |
+
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
| 381 |
+
"""
|
| 382 |
+
return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)
|
src/custom_timm/models/inception_v3.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Inception-V3
|
| 2 |
+
|
| 3 |
+
Originally from torchvision Inception3 model
|
| 4 |
+
Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from custom_timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 11 |
+
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules
|
| 12 |
+
from .registry import register_model
|
| 13 |
+
from .layers import trunc_normal_, create_classifier, Linear
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _cfg(url='', **kwargs):
|
| 17 |
+
return {
|
| 18 |
+
'url': url,
|
| 19 |
+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
| 20 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 21 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
| 22 |
+
'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc',
|
| 23 |
+
**kwargs
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
default_cfgs = {
|
| 28 |
+
# original PyTorch weights, ported from Tensorflow but modified
|
| 29 |
+
'inception_v3': _cfg(
|
| 30 |
+
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
| 31 |
+
has_aux=True), # checkpoint has aux logit layer weights
|
| 32 |
+
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
| 33 |
+
'tf_inception_v3': _cfg(
|
| 34 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
| 35 |
+
num_classes=1000, has_aux=False, label_offset=1),
|
| 36 |
+
# my port of Tensorflow adversarially trained Inception V3 from
|
| 37 |
+
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
| 38 |
+
'adv_inception_v3': _cfg(
|
| 39 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
| 40 |
+
num_classes=1000, has_aux=False, label_offset=1),
|
| 41 |
+
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
| 42 |
+
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
| 43 |
+
'gluon_inception_v3': _cfg(
|
| 44 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
| 45 |
+
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
| 46 |
+
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
| 47 |
+
has_aux=False,
|
| 48 |
+
)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class InceptionA(nn.Module):
|
| 53 |
+
|
| 54 |
+
def __init__(self, in_channels, pool_features, conv_block=None):
|
| 55 |
+
super(InceptionA, self).__init__()
|
| 56 |
+
if conv_block is None:
|
| 57 |
+
conv_block = BasicConv2d
|
| 58 |
+
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
| 59 |
+
|
| 60 |
+
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
| 61 |
+
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
| 62 |
+
|
| 63 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
| 64 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
| 65 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
| 66 |
+
|
| 67 |
+
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
| 68 |
+
|
| 69 |
+
def _forward(self, x):
|
| 70 |
+
branch1x1 = self.branch1x1(x)
|
| 71 |
+
|
| 72 |
+
branch5x5 = self.branch5x5_1(x)
|
| 73 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
| 74 |
+
|
| 75 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 76 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 77 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 78 |
+
|
| 79 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 80 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 81 |
+
|
| 82 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
| 83 |
+
return outputs
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
outputs = self._forward(x)
|
| 87 |
+
return torch.cat(outputs, 1)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InceptionB(nn.Module):
|
| 91 |
+
|
| 92 |
+
def __init__(self, in_channels, conv_block=None):
|
| 93 |
+
super(InceptionB, self).__init__()
|
| 94 |
+
if conv_block is None:
|
| 95 |
+
conv_block = BasicConv2d
|
| 96 |
+
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
| 97 |
+
|
| 98 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
| 99 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
| 100 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
| 101 |
+
|
| 102 |
+
def _forward(self, x):
|
| 103 |
+
branch3x3 = self.branch3x3(x)
|
| 104 |
+
|
| 105 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 106 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 107 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
| 108 |
+
|
| 109 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 110 |
+
|
| 111 |
+
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
| 112 |
+
return outputs
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
outputs = self._forward(x)
|
| 116 |
+
return torch.cat(outputs, 1)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class InceptionC(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(self, in_channels, channels_7x7, conv_block=None):
|
| 122 |
+
super(InceptionC, self).__init__()
|
| 123 |
+
if conv_block is None:
|
| 124 |
+
conv_block = BasicConv2d
|
| 125 |
+
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
| 126 |
+
|
| 127 |
+
c7 = channels_7x7
|
| 128 |
+
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
| 129 |
+
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
| 130 |
+
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
| 131 |
+
|
| 132 |
+
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
| 133 |
+
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
| 134 |
+
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
| 135 |
+
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
| 136 |
+
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
| 137 |
+
|
| 138 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
| 139 |
+
|
| 140 |
+
def _forward(self, x):
|
| 141 |
+
branch1x1 = self.branch1x1(x)
|
| 142 |
+
|
| 143 |
+
branch7x7 = self.branch7x7_1(x)
|
| 144 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
| 145 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
| 146 |
+
|
| 147 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
| 148 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
| 149 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
| 150 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
| 151 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
| 152 |
+
|
| 153 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 154 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 155 |
+
|
| 156 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
| 157 |
+
return outputs
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
outputs = self._forward(x)
|
| 161 |
+
return torch.cat(outputs, 1)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class InceptionD(nn.Module):
|
| 165 |
+
|
| 166 |
+
def __init__(self, in_channels, conv_block=None):
|
| 167 |
+
super(InceptionD, self).__init__()
|
| 168 |
+
if conv_block is None:
|
| 169 |
+
conv_block = BasicConv2d
|
| 170 |
+
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
| 171 |
+
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
| 172 |
+
|
| 173 |
+
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
| 174 |
+
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
| 175 |
+
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
| 176 |
+
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
| 177 |
+
|
| 178 |
+
def _forward(self, x):
|
| 179 |
+
branch3x3 = self.branch3x3_1(x)
|
| 180 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
| 181 |
+
|
| 182 |
+
branch7x7x3 = self.branch7x7x3_1(x)
|
| 183 |
+
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
| 184 |
+
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
| 185 |
+
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
| 186 |
+
|
| 187 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
| 188 |
+
outputs = [branch3x3, branch7x7x3, branch_pool]
|
| 189 |
+
return outputs
|
| 190 |
+
|
| 191 |
+
def forward(self, x):
|
| 192 |
+
outputs = self._forward(x)
|
| 193 |
+
return torch.cat(outputs, 1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class InceptionE(nn.Module):
|
| 197 |
+
|
| 198 |
+
def __init__(self, in_channels, conv_block=None):
|
| 199 |
+
super(InceptionE, self).__init__()
|
| 200 |
+
if conv_block is None:
|
| 201 |
+
conv_block = BasicConv2d
|
| 202 |
+
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
| 203 |
+
|
| 204 |
+
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
| 205 |
+
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
| 206 |
+
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
| 207 |
+
|
| 208 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
| 209 |
+
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
| 210 |
+
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
| 211 |
+
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
| 212 |
+
|
| 213 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
| 214 |
+
|
| 215 |
+
def _forward(self, x):
|
| 216 |
+
branch1x1 = self.branch1x1(x)
|
| 217 |
+
|
| 218 |
+
branch3x3 = self.branch3x3_1(x)
|
| 219 |
+
branch3x3 = [
|
| 220 |
+
self.branch3x3_2a(branch3x3),
|
| 221 |
+
self.branch3x3_2b(branch3x3),
|
| 222 |
+
]
|
| 223 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
| 224 |
+
|
| 225 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
| 226 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
| 227 |
+
branch3x3dbl = [
|
| 228 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
| 229 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
| 230 |
+
]
|
| 231 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
| 232 |
+
|
| 233 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
| 234 |
+
branch_pool = self.branch_pool(branch_pool)
|
| 235 |
+
|
| 236 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
| 237 |
+
return outputs
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
outputs = self._forward(x)
|
| 241 |
+
return torch.cat(outputs, 1)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class InceptionAux(nn.Module):
|
| 245 |
+
|
| 246 |
+
def __init__(self, in_channels, num_classes, conv_block=None):
|
| 247 |
+
super(InceptionAux, self).__init__()
|
| 248 |
+
if conv_block is None:
|
| 249 |
+
conv_block = BasicConv2d
|
| 250 |
+
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
| 251 |
+
self.conv1 = conv_block(128, 768, kernel_size=5)
|
| 252 |
+
self.conv1.stddev = 0.01
|
| 253 |
+
self.fc = Linear(768, num_classes)
|
| 254 |
+
self.fc.stddev = 0.001
|
| 255 |
+
|
| 256 |
+
def forward(self, x):
|
| 257 |
+
# N x 768 x 17 x 17
|
| 258 |
+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
| 259 |
+
# N x 768 x 5 x 5
|
| 260 |
+
x = self.conv0(x)
|
| 261 |
+
# N x 128 x 5 x 5
|
| 262 |
+
x = self.conv1(x)
|
| 263 |
+
# N x 768 x 1 x 1
|
| 264 |
+
# Adaptive average pooling
|
| 265 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
| 266 |
+
# N x 768 x 1 x 1
|
| 267 |
+
x = torch.flatten(x, 1)
|
| 268 |
+
# N x 768
|
| 269 |
+
x = self.fc(x)
|
| 270 |
+
# N x 1000
|
| 271 |
+
return x
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class BasicConv2d(nn.Module):
|
| 275 |
+
|
| 276 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
| 277 |
+
super(BasicConv2d, self).__init__()
|
| 278 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
| 279 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
x = self.conv(x)
|
| 283 |
+
x = self.bn(x)
|
| 284 |
+
return F.relu(x, inplace=True)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class InceptionV3(nn.Module):
|
| 288 |
+
"""Inception-V3 with no AuxLogits
|
| 289 |
+
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False):
|
| 293 |
+
super(InceptionV3, self).__init__()
|
| 294 |
+
self.num_classes = num_classes
|
| 295 |
+
self.drop_rate = drop_rate
|
| 296 |
+
self.aux_logits = aux_logits
|
| 297 |
+
|
| 298 |
+
self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
| 299 |
+
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
|
| 300 |
+
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
| 301 |
+
self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
|
| 302 |
+
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
| 303 |
+
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
| 304 |
+
self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
|
| 305 |
+
self.Mixed_5b = InceptionA(192, pool_features=32)
|
| 306 |
+
self.Mixed_5c = InceptionA(256, pool_features=64)
|
| 307 |
+
self.Mixed_5d = InceptionA(288, pool_features=64)
|
| 308 |
+
self.Mixed_6a = InceptionB(288)
|
| 309 |
+
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
| 310 |
+
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
| 311 |
+
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
| 312 |
+
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
| 313 |
+
if aux_logits:
|
| 314 |
+
self.AuxLogits = InceptionAux(768, num_classes)
|
| 315 |
+
else:
|
| 316 |
+
self.AuxLogits = None
|
| 317 |
+
self.Mixed_7a = InceptionD(768)
|
| 318 |
+
self.Mixed_7b = InceptionE(1280)
|
| 319 |
+
self.Mixed_7c = InceptionE(2048)
|
| 320 |
+
self.feature_info = [
|
| 321 |
+
dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
|
| 322 |
+
dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
|
| 323 |
+
dict(num_chs=288, reduction=8, module='Mixed_5d'),
|
| 324 |
+
dict(num_chs=768, reduction=16, module='Mixed_6e'),
|
| 325 |
+
dict(num_chs=2048, reduction=32, module='Mixed_7c'),
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
self.num_features = 2048
|
| 329 |
+
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
| 330 |
+
|
| 331 |
+
for m in self.modules():
|
| 332 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 333 |
+
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
| 334 |
+
trunc_normal_(m.weight, std=stddev)
|
| 335 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 336 |
+
nn.init.constant_(m.weight, 1)
|
| 337 |
+
nn.init.constant_(m.bias, 0)
|
| 338 |
+
|
| 339 |
+
@torch.jit.ignore
|
| 340 |
+
def group_matcher(self, coarse=False):
|
| 341 |
+
module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
|
| 342 |
+
module_map.pop(('fc',))
|
| 343 |
+
|
| 344 |
+
def _matcher(name):
|
| 345 |
+
if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]):
|
| 346 |
+
return 0
|
| 347 |
+
elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]):
|
| 348 |
+
return 1
|
| 349 |
+
else:
|
| 350 |
+
for k in module_map.keys():
|
| 351 |
+
if k == tuple(name.split('.')[:len(k)]):
|
| 352 |
+
return module_map[k]
|
| 353 |
+
return float('inf')
|
| 354 |
+
return _matcher
|
| 355 |
+
|
| 356 |
+
@torch.jit.ignore
|
| 357 |
+
def set_grad_checkpointing(self, enable=True):
|
| 358 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 359 |
+
|
| 360 |
+
@torch.jit.ignore
|
| 361 |
+
def get_classifier(self):
|
| 362 |
+
return self.fc
|
| 363 |
+
|
| 364 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 365 |
+
self.num_classes = num_classes
|
| 366 |
+
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
| 367 |
+
|
| 368 |
+
def forward_preaux(self, x):
|
| 369 |
+
x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149
|
| 370 |
+
x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147
|
| 371 |
+
x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147
|
| 372 |
+
x = self.Pool1(x) # N x 64 x 73 x 73
|
| 373 |
+
x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73
|
| 374 |
+
x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71
|
| 375 |
+
x = self.Pool2(x) # N x 192 x 35 x 35
|
| 376 |
+
x = self.Mixed_5b(x) # N x 256 x 35 x 35
|
| 377 |
+
x = self.Mixed_5c(x) # N x 288 x 35 x 35
|
| 378 |
+
x = self.Mixed_5d(x) # N x 288 x 35 x 35
|
| 379 |
+
x = self.Mixed_6a(x) # N x 768 x 17 x 17
|
| 380 |
+
x = self.Mixed_6b(x) # N x 768 x 17 x 17
|
| 381 |
+
x = self.Mixed_6c(x) # N x 768 x 17 x 17
|
| 382 |
+
x = self.Mixed_6d(x) # N x 768 x 17 x 17
|
| 383 |
+
x = self.Mixed_6e(x) # N x 768 x 17 x 17
|
| 384 |
+
return x
|
| 385 |
+
|
| 386 |
+
def forward_postaux(self, x):
|
| 387 |
+
x = self.Mixed_7a(x) # N x 1280 x 8 x 8
|
| 388 |
+
x = self.Mixed_7b(x) # N x 2048 x 8 x 8
|
| 389 |
+
x = self.Mixed_7c(x) # N x 2048 x 8 x 8
|
| 390 |
+
return x
|
| 391 |
+
|
| 392 |
+
def forward_features(self, x):
|
| 393 |
+
x = self.forward_preaux(x)
|
| 394 |
+
x = self.forward_postaux(x)
|
| 395 |
+
return x
|
| 396 |
+
|
| 397 |
+
def forward_head(self, x):
|
| 398 |
+
x = self.global_pool(x)
|
| 399 |
+
if self.drop_rate > 0:
|
| 400 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
| 401 |
+
x = self.fc(x)
|
| 402 |
+
return x
|
| 403 |
+
|
| 404 |
+
def forward(self, x):
|
| 405 |
+
x = self.forward_features(x)
|
| 406 |
+
x = self.forward_head(x)
|
| 407 |
+
return x
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class InceptionV3Aux(InceptionV3):
|
| 411 |
+
"""InceptionV3 with AuxLogits
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True):
|
| 415 |
+
super(InceptionV3Aux, self).__init__(
|
| 416 |
+
num_classes, in_chans, drop_rate, global_pool, aux_logits)
|
| 417 |
+
|
| 418 |
+
def forward_features(self, x):
|
| 419 |
+
x = self.forward_preaux(x)
|
| 420 |
+
aux = self.AuxLogits(x) if self.training else None
|
| 421 |
+
x = self.forward_postaux(x)
|
| 422 |
+
return x, aux
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
x, aux = self.forward_features(x)
|
| 426 |
+
x = self.forward_head(x)
|
| 427 |
+
return x, aux
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _create_inception_v3(variant, pretrained=False, **kwargs):
|
| 431 |
+
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
|
| 432 |
+
aux_logits = kwargs.pop('aux_logits', False)
|
| 433 |
+
if aux_logits:
|
| 434 |
+
assert not kwargs.pop('features_only', False)
|
| 435 |
+
model_cls = InceptionV3Aux
|
| 436 |
+
load_strict = pretrained_cfg['has_aux']
|
| 437 |
+
else:
|
| 438 |
+
model_cls = InceptionV3
|
| 439 |
+
load_strict = not pretrained_cfg['has_aux']
|
| 440 |
+
|
| 441 |
+
return build_model_with_cfg(
|
| 442 |
+
model_cls, variant, pretrained,
|
| 443 |
+
pretrained_cfg=pretrained_cfg,
|
| 444 |
+
pretrained_strict=load_strict,
|
| 445 |
+
**kwargs)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@register_model
|
| 449 |
+
def inception_v3(pretrained=False, **kwargs):
|
| 450 |
+
# original PyTorch weights, ported from Tensorflow but modified
|
| 451 |
+
model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
|
| 452 |
+
return model
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@register_model
|
| 456 |
+
def tf_inception_v3(pretrained=False, **kwargs):
|
| 457 |
+
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
| 458 |
+
model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
|
| 459 |
+
return model
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@register_model
|
| 463 |
+
def adv_inception_v3(pretrained=False, **kwargs):
|
| 464 |
+
# my port of Tensorflow adversarially trained Inception V3 from
|
| 465 |
+
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
| 466 |
+
model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
|
| 467 |
+
return model
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@register_model
|
| 471 |
+
def gluon_inception_v3(pretrained=False, **kwargs):
|
| 472 |
+
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
| 473 |
+
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
| 474 |
+
model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
|
| 475 |
+
return model
|
src/custom_timm/models/inception_v4.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Pytorch Inception-V4 implementation
|
| 2 |
+
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
| 3 |
+
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from custom_timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 10 |
+
from .helpers import build_model_with_cfg
|
| 11 |
+
from .layers import create_classifier
|
| 12 |
+
from .registry import register_model
|
| 13 |
+
|
| 14 |
+
__all__ = ['InceptionV4']
|
| 15 |
+
|
| 16 |
+
default_cfgs = {
|
| 17 |
+
'inception_v4': {
|
| 18 |
+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
|
| 19 |
+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
| 20 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 21 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
| 22 |
+
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
|
| 23 |
+
'label_offset': 1, # 1001 classes in pretrained weights
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BasicConv2d(nn.Module):
|
| 29 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
| 30 |
+
super(BasicConv2d, self).__init__()
|
| 31 |
+
self.conv = nn.Conv2d(
|
| 32 |
+
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
| 33 |
+
self.bn = nn.BatchNorm2d(out_planes, eps=0.001)
|
| 34 |
+
self.relu = nn.ReLU(inplace=True)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.conv(x)
|
| 38 |
+
x = self.bn(x)
|
| 39 |
+
x = self.relu(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Mixed3a(nn.Module):
|
| 44 |
+
def __init__(self):
|
| 45 |
+
super(Mixed3a, self).__init__()
|
| 46 |
+
self.maxpool = nn.MaxPool2d(3, stride=2)
|
| 47 |
+
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x0 = self.maxpool(x)
|
| 51 |
+
x1 = self.conv(x)
|
| 52 |
+
out = torch.cat((x0, x1), 1)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Mixed4a(nn.Module):
|
| 57 |
+
def __init__(self):
|
| 58 |
+
super(Mixed4a, self).__init__()
|
| 59 |
+
|
| 60 |
+
self.branch0 = nn.Sequential(
|
| 61 |
+
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
| 62 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.branch1 = nn.Sequential(
|
| 66 |
+
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
| 67 |
+
BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 68 |
+
BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 69 |
+
BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x0 = self.branch0(x)
|
| 74 |
+
x1 = self.branch1(x)
|
| 75 |
+
out = torch.cat((x0, x1), 1)
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Mixed5a(nn.Module):
|
| 80 |
+
def __init__(self):
|
| 81 |
+
super(Mixed5a, self).__init__()
|
| 82 |
+
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
|
| 83 |
+
self.maxpool = nn.MaxPool2d(3, stride=2)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
x0 = self.conv(x)
|
| 87 |
+
x1 = self.maxpool(x)
|
| 88 |
+
out = torch.cat((x0, x1), 1)
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class InceptionA(nn.Module):
|
| 93 |
+
def __init__(self):
|
| 94 |
+
super(InceptionA, self).__init__()
|
| 95 |
+
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
|
| 96 |
+
|
| 97 |
+
self.branch1 = nn.Sequential(
|
| 98 |
+
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
| 99 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.branch2 = nn.Sequential(
|
| 103 |
+
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
| 104 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
| 105 |
+
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.branch3 = nn.Sequential(
|
| 109 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 110 |
+
BasicConv2d(384, 96, kernel_size=1, stride=1)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
x0 = self.branch0(x)
|
| 115 |
+
x1 = self.branch1(x)
|
| 116 |
+
x2 = self.branch2(x)
|
| 117 |
+
x3 = self.branch3(x)
|
| 118 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class ReductionA(nn.Module):
|
| 123 |
+
def __init__(self):
|
| 124 |
+
super(ReductionA, self).__init__()
|
| 125 |
+
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
|
| 126 |
+
|
| 127 |
+
self.branch1 = nn.Sequential(
|
| 128 |
+
BasicConv2d(384, 192, kernel_size=1, stride=1),
|
| 129 |
+
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
|
| 130 |
+
BasicConv2d(224, 256, kernel_size=3, stride=2)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x0 = self.branch0(x)
|
| 137 |
+
x1 = self.branch1(x)
|
| 138 |
+
x2 = self.branch2(x)
|
| 139 |
+
out = torch.cat((x0, x1, x2), 1)
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class InceptionB(nn.Module):
|
| 144 |
+
def __init__(self):
|
| 145 |
+
super(InceptionB, self).__init__()
|
| 146 |
+
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
|
| 147 |
+
|
| 148 |
+
self.branch1 = nn.Sequential(
|
| 149 |
+
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
| 150 |
+
BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 151 |
+
BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
self.branch2 = nn.Sequential(
|
| 155 |
+
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
| 156 |
+
BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 157 |
+
BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 158 |
+
BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 159 |
+
BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.branch3 = nn.Sequential(
|
| 163 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 164 |
+
BasicConv2d(1024, 128, kernel_size=1, stride=1)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
x0 = self.branch0(x)
|
| 169 |
+
x1 = self.branch1(x)
|
| 170 |
+
x2 = self.branch2(x)
|
| 171 |
+
x3 = self.branch3(x)
|
| 172 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 173 |
+
return out
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class ReductionB(nn.Module):
|
| 177 |
+
def __init__(self):
|
| 178 |
+
super(ReductionB, self).__init__()
|
| 179 |
+
|
| 180 |
+
self.branch0 = nn.Sequential(
|
| 181 |
+
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
| 182 |
+
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
self.branch1 = nn.Sequential(
|
| 186 |
+
BasicConv2d(1024, 256, kernel_size=1, stride=1),
|
| 187 |
+
BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 188 |
+
BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 189 |
+
BasicConv2d(320, 320, kernel_size=3, stride=2)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
x0 = self.branch0(x)
|
| 196 |
+
x1 = self.branch1(x)
|
| 197 |
+
x2 = self.branch2(x)
|
| 198 |
+
out = torch.cat((x0, x1, x2), 1)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class InceptionC(nn.Module):
|
| 203 |
+
def __init__(self):
|
| 204 |
+
super(InceptionC, self).__init__()
|
| 205 |
+
|
| 206 |
+
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
| 207 |
+
|
| 208 |
+
self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
| 209 |
+
self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
| 210 |
+
self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 211 |
+
|
| 212 |
+
self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
| 213 |
+
self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 214 |
+
self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
| 215 |
+
self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
| 216 |
+
self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 217 |
+
|
| 218 |
+
self.branch3 = nn.Sequential(
|
| 219 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 220 |
+
BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
x0 = self.branch0(x)
|
| 225 |
+
|
| 226 |
+
x1_0 = self.branch1_0(x)
|
| 227 |
+
x1_1a = self.branch1_1a(x1_0)
|
| 228 |
+
x1_1b = self.branch1_1b(x1_0)
|
| 229 |
+
x1 = torch.cat((x1_1a, x1_1b), 1)
|
| 230 |
+
|
| 231 |
+
x2_0 = self.branch2_0(x)
|
| 232 |
+
x2_1 = self.branch2_1(x2_0)
|
| 233 |
+
x2_2 = self.branch2_2(x2_1)
|
| 234 |
+
x2_3a = self.branch2_3a(x2_2)
|
| 235 |
+
x2_3b = self.branch2_3b(x2_2)
|
| 236 |
+
x2 = torch.cat((x2_3a, x2_3b), 1)
|
| 237 |
+
|
| 238 |
+
x3 = self.branch3(x)
|
| 239 |
+
|
| 240 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 241 |
+
return out
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class InceptionV4(nn.Module):
|
| 245 |
+
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
|
| 246 |
+
super(InceptionV4, self).__init__()
|
| 247 |
+
assert output_stride == 32
|
| 248 |
+
self.drop_rate = drop_rate
|
| 249 |
+
self.num_classes = num_classes
|
| 250 |
+
self.num_features = 1536
|
| 251 |
+
|
| 252 |
+
self.features = nn.Sequential(
|
| 253 |
+
BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
|
| 254 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1),
|
| 255 |
+
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
| 256 |
+
Mixed3a(),
|
| 257 |
+
Mixed4a(),
|
| 258 |
+
Mixed5a(),
|
| 259 |
+
InceptionA(),
|
| 260 |
+
InceptionA(),
|
| 261 |
+
InceptionA(),
|
| 262 |
+
InceptionA(),
|
| 263 |
+
ReductionA(), # Mixed6a
|
| 264 |
+
InceptionB(),
|
| 265 |
+
InceptionB(),
|
| 266 |
+
InceptionB(),
|
| 267 |
+
InceptionB(),
|
| 268 |
+
InceptionB(),
|
| 269 |
+
InceptionB(),
|
| 270 |
+
InceptionB(),
|
| 271 |
+
ReductionB(), # Mixed7a
|
| 272 |
+
InceptionC(),
|
| 273 |
+
InceptionC(),
|
| 274 |
+
InceptionC(),
|
| 275 |
+
)
|
| 276 |
+
self.feature_info = [
|
| 277 |
+
dict(num_chs=64, reduction=2, module='features.2'),
|
| 278 |
+
dict(num_chs=160, reduction=4, module='features.3'),
|
| 279 |
+
dict(num_chs=384, reduction=8, module='features.9'),
|
| 280 |
+
dict(num_chs=1024, reduction=16, module='features.17'),
|
| 281 |
+
dict(num_chs=1536, reduction=32, module='features.21'),
|
| 282 |
+
]
|
| 283 |
+
self.global_pool, self.last_linear = create_classifier(
|
| 284 |
+
self.num_features, self.num_classes, pool_type=global_pool)
|
| 285 |
+
|
| 286 |
+
@torch.jit.ignore
|
| 287 |
+
def group_matcher(self, coarse=False):
|
| 288 |
+
return dict(
|
| 289 |
+
stem=r'^features\.[012]\.',
|
| 290 |
+
blocks=r'^features\.(\d+)'
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
@torch.jit.ignore
|
| 294 |
+
def set_grad_checkpointing(self, enable=True):
|
| 295 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 296 |
+
|
| 297 |
+
@torch.jit.ignore
|
| 298 |
+
def get_classifier(self):
|
| 299 |
+
return self.last_linear
|
| 300 |
+
|
| 301 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 302 |
+
self.num_classes = num_classes
|
| 303 |
+
self.global_pool, self.last_linear = create_classifier(
|
| 304 |
+
self.num_features, self.num_classes, pool_type=global_pool)
|
| 305 |
+
|
| 306 |
+
def forward_features(self, x):
|
| 307 |
+
return self.features(x)
|
| 308 |
+
|
| 309 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 310 |
+
x = self.global_pool(x)
|
| 311 |
+
if self.drop_rate > 0:
|
| 312 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
| 313 |
+
return x if pre_logits else self.last_linear(x)
|
| 314 |
+
|
| 315 |
+
def forward(self, x):
|
| 316 |
+
x = self.forward_features(x)
|
| 317 |
+
x = self.forward_head(x)
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _create_inception_v4(variant, pretrained=False, **kwargs):
|
| 322 |
+
return build_model_with_cfg(
|
| 323 |
+
InceptionV4, variant, pretrained,
|
| 324 |
+
feature_cfg=dict(flatten_sequential=True),
|
| 325 |
+
**kwargs)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@register_model
|
| 329 |
+
def inception_v4(pretrained=False, **kwargs):
|
| 330 |
+
return _create_inception_v4('inception_v4', pretrained, **kwargs)
|
src/custom_timm/models/levit.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" LeViT
|
| 2 |
+
|
| 3 |
+
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
|
| 4 |
+
- https://arxiv.org/abs/2104.01136
|
| 5 |
+
|
| 6 |
+
@article{graham2021levit,
|
| 7 |
+
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
|
| 8 |
+
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
|
| 9 |
+
journal={arXiv preprint arXiv:22104.01136},
|
| 10 |
+
year={2021}
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
|
| 14 |
+
|
| 15 |
+
This version combines both conv/linear models and fixes torchscript compatibility.
|
| 16 |
+
|
| 17 |
+
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 21 |
+
# All rights reserved.
|
| 22 |
+
|
| 23 |
+
# Modified from
|
| 24 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 25 |
+
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
| 26 |
+
import itertools
|
| 27 |
+
from copy import deepcopy
|
| 28 |
+
from functools import partial
|
| 29 |
+
from typing import Dict
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
|
| 34 |
+
from custom_timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
| 35 |
+
from .helpers import build_model_with_cfg, checkpoint_seq
|
| 36 |
+
from .layers import to_ntuple, get_act_layer
|
| 37 |
+
from .vision_transformer import trunc_normal_
|
| 38 |
+
from .registry import register_model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _cfg(url='', **kwargs):
|
| 42 |
+
return {
|
| 43 |
+
'url': url,
|
| 44 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 45 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 46 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 47 |
+
'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
|
| 48 |
+
**kwargs
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
default_cfgs = dict(
|
| 53 |
+
levit_128s=_cfg(
|
| 54 |
+
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
|
| 55 |
+
),
|
| 56 |
+
levit_128=_cfg(
|
| 57 |
+
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
|
| 58 |
+
),
|
| 59 |
+
levit_192=_cfg(
|
| 60 |
+
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
|
| 61 |
+
),
|
| 62 |
+
levit_256=_cfg(
|
| 63 |
+
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
|
| 64 |
+
),
|
| 65 |
+
levit_384=_cfg(
|
| 66 |
+
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
|
| 67 |
+
),
|
| 68 |
+
|
| 69 |
+
levit_256d=_cfg(url='', classifier='head.l'),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
model_cfgs = dict(
|
| 73 |
+
levit_128s=dict(
|
| 74 |
+
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
|
| 75 |
+
levit_128=dict(
|
| 76 |
+
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
|
| 77 |
+
levit_192=dict(
|
| 78 |
+
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
|
| 79 |
+
levit_256=dict(
|
| 80 |
+
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
|
| 81 |
+
levit_384=dict(
|
| 82 |
+
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
|
| 83 |
+
|
| 84 |
+
levit_256d=dict(
|
| 85 |
+
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6)),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
__all__ = ['Levit']
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@register_model
|
| 92 |
+
def levit_128s(pretrained=False, use_conv=False, **kwargs):
|
| 93 |
+
return create_levit(
|
| 94 |
+
'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@register_model
|
| 98 |
+
def levit_128(pretrained=False, use_conv=False, **kwargs):
|
| 99 |
+
return create_levit(
|
| 100 |
+
'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@register_model
|
| 104 |
+
def levit_192(pretrained=False, use_conv=False, **kwargs):
|
| 105 |
+
return create_levit(
|
| 106 |
+
'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@register_model
|
| 110 |
+
def levit_256(pretrained=False, use_conv=False, **kwargs):
|
| 111 |
+
return create_levit(
|
| 112 |
+
'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@register_model
|
| 116 |
+
def levit_384(pretrained=False, use_conv=False, **kwargs):
|
| 117 |
+
return create_levit(
|
| 118 |
+
'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@register_model
|
| 122 |
+
def levit_256d(pretrained=False, use_conv=False, **kwargs):
|
| 123 |
+
return create_levit(
|
| 124 |
+
'levit_256d', pretrained=pretrained, use_conv=use_conv, distilled=False, **kwargs)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ConvNorm(nn.Sequential):
|
| 128 |
+
def __init__(
|
| 129 |
+
self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
|
| 130 |
+
groups=1, bn_weight_init=1, resolution=-10000):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.add_module('c', nn.Conv2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
|
| 133 |
+
self.add_module('bn', nn.BatchNorm2d(out_chs))
|
| 134 |
+
|
| 135 |
+
nn.init.constant_(self.bn.weight, bn_weight_init)
|
| 136 |
+
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def fuse(self):
|
| 139 |
+
c, bn = self._modules.values()
|
| 140 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 141 |
+
w = c.weight * w[:, None, None, None]
|
| 142 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 143 |
+
m = nn.Conv2d(
|
| 144 |
+
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
|
| 145 |
+
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
| 146 |
+
m.weight.data.copy_(w)
|
| 147 |
+
m.bias.data.copy_(b)
|
| 148 |
+
return m
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class LinearNorm(nn.Sequential):
|
| 152 |
+
def __init__(self, in_features, out_features, bn_weight_init=1, resolution=-100000):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.add_module('c', nn.Linear(in_features, out_features, bias=False))
|
| 155 |
+
self.add_module('bn', nn.BatchNorm1d(out_features))
|
| 156 |
+
|
| 157 |
+
nn.init.constant_(self.bn.weight, bn_weight_init)
|
| 158 |
+
|
| 159 |
+
@torch.no_grad()
|
| 160 |
+
def fuse(self):
|
| 161 |
+
l, bn = self._modules.values()
|
| 162 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 163 |
+
w = l.weight * w[:, None]
|
| 164 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 165 |
+
m = nn.Linear(w.size(1), w.size(0))
|
| 166 |
+
m.weight.data.copy_(w)
|
| 167 |
+
m.bias.data.copy_(b)
|
| 168 |
+
return m
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
x = self.c(x)
|
| 172 |
+
return self.bn(x.flatten(0, 1)).reshape_as(x)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class NormLinear(nn.Sequential):
|
| 176 |
+
def __init__(self, in_features, out_features, bias=True, std=0.02):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.add_module('bn', nn.BatchNorm1d(in_features))
|
| 179 |
+
self.add_module('l', nn.Linear(in_features, out_features, bias=bias))
|
| 180 |
+
|
| 181 |
+
trunc_normal_(self.l.weight, std=std)
|
| 182 |
+
if self.l.bias is not None:
|
| 183 |
+
nn.init.constant_(self.l.bias, 0)
|
| 184 |
+
|
| 185 |
+
@torch.no_grad()
|
| 186 |
+
def fuse(self):
|
| 187 |
+
bn, l = self._modules.values()
|
| 188 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 189 |
+
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 190 |
+
w = l.weight * w[None, :]
|
| 191 |
+
if l.bias is None:
|
| 192 |
+
b = b @ self.l.weight.T
|
| 193 |
+
else:
|
| 194 |
+
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
| 195 |
+
m = nn.Linear(w.size(1), w.size(0))
|
| 196 |
+
m.weight.data.copy_(w)
|
| 197 |
+
m.bias.data.copy_(b)
|
| 198 |
+
return m
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def stem_b16(in_chs, out_chs, activation, resolution=224):
|
| 202 |
+
return nn.Sequential(
|
| 203 |
+
ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
|
| 204 |
+
activation(),
|
| 205 |
+
ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
|
| 206 |
+
activation(),
|
| 207 |
+
ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
|
| 208 |
+
activation(),
|
| 209 |
+
ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class Residual(nn.Module):
|
| 213 |
+
def __init__(self, m, drop):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.m = m
|
| 216 |
+
self.drop = drop
|
| 217 |
+
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
if self.training and self.drop > 0:
|
| 220 |
+
return x + self.m(x) * torch.rand(
|
| 221 |
+
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
| 222 |
+
else:
|
| 223 |
+
return x + self.m(x)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class Subsample(nn.Module):
|
| 227 |
+
def __init__(self, stride, resolution):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.stride = stride
|
| 230 |
+
self.resolution = resolution
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
B, N, C = x.shape
|
| 234 |
+
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
|
| 235 |
+
return x.reshape(B, -1, C)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class Attention(nn.Module):
|
| 239 |
+
ab: Dict[str, torch.Tensor]
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
|
| 243 |
+
super().__init__()
|
| 244 |
+
ln_layer = ConvNorm if use_conv else LinearNorm
|
| 245 |
+
self.use_conv = use_conv
|
| 246 |
+
self.num_heads = num_heads
|
| 247 |
+
self.scale = key_dim ** -0.5
|
| 248 |
+
self.key_dim = key_dim
|
| 249 |
+
self.key_attn_dim = key_dim * num_heads
|
| 250 |
+
self.val_dim = int(attn_ratio * key_dim)
|
| 251 |
+
self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
|
| 252 |
+
|
| 253 |
+
self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, resolution=resolution)
|
| 254 |
+
self.proj = nn.Sequential(
|
| 255 |
+
act_layer(),
|
| 256 |
+
ln_layer(self.val_attn_dim, dim, bn_weight_init=0, resolution=resolution)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution ** 2))
|
| 260 |
+
pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1)
|
| 261 |
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
| 262 |
+
rel_pos = (rel_pos[0] * resolution) + rel_pos[1]
|
| 263 |
+
self.register_buffer('attention_bias_idxs', rel_pos)
|
| 264 |
+
self.ab = {}
|
| 265 |
+
|
| 266 |
+
@torch.no_grad()
|
| 267 |
+
def train(self, mode=True):
|
| 268 |
+
super().train(mode)
|
| 269 |
+
if mode and self.ab:
|
| 270 |
+
self.ab = {} # clear ab cache
|
| 271 |
+
|
| 272 |
+
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
| 273 |
+
if self.training:
|
| 274 |
+
return self.attention_biases[:, self.attention_bias_idxs]
|
| 275 |
+
else:
|
| 276 |
+
device_key = str(device)
|
| 277 |
+
if device_key not in self.ab:
|
| 278 |
+
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
| 279 |
+
return self.ab[device_key]
|
| 280 |
+
|
| 281 |
+
def forward(self, x): # x (B,C,H,W)
|
| 282 |
+
if self.use_conv:
|
| 283 |
+
B, C, H, W = x.shape
|
| 284 |
+
q, k, v = self.qkv(x).view(
|
| 285 |
+
B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
|
| 286 |
+
|
| 287 |
+
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
| 288 |
+
attn = attn.softmax(dim=-1)
|
| 289 |
+
|
| 290 |
+
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
| 291 |
+
else:
|
| 292 |
+
B, N, C = x.shape
|
| 293 |
+
q, k, v = self.qkv(x).view(
|
| 294 |
+
B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
|
| 295 |
+
q = q.permute(0, 2, 1, 3)
|
| 296 |
+
k = k.permute(0, 2, 3, 1)
|
| 297 |
+
v = v.permute(0, 2, 1, 3)
|
| 298 |
+
|
| 299 |
+
attn = q @ k * self.scale + self.get_attention_biases(x.device)
|
| 300 |
+
attn = attn.softmax(dim=-1)
|
| 301 |
+
|
| 302 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
|
| 303 |
+
x = self.proj(x)
|
| 304 |
+
return x
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class AttentionSubsample(nn.Module):
|
| 308 |
+
ab: Dict[str, torch.Tensor]
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
|
| 312 |
+
act_layer=None, stride=2, resolution=14, resolution_out=7, use_conv=False):
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.stride = stride
|
| 315 |
+
self.num_heads = num_heads
|
| 316 |
+
self.scale = key_dim ** -0.5
|
| 317 |
+
self.key_dim = key_dim
|
| 318 |
+
self.key_attn_dim = key_dim * num_heads
|
| 319 |
+
self.val_dim = int(attn_ratio * key_dim)
|
| 320 |
+
self.val_attn_dim = self.val_dim * self.num_heads
|
| 321 |
+
self.resolution = resolution
|
| 322 |
+
self.resolution_out_area = resolution_out ** 2
|
| 323 |
+
|
| 324 |
+
self.use_conv = use_conv
|
| 325 |
+
if self.use_conv:
|
| 326 |
+
ln_layer = ConvNorm
|
| 327 |
+
sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
|
| 328 |
+
else:
|
| 329 |
+
ln_layer = LinearNorm
|
| 330 |
+
sub_layer = partial(Subsample, resolution=resolution)
|
| 331 |
+
|
| 332 |
+
self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, resolution=resolution)
|
| 333 |
+
self.q = nn.Sequential(
|
| 334 |
+
sub_layer(stride=stride),
|
| 335 |
+
ln_layer(in_dim, self.key_attn_dim, resolution=resolution_out)
|
| 336 |
+
)
|
| 337 |
+
self.proj = nn.Sequential(
|
| 338 |
+
act_layer(),
|
| 339 |
+
ln_layer(self.val_attn_dim, out_dim, resolution=resolution_out)
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.resolution ** 2))
|
| 343 |
+
k_pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1)
|
| 344 |
+
q_pos = torch.stack(torch.meshgrid(
|
| 345 |
+
torch.arange(0, resolution, step=stride),
|
| 346 |
+
torch.arange(0, resolution, step=stride))).flatten(1)
|
| 347 |
+
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
| 348 |
+
rel_pos = (rel_pos[0] * resolution) + rel_pos[1]
|
| 349 |
+
self.register_buffer('attention_bias_idxs', rel_pos)
|
| 350 |
+
|
| 351 |
+
self.ab = {} # per-device attention_biases cache
|
| 352 |
+
|
| 353 |
+
@torch.no_grad()
|
| 354 |
+
def train(self, mode=True):
|
| 355 |
+
super().train(mode)
|
| 356 |
+
if mode and self.ab:
|
| 357 |
+
self.ab = {} # clear ab cache
|
| 358 |
+
|
| 359 |
+
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
| 360 |
+
if self.training:
|
| 361 |
+
return self.attention_biases[:, self.attention_bias_idxs]
|
| 362 |
+
else:
|
| 363 |
+
device_key = str(device)
|
| 364 |
+
if device_key not in self.ab:
|
| 365 |
+
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
| 366 |
+
return self.ab[device_key]
|
| 367 |
+
|
| 368 |
+
def forward(self, x):
|
| 369 |
+
if self.use_conv:
|
| 370 |
+
B, C, H, W = x.shape
|
| 371 |
+
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
|
| 372 |
+
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_out_area)
|
| 373 |
+
|
| 374 |
+
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
| 375 |
+
attn = attn.softmax(dim=-1)
|
| 376 |
+
|
| 377 |
+
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution, self.resolution)
|
| 378 |
+
else:
|
| 379 |
+
B, N, C = x.shape
|
| 380 |
+
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
|
| 381 |
+
k = k.permute(0, 2, 3, 1) # BHCN
|
| 382 |
+
v = v.permute(0, 2, 1, 3) # BHNC
|
| 383 |
+
q = self.q(x).view(B, self.resolution_out_area, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
| 384 |
+
|
| 385 |
+
attn = q @ k * self.scale + self.get_attention_biases(x.device)
|
| 386 |
+
attn = attn.softmax(dim=-1)
|
| 387 |
+
|
| 388 |
+
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
|
| 389 |
+
x = self.proj(x)
|
| 390 |
+
return x
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class Levit(nn.Module):
|
| 394 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 395 |
+
|
| 396 |
+
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
|
| 397 |
+
w/ train scripts that don't take tuple outputs,
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
def __init__(
|
| 401 |
+
self,
|
| 402 |
+
img_size=224,
|
| 403 |
+
patch_size=16,
|
| 404 |
+
in_chans=3,
|
| 405 |
+
num_classes=1000,
|
| 406 |
+
embed_dim=(192,),
|
| 407 |
+
key_dim=64,
|
| 408 |
+
depth=(12,),
|
| 409 |
+
num_heads=(3,),
|
| 410 |
+
attn_ratio=2,
|
| 411 |
+
mlp_ratio=2,
|
| 412 |
+
hybrid_backbone=None,
|
| 413 |
+
down_ops=None,
|
| 414 |
+
act_layer='hard_swish',
|
| 415 |
+
attn_act_layer='hard_swish',
|
| 416 |
+
use_conv=False,
|
| 417 |
+
global_pool='avg',
|
| 418 |
+
drop_rate=0.,
|
| 419 |
+
drop_path_rate=0.):
|
| 420 |
+
super().__init__()
|
| 421 |
+
act_layer = get_act_layer(act_layer)
|
| 422 |
+
attn_act_layer = get_act_layer(attn_act_layer)
|
| 423 |
+
ln_layer = ConvNorm if use_conv else LinearNorm
|
| 424 |
+
self.use_conv = use_conv
|
| 425 |
+
if isinstance(img_size, tuple):
|
| 426 |
+
# FIXME origin impl passes single img/res dim through whole hierarchy,
|
| 427 |
+
# not sure this model will be used enough to spend time fixing it.
|
| 428 |
+
assert img_size[0] == img_size[1]
|
| 429 |
+
img_size = img_size[0]
|
| 430 |
+
self.num_classes = num_classes
|
| 431 |
+
self.global_pool = global_pool
|
| 432 |
+
self.num_features = embed_dim[-1]
|
| 433 |
+
self.embed_dim = embed_dim
|
| 434 |
+
self.grad_checkpointing = False
|
| 435 |
+
|
| 436 |
+
num_stages = len(embed_dim)
|
| 437 |
+
assert len(depth) == len(num_heads) == num_stages
|
| 438 |
+
key_dim = to_ntuple(num_stages)(key_dim)
|
| 439 |
+
attn_ratio = to_ntuple(num_stages)(attn_ratio)
|
| 440 |
+
mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
|
| 441 |
+
down_ops = down_ops or (
|
| 442 |
+
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
| 443 |
+
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
|
| 444 |
+
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
|
| 445 |
+
('',)
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
|
| 449 |
+
|
| 450 |
+
self.blocks = []
|
| 451 |
+
resolution = img_size // patch_size
|
| 452 |
+
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
|
| 453 |
+
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
|
| 454 |
+
for _ in range(dpth):
|
| 455 |
+
self.blocks.append(
|
| 456 |
+
Residual(
|
| 457 |
+
Attention(
|
| 458 |
+
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
|
| 459 |
+
resolution=resolution, use_conv=use_conv),
|
| 460 |
+
drop_path_rate))
|
| 461 |
+
if mr > 0:
|
| 462 |
+
h = int(ed * mr)
|
| 463 |
+
self.blocks.append(
|
| 464 |
+
Residual(nn.Sequential(
|
| 465 |
+
ln_layer(ed, h, resolution=resolution),
|
| 466 |
+
act_layer(),
|
| 467 |
+
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
|
| 468 |
+
), drop_path_rate))
|
| 469 |
+
if do[0] == 'Subsample':
|
| 470 |
+
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
| 471 |
+
resolution_out = (resolution - 1) // do[5] + 1
|
| 472 |
+
self.blocks.append(
|
| 473 |
+
AttentionSubsample(
|
| 474 |
+
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
|
| 475 |
+
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
|
| 476 |
+
resolution=resolution, resolution_out=resolution_out, use_conv=use_conv))
|
| 477 |
+
resolution = resolution_out
|
| 478 |
+
if do[4] > 0: # mlp_ratio
|
| 479 |
+
h = int(embed_dim[i + 1] * do[4])
|
| 480 |
+
self.blocks.append(
|
| 481 |
+
Residual(nn.Sequential(
|
| 482 |
+
ln_layer(embed_dim[i + 1], h, resolution=resolution),
|
| 483 |
+
act_layer(),
|
| 484 |
+
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
|
| 485 |
+
), drop_path_rate))
|
| 486 |
+
self.blocks = nn.Sequential(*self.blocks)
|
| 487 |
+
|
| 488 |
+
# Classifier head
|
| 489 |
+
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
| 490 |
+
|
| 491 |
+
@torch.jit.ignore
|
| 492 |
+
def no_weight_decay(self):
|
| 493 |
+
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
| 494 |
+
|
| 495 |
+
@torch.jit.ignore
|
| 496 |
+
def group_matcher(self, coarse=False):
|
| 497 |
+
matcher = dict(
|
| 498 |
+
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
| 499 |
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
| 500 |
+
)
|
| 501 |
+
return matcher
|
| 502 |
+
|
| 503 |
+
@torch.jit.ignore
|
| 504 |
+
def set_grad_checkpointing(self, enable=True):
|
| 505 |
+
self.grad_checkpointing = enable
|
| 506 |
+
|
| 507 |
+
@torch.jit.ignore
|
| 508 |
+
def get_classifier(self):
|
| 509 |
+
return self.head
|
| 510 |
+
|
| 511 |
+
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
| 512 |
+
self.num_classes = num_classes
|
| 513 |
+
if global_pool is not None:
|
| 514 |
+
self.global_pool = global_pool
|
| 515 |
+
self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
| 516 |
+
|
| 517 |
+
def forward_features(self, x):
|
| 518 |
+
x = self.patch_embed(x)
|
| 519 |
+
if not self.use_conv:
|
| 520 |
+
x = x.flatten(2).transpose(1, 2)
|
| 521 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 522 |
+
x = checkpoint_seq(self.blocks, x)
|
| 523 |
+
else:
|
| 524 |
+
x = self.blocks(x)
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 528 |
+
if self.global_pool == 'avg':
|
| 529 |
+
x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
|
| 530 |
+
return x if pre_logits else self.head(x)
|
| 531 |
+
|
| 532 |
+
def forward(self, x):
|
| 533 |
+
x = self.forward_features(x)
|
| 534 |
+
x = self.forward_head(x)
|
| 535 |
+
return x
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class LevitDistilled(Levit):
|
| 539 |
+
def __init__(self, *args, **kwargs):
|
| 540 |
+
super().__init__(*args, **kwargs)
|
| 541 |
+
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
| 542 |
+
self.distilled_training = False # must set this True to train w/ distillation token
|
| 543 |
+
|
| 544 |
+
@torch.jit.ignore
|
| 545 |
+
def get_classifier(self):
|
| 546 |
+
return self.head, self.head_dist
|
| 547 |
+
|
| 548 |
+
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
| 549 |
+
self.num_classes = num_classes
|
| 550 |
+
if global_pool is not None:
|
| 551 |
+
self.global_pool = global_pool
|
| 552 |
+
self.head = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 553 |
+
self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 554 |
+
|
| 555 |
+
@torch.jit.ignore
|
| 556 |
+
def set_distilled_training(self, enable=True):
|
| 557 |
+
self.distilled_training = enable
|
| 558 |
+
|
| 559 |
+
def forward_head(self, x):
|
| 560 |
+
if self.global_pool == 'avg':
|
| 561 |
+
x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
|
| 562 |
+
x, x_dist = self.head(x), self.head_dist(x)
|
| 563 |
+
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
| 564 |
+
# only return separate classification predictions when training in distilled mode
|
| 565 |
+
return x, x_dist
|
| 566 |
+
else:
|
| 567 |
+
# during standard train/finetune, inference average the classifier predictions
|
| 568 |
+
return (x + x_dist) / 2
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 572 |
+
if 'model' in state_dict:
|
| 573 |
+
# For deit models
|
| 574 |
+
state_dict = state_dict['model']
|
| 575 |
+
D = model.state_dict()
|
| 576 |
+
for k in state_dict.keys():
|
| 577 |
+
if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
|
| 578 |
+
state_dict[k] = state_dict[k][:, :, None, None]
|
| 579 |
+
return state_dict
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def create_levit(variant, pretrained=False, distilled=True, **kwargs):
|
| 583 |
+
if kwargs.get('features_only', None):
|
| 584 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 585 |
+
|
| 586 |
+
model_cfg = dict(**model_cfgs[variant], **kwargs)
|
| 587 |
+
model = build_model_with_cfg(
|
| 588 |
+
LevitDistilled if distilled else Levit, variant, pretrained,
|
| 589 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 590 |
+
**model_cfg)
|
| 591 |
+
return model
|
| 592 |
+
|
src/custom_timm/optim/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .adabelief import AdaBelief
|
| 2 |
+
from .adafactor import Adafactor
|
| 3 |
+
from .adahessian import Adahessian
|
| 4 |
+
from .adamp import AdamP
|
| 5 |
+
from .adamw import AdamW
|
| 6 |
+
from .lamb import Lamb
|
| 7 |
+
from .lars import Lars
|
| 8 |
+
from .lookahead import Lookahead
|
| 9 |
+
from .madgrad import MADGRAD
|
| 10 |
+
from .nadam import Nadam
|
| 11 |
+
from .nvnovograd import NvNovoGrad
|
| 12 |
+
from .radam import RAdam
|
| 13 |
+
from .rmsprop_tf import RMSpropTF
|
| 14 |
+
from .sgdp import SGDP
|
| 15 |
+
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
src/custom_timm/optim/adabelief.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.optim.optimizer import Optimizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AdaBelief(Optimizer):
|
| 7 |
+
r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
|
| 8 |
+
|
| 9 |
+
Arguments:
|
| 10 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 11 |
+
parameter groups
|
| 12 |
+
lr (float, optional): learning rate (default: 1e-3)
|
| 13 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 14 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
| 15 |
+
eps (float, optional): term added to the denominator to improve
|
| 16 |
+
numerical stability (default: 1e-16)
|
| 17 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 18 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
| 19 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
| 20 |
+
(default: False)
|
| 21 |
+
decoupled_decay (boolean, optional): (default: True) If set as True, then
|
| 22 |
+
the optimizer uses decoupled weight decay as in AdamW
|
| 23 |
+
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
|
| 24 |
+
is set as True.
|
| 25 |
+
When fixed_decay == True, the weight decay is performed as
|
| 26 |
+
$W_{new} = W_{old} - W_{old} \times decay$.
|
| 27 |
+
When fixed_decay == False, the weight decay is performed as
|
| 28 |
+
$W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
|
| 29 |
+
weight decay ratio decreases with learning rate (lr).
|
| 30 |
+
rectify (boolean, optional): (default: True) If set as True, then perform the rectified
|
| 31 |
+
update similar to RAdam
|
| 32 |
+
degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
|
| 33 |
+
when variance of gradient is high
|
| 34 |
+
reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
|
| 35 |
+
|
| 36 |
+
For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
|
| 37 |
+
For example train/args for EfficientNet see these gists
|
| 38 |
+
- link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
|
| 39 |
+
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
|
| 44 |
+
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
|
| 45 |
+
|
| 46 |
+
if not 0.0 <= lr:
|
| 47 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 48 |
+
if not 0.0 <= eps:
|
| 49 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 50 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 51 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 52 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 53 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 54 |
+
|
| 55 |
+
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
| 56 |
+
for param in params:
|
| 57 |
+
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
| 58 |
+
param['buffer'] = [[None, None, None] for _ in range(10)]
|
| 59 |
+
|
| 60 |
+
defaults = dict(
|
| 61 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
|
| 62 |
+
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
|
| 63 |
+
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
|
| 64 |
+
super(AdaBelief, self).__init__(params, defaults)
|
| 65 |
+
|
| 66 |
+
def __setstate__(self, state):
|
| 67 |
+
super(AdaBelief, self).__setstate__(state)
|
| 68 |
+
for group in self.param_groups:
|
| 69 |
+
group.setdefault('amsgrad', False)
|
| 70 |
+
|
| 71 |
+
@torch.no_grad()
|
| 72 |
+
def reset(self):
|
| 73 |
+
for group in self.param_groups:
|
| 74 |
+
for p in group['params']:
|
| 75 |
+
state = self.state[p]
|
| 76 |
+
amsgrad = group['amsgrad']
|
| 77 |
+
|
| 78 |
+
# State initialization
|
| 79 |
+
state['step'] = 0
|
| 80 |
+
# Exponential moving average of gradient values
|
| 81 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 82 |
+
|
| 83 |
+
# Exponential moving average of squared gradient values
|
| 84 |
+
state['exp_avg_var'] = torch.zeros_like(p)
|
| 85 |
+
if amsgrad:
|
| 86 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 87 |
+
state['max_exp_avg_var'] = torch.zeros_like(p)
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def step(self, closure=None):
|
| 91 |
+
"""Performs a single optimization step.
|
| 92 |
+
Arguments:
|
| 93 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 94 |
+
and returns the loss.
|
| 95 |
+
"""
|
| 96 |
+
loss = None
|
| 97 |
+
if closure is not None:
|
| 98 |
+
with torch.enable_grad():
|
| 99 |
+
loss = closure()
|
| 100 |
+
|
| 101 |
+
for group in self.param_groups:
|
| 102 |
+
for p in group['params']:
|
| 103 |
+
if p.grad is None:
|
| 104 |
+
continue
|
| 105 |
+
grad = p.grad
|
| 106 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
| 107 |
+
grad = grad.float()
|
| 108 |
+
if grad.is_sparse:
|
| 109 |
+
raise RuntimeError(
|
| 110 |
+
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
|
| 111 |
+
|
| 112 |
+
p_fp32 = p
|
| 113 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 114 |
+
p_fp32 = p_fp32.float()
|
| 115 |
+
|
| 116 |
+
amsgrad = group['amsgrad']
|
| 117 |
+
beta1, beta2 = group['betas']
|
| 118 |
+
state = self.state[p]
|
| 119 |
+
# State initialization
|
| 120 |
+
if len(state) == 0:
|
| 121 |
+
state['step'] = 0
|
| 122 |
+
# Exponential moving average of gradient values
|
| 123 |
+
state['exp_avg'] = torch.zeros_like(p_fp32)
|
| 124 |
+
# Exponential moving average of squared gradient values
|
| 125 |
+
state['exp_avg_var'] = torch.zeros_like(p_fp32)
|
| 126 |
+
if amsgrad:
|
| 127 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 128 |
+
state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
|
| 129 |
+
|
| 130 |
+
# perform weight decay, check if decoupled weight decay
|
| 131 |
+
if group['decoupled_decay']:
|
| 132 |
+
if not group['fixed_decay']:
|
| 133 |
+
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
|
| 134 |
+
else:
|
| 135 |
+
p_fp32.mul_(1.0 - group['weight_decay'])
|
| 136 |
+
else:
|
| 137 |
+
if group['weight_decay'] != 0:
|
| 138 |
+
grad.add_(p_fp32, alpha=group['weight_decay'])
|
| 139 |
+
|
| 140 |
+
# get current state variable
|
| 141 |
+
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
|
| 142 |
+
|
| 143 |
+
state['step'] += 1
|
| 144 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
| 145 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
| 146 |
+
|
| 147 |
+
# Update first and second moment running average
|
| 148 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 149 |
+
grad_residual = grad - exp_avg
|
| 150 |
+
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
|
| 151 |
+
|
| 152 |
+
if amsgrad:
|
| 153 |
+
max_exp_avg_var = state['max_exp_avg_var']
|
| 154 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
| 155 |
+
torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var)
|
| 156 |
+
|
| 157 |
+
# Use the max. for normalizing running avg. of gradient
|
| 158 |
+
denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 159 |
+
else:
|
| 160 |
+
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 161 |
+
|
| 162 |
+
# update
|
| 163 |
+
if not group['rectify']:
|
| 164 |
+
# Default update
|
| 165 |
+
step_size = group['lr'] / bias_correction1
|
| 166 |
+
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
| 167 |
+
else:
|
| 168 |
+
# Rectified update, forked from RAdam
|
| 169 |
+
buffered = group['buffer'][int(state['step'] % 10)]
|
| 170 |
+
if state['step'] == buffered[0]:
|
| 171 |
+
num_sma, step_size = buffered[1], buffered[2]
|
| 172 |
+
else:
|
| 173 |
+
buffered[0] = state['step']
|
| 174 |
+
beta2_t = beta2 ** state['step']
|
| 175 |
+
num_sma_max = 2 / (1 - beta2) - 1
|
| 176 |
+
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
| 177 |
+
buffered[1] = num_sma
|
| 178 |
+
|
| 179 |
+
# more conservative since it's an approximated value
|
| 180 |
+
if num_sma >= 5:
|
| 181 |
+
step_size = math.sqrt(
|
| 182 |
+
(1 - beta2_t) *
|
| 183 |
+
(num_sma - 4) / (num_sma_max - 4) *
|
| 184 |
+
(num_sma - 2) / num_sma *
|
| 185 |
+
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
|
| 186 |
+
elif group['degenerated_to_sgd']:
|
| 187 |
+
step_size = 1.0 / (1 - beta1 ** state['step'])
|
| 188 |
+
else:
|
| 189 |
+
step_size = -1
|
| 190 |
+
buffered[2] = step_size
|
| 191 |
+
|
| 192 |
+
if num_sma >= 5:
|
| 193 |
+
denom = exp_avg_var.sqrt().add_(group['eps'])
|
| 194 |
+
p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
| 195 |
+
elif step_size > 0:
|
| 196 |
+
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
|
| 197 |
+
|
| 198 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 199 |
+
p.copy_(p_fp32)
|
| 200 |
+
|
| 201 |
+
return loss
|
src/custom_timm/optim/adafactor.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Adafactor Optimizer
|
| 2 |
+
|
| 3 |
+
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
| 4 |
+
|
| 5 |
+
Original header/copyright below.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 9 |
+
#
|
| 10 |
+
# This source code is licensed under the MIT license found in the
|
| 11 |
+
# LICENSE file in the root directory of this source tree.
|
| 12 |
+
import torch
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Adafactor(torch.optim.Optimizer):
|
| 17 |
+
"""Implements Adafactor algorithm.
|
| 18 |
+
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
| 19 |
+
(see https://arxiv.org/abs/1804.04235)
|
| 20 |
+
|
| 21 |
+
Note that this optimizer internally adjusts the learning rate depending on the
|
| 22 |
+
*scale_parameter*, *relative_step* and *warmup_init* options.
|
| 23 |
+
|
| 24 |
+
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
| 25 |
+
`relative_step=False`.
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
| 29 |
+
lr (float, optional): external learning rate (default: None)
|
| 30 |
+
eps (tuple[float, float]): regularization constants for square gradient
|
| 31 |
+
and parameter scale respectively (default: (1e-30, 1e-3))
|
| 32 |
+
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
| 33 |
+
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
| 34 |
+
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
| 35 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 36 |
+
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
| 37 |
+
warmup_init (bool): time-dependent learning rate computation depends on
|
| 38 |
+
whether warm-up initialization is being used (default: False)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
|
| 42 |
+
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
|
| 43 |
+
relative_step = not lr
|
| 44 |
+
if warmup_init and not relative_step:
|
| 45 |
+
raise ValueError('warmup_init requires relative_step=True')
|
| 46 |
+
|
| 47 |
+
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
|
| 48 |
+
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
|
| 49 |
+
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
|
| 50 |
+
relative_step=relative_step, warmup_init=warmup_init)
|
| 51 |
+
super(Adafactor, self).__init__(params, defaults)
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def _get_lr(param_group, param_state):
|
| 55 |
+
if param_group['relative_step']:
|
| 56 |
+
min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
|
| 57 |
+
lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
|
| 58 |
+
param_scale = 1.0
|
| 59 |
+
if param_group['scale_parameter']:
|
| 60 |
+
param_scale = max(param_group['eps_scale'], param_state['RMS'])
|
| 61 |
+
param_group['lr'] = lr_t * param_scale
|
| 62 |
+
return param_group['lr']
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _get_options(param_group, param_shape):
|
| 66 |
+
factored = len(param_shape) >= 2
|
| 67 |
+
use_first_moment = param_group['beta1'] is not None
|
| 68 |
+
return factored, use_first_moment
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def _rms(tensor):
|
| 72 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
| 73 |
+
|
| 74 |
+
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
| 75 |
+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
| 76 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
| 77 |
+
return torch.mul(r_factor, c_factor)
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def step(self, closure=None):
|
| 81 |
+
"""Performs a single optimization step.
|
| 82 |
+
Arguments:
|
| 83 |
+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
| 84 |
+
"""
|
| 85 |
+
loss = None
|
| 86 |
+
if closure is not None:
|
| 87 |
+
with torch.enable_grad():
|
| 88 |
+
loss = closure()
|
| 89 |
+
|
| 90 |
+
for group in self.param_groups:
|
| 91 |
+
for p in group['params']:
|
| 92 |
+
if p.grad is None:
|
| 93 |
+
continue
|
| 94 |
+
grad = p.grad
|
| 95 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
| 96 |
+
grad = grad.float()
|
| 97 |
+
if grad.is_sparse:
|
| 98 |
+
raise RuntimeError('Adafactor does not support sparse gradients.')
|
| 99 |
+
|
| 100 |
+
state = self.state[p]
|
| 101 |
+
|
| 102 |
+
factored, use_first_moment = self._get_options(group, grad.shape)
|
| 103 |
+
# State Initialization
|
| 104 |
+
if len(state) == 0:
|
| 105 |
+
state['step'] = 0
|
| 106 |
+
|
| 107 |
+
if use_first_moment:
|
| 108 |
+
# Exponential moving average of gradient values
|
| 109 |
+
state['exp_avg'] = torch.zeros_like(grad)
|
| 110 |
+
if factored:
|
| 111 |
+
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
|
| 112 |
+
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
|
| 113 |
+
else:
|
| 114 |
+
state['exp_avg_sq'] = torch.zeros_like(grad)
|
| 115 |
+
|
| 116 |
+
state['RMS'] = 0
|
| 117 |
+
else:
|
| 118 |
+
if use_first_moment:
|
| 119 |
+
state['exp_avg'] = state['exp_avg'].to(grad)
|
| 120 |
+
if factored:
|
| 121 |
+
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
| 122 |
+
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
| 123 |
+
else:
|
| 124 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
| 125 |
+
|
| 126 |
+
p_fp32 = p
|
| 127 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 128 |
+
p_fp32 = p_fp32.float()
|
| 129 |
+
|
| 130 |
+
state['step'] += 1
|
| 131 |
+
state['RMS'] = self._rms(p_fp32)
|
| 132 |
+
lr_t = self._get_lr(group, state)
|
| 133 |
+
|
| 134 |
+
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
| 135 |
+
update = grad ** 2 + group['eps']
|
| 136 |
+
if factored:
|
| 137 |
+
exp_avg_sq_row = state['exp_avg_sq_row']
|
| 138 |
+
exp_avg_sq_col = state['exp_avg_sq_col']
|
| 139 |
+
|
| 140 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
|
| 141 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
| 142 |
+
|
| 143 |
+
# Approximation of exponential moving average of square of gradient
|
| 144 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
| 145 |
+
update.mul_(grad)
|
| 146 |
+
else:
|
| 147 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 148 |
+
|
| 149 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
|
| 150 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
| 151 |
+
|
| 152 |
+
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
|
| 153 |
+
update.mul_(lr_t)
|
| 154 |
+
|
| 155 |
+
if use_first_moment:
|
| 156 |
+
exp_avg = state['exp_avg']
|
| 157 |
+
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
|
| 158 |
+
update = exp_avg
|
| 159 |
+
|
| 160 |
+
if group['weight_decay'] != 0:
|
| 161 |
+
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
|
| 162 |
+
|
| 163 |
+
p_fp32.add_(-update)
|
| 164 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 165 |
+
p.copy_(p_fp32)
|
| 166 |
+
|
| 167 |
+
return loss
|
src/custom_timm/optim/adahessian.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" AdaHessian Optimizer
|
| 2 |
+
|
| 3 |
+
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
|
| 4 |
+
Originally licensed MIT, Copyright 2020, David Samuel
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Adahessian(torch.optim.Optimizer):
|
| 10 |
+
"""
|
| 11 |
+
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
|
| 12 |
+
|
| 13 |
+
Arguments:
|
| 14 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
| 15 |
+
lr (float, optional): learning rate (default: 0.1)
|
| 16 |
+
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
|
| 17 |
+
squared hessian trace (default: (0.9, 0.999))
|
| 18 |
+
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
|
| 19 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
|
| 20 |
+
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
|
| 21 |
+
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
|
| 22 |
+
(to save time) (default: 1)
|
| 23 |
+
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
|
| 27 |
+
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
|
| 28 |
+
if not 0.0 <= lr:
|
| 29 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
| 30 |
+
if not 0.0 <= eps:
|
| 31 |
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
| 32 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 33 |
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
| 34 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 35 |
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
| 36 |
+
if not 0.0 <= hessian_power <= 1.0:
|
| 37 |
+
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
|
| 38 |
+
|
| 39 |
+
self.n_samples = n_samples
|
| 40 |
+
self.update_each = update_each
|
| 41 |
+
self.avg_conv_kernel = avg_conv_kernel
|
| 42 |
+
|
| 43 |
+
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
|
| 44 |
+
self.seed = 2147483647
|
| 45 |
+
self.generator = torch.Generator().manual_seed(self.seed)
|
| 46 |
+
|
| 47 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
|
| 48 |
+
super(Adahessian, self).__init__(params, defaults)
|
| 49 |
+
|
| 50 |
+
for p in self.get_params():
|
| 51 |
+
p.hess = 0.0
|
| 52 |
+
self.state[p]["hessian step"] = 0
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def is_second_order(self):
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
def get_params(self):
|
| 59 |
+
"""
|
| 60 |
+
Gets all parameters in all param_groups with gradients
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
|
| 64 |
+
|
| 65 |
+
def zero_hessian(self):
|
| 66 |
+
"""
|
| 67 |
+
Zeros out the accumalated hessian traces.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
for p in self.get_params():
|
| 71 |
+
if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
|
| 72 |
+
p.hess.zero_()
|
| 73 |
+
|
| 74 |
+
@torch.no_grad()
|
| 75 |
+
def set_hessian(self):
|
| 76 |
+
"""
|
| 77 |
+
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
params = []
|
| 81 |
+
for p in filter(lambda p: p.grad is not None, self.get_params()):
|
| 82 |
+
if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
|
| 83 |
+
params.append(p)
|
| 84 |
+
self.state[p]["hessian step"] += 1
|
| 85 |
+
|
| 86 |
+
if len(params) == 0:
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
|
| 90 |
+
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
|
| 91 |
+
|
| 92 |
+
grads = [p.grad for p in params]
|
| 93 |
+
|
| 94 |
+
for i in range(self.n_samples):
|
| 95 |
+
# Rademacher distribution {-1.0, 1.0}
|
| 96 |
+
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
|
| 97 |
+
h_zs = torch.autograd.grad(
|
| 98 |
+
grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
|
| 99 |
+
for h_z, z, p in zip(h_zs, zs, params):
|
| 100 |
+
p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def step(self, closure=None):
|
| 104 |
+
"""
|
| 105 |
+
Performs a single optimization step.
|
| 106 |
+
Arguments:
|
| 107 |
+
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
loss = None
|
| 111 |
+
if closure is not None:
|
| 112 |
+
loss = closure()
|
| 113 |
+
|
| 114 |
+
self.zero_hessian()
|
| 115 |
+
self.set_hessian()
|
| 116 |
+
|
| 117 |
+
for group in self.param_groups:
|
| 118 |
+
for p in group['params']:
|
| 119 |
+
if p.grad is None or p.hess is None:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
if self.avg_conv_kernel and p.dim() == 4:
|
| 123 |
+
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
|
| 124 |
+
|
| 125 |
+
# Perform correct stepweight decay as in AdamW
|
| 126 |
+
p.mul_(1 - group['lr'] * group['weight_decay'])
|
| 127 |
+
|
| 128 |
+
state = self.state[p]
|
| 129 |
+
|
| 130 |
+
# State initialization
|
| 131 |
+
if len(state) == 1:
|
| 132 |
+
state['step'] = 0
|
| 133 |
+
# Exponential moving average of gradient values
|
| 134 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 135 |
+
# Exponential moving average of Hessian diagonal square values
|
| 136 |
+
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
|
| 137 |
+
|
| 138 |
+
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
|
| 139 |
+
beta1, beta2 = group['betas']
|
| 140 |
+
state['step'] += 1
|
| 141 |
+
|
| 142 |
+
# Decay the first and second moment running average coefficient
|
| 143 |
+
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
| 144 |
+
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
|
| 145 |
+
|
| 146 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
| 147 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
| 148 |
+
|
| 149 |
+
k = group['hessian_power']
|
| 150 |
+
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
|
| 151 |
+
|
| 152 |
+
# make update
|
| 153 |
+
step_size = group['lr'] / bias_correction1
|
| 154 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
| 155 |
+
|
| 156 |
+
return loss
|
src/custom_timm/optim/adamp.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
|
| 3 |
+
|
| 4 |
+
Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
|
| 5 |
+
Code: https://github.com/clovaai/AdamP
|
| 6 |
+
|
| 7 |
+
Copyright (c) 2020-present NAVER Corp.
|
| 8 |
+
MIT license
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.optim.optimizer import Optimizer
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _channel_view(x) -> torch.Tensor:
|
| 18 |
+
return x.reshape(x.size(0), -1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _layer_view(x) -> torch.Tensor:
|
| 22 |
+
return x.reshape(1, -1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
|
| 26 |
+
wd = 1.
|
| 27 |
+
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
|
| 28 |
+
for view_func in [_channel_view, _layer_view]:
|
| 29 |
+
param_view = view_func(p)
|
| 30 |
+
grad_view = view_func(grad)
|
| 31 |
+
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
|
| 32 |
+
|
| 33 |
+
# FIXME this is a problem for PyTorch XLA
|
| 34 |
+
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
|
| 35 |
+
p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
|
| 36 |
+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
|
| 37 |
+
wd = wd_ratio
|
| 38 |
+
return perturb, wd
|
| 39 |
+
|
| 40 |
+
return perturb, wd
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class AdamP(Optimizer):
|
| 44 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
| 45 |
+
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
|
| 46 |
+
defaults = dict(
|
| 47 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
| 48 |
+
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
| 49 |
+
super(AdamP, self).__init__(params, defaults)
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def step(self, closure=None):
|
| 53 |
+
loss = None
|
| 54 |
+
if closure is not None:
|
| 55 |
+
with torch.enable_grad():
|
| 56 |
+
loss = closure()
|
| 57 |
+
|
| 58 |
+
for group in self.param_groups:
|
| 59 |
+
for p in group['params']:
|
| 60 |
+
if p.grad is None:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
grad = p.grad
|
| 64 |
+
beta1, beta2 = group['betas']
|
| 65 |
+
nesterov = group['nesterov']
|
| 66 |
+
|
| 67 |
+
state = self.state[p]
|
| 68 |
+
|
| 69 |
+
# State initialization
|
| 70 |
+
if len(state) == 0:
|
| 71 |
+
state['step'] = 0
|
| 72 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 73 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 74 |
+
|
| 75 |
+
# Adam
|
| 76 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 77 |
+
|
| 78 |
+
state['step'] += 1
|
| 79 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
| 80 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
| 81 |
+
|
| 82 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 83 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 84 |
+
|
| 85 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 86 |
+
step_size = group['lr'] / bias_correction1
|
| 87 |
+
|
| 88 |
+
if nesterov:
|
| 89 |
+
perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
|
| 90 |
+
else:
|
| 91 |
+
perturb = exp_avg / denom
|
| 92 |
+
|
| 93 |
+
# Projection
|
| 94 |
+
wd_ratio = 1.
|
| 95 |
+
if len(p.shape) > 1:
|
| 96 |
+
perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
|
| 97 |
+
|
| 98 |
+
# Weight decay
|
| 99 |
+
if group['weight_decay'] > 0:
|
| 100 |
+
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
|
| 101 |
+
|
| 102 |
+
# Step
|
| 103 |
+
p.add_(perturb, alpha=-step_size)
|
| 104 |
+
|
| 105 |
+
return loss
|
src/custom_timm/optim/adamw.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" AdamW Optimizer
|
| 2 |
+
Impl copied from PyTorch master
|
| 3 |
+
|
| 4 |
+
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
|
| 5 |
+
someday
|
| 6 |
+
"""
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
from torch.optim.optimizer import Optimizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AdamW(Optimizer):
|
| 13 |
+
r"""Implements AdamW algorithm.
|
| 14 |
+
|
| 15 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
| 16 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
| 17 |
+
|
| 18 |
+
Arguments:
|
| 19 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 20 |
+
parameter groups
|
| 21 |
+
lr (float, optional): learning rate (default: 1e-3)
|
| 22 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 23 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
| 24 |
+
eps (float, optional): term added to the denominator to improve
|
| 25 |
+
numerical stability (default: 1e-8)
|
| 26 |
+
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
| 27 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
| 28 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
| 29 |
+
(default: False)
|
| 30 |
+
|
| 31 |
+
.. _Adam\: A Method for Stochastic Optimization:
|
| 32 |
+
https://arxiv.org/abs/1412.6980
|
| 33 |
+
.. _Decoupled Weight Decay Regularization:
|
| 34 |
+
https://arxiv.org/abs/1711.05101
|
| 35 |
+
.. _On the Convergence of Adam and Beyond:
|
| 36 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
| 40 |
+
weight_decay=1e-2, amsgrad=False):
|
| 41 |
+
if not 0.0 <= lr:
|
| 42 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 43 |
+
if not 0.0 <= eps:
|
| 44 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 45 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 46 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 47 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 48 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 49 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
| 50 |
+
weight_decay=weight_decay, amsgrad=amsgrad)
|
| 51 |
+
super(AdamW, self).__init__(params, defaults)
|
| 52 |
+
|
| 53 |
+
def __setstate__(self, state):
|
| 54 |
+
super(AdamW, self).__setstate__(state)
|
| 55 |
+
for group in self.param_groups:
|
| 56 |
+
group.setdefault('amsgrad', False)
|
| 57 |
+
|
| 58 |
+
@torch.no_grad()
|
| 59 |
+
def step(self, closure=None):
|
| 60 |
+
"""Performs a single optimization step.
|
| 61 |
+
|
| 62 |
+
Arguments:
|
| 63 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 64 |
+
and returns the loss.
|
| 65 |
+
"""
|
| 66 |
+
loss = None
|
| 67 |
+
if closure is not None:
|
| 68 |
+
with torch.enable_grad():
|
| 69 |
+
loss = closure()
|
| 70 |
+
|
| 71 |
+
for group in self.param_groups:
|
| 72 |
+
for p in group['params']:
|
| 73 |
+
if p.grad is None:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
# Perform stepweight decay
|
| 77 |
+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
|
| 78 |
+
|
| 79 |
+
# Perform optimization step
|
| 80 |
+
grad = p.grad
|
| 81 |
+
if grad.is_sparse:
|
| 82 |
+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
| 83 |
+
amsgrad = group['amsgrad']
|
| 84 |
+
|
| 85 |
+
state = self.state[p]
|
| 86 |
+
|
| 87 |
+
# State initialization
|
| 88 |
+
if len(state) == 0:
|
| 89 |
+
state['step'] = 0
|
| 90 |
+
# Exponential moving average of gradient values
|
| 91 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 92 |
+
# Exponential moving average of squared gradient values
|
| 93 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 94 |
+
if amsgrad:
|
| 95 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 96 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p)
|
| 97 |
+
|
| 98 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 99 |
+
if amsgrad:
|
| 100 |
+
max_exp_avg_sq = state['max_exp_avg_sq']
|
| 101 |
+
beta1, beta2 = group['betas']
|
| 102 |
+
|
| 103 |
+
state['step'] += 1
|
| 104 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
| 105 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
| 106 |
+
|
| 107 |
+
# Decay the first and second moment running average coefficient
|
| 108 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 109 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 110 |
+
if amsgrad:
|
| 111 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
| 112 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
| 113 |
+
# Use the max. for normalizing running avg. of gradient
|
| 114 |
+
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 115 |
+
else:
|
| 116 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 117 |
+
|
| 118 |
+
step_size = group['lr'] / bias_correction1
|
| 119 |
+
|
| 120 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
| 121 |
+
|
| 122 |
+
return loss
|
src/custom_timm/optim/lamb.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
|
| 2 |
+
|
| 3 |
+
This optimizer code was adapted from the following (starting with latest)
|
| 4 |
+
* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
|
| 5 |
+
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
| 6 |
+
* https://github.com/cybertronai/pytorch-lamb
|
| 7 |
+
|
| 8 |
+
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
|
| 9 |
+
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
|
| 10 |
+
|
| 11 |
+
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
|
| 12 |
+
|
| 13 |
+
Original copyrights for above sources are below.
|
| 14 |
+
|
| 15 |
+
Modifications Copyright 2021 Ross Wightman
|
| 16 |
+
"""
|
| 17 |
+
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
|
| 18 |
+
|
| 19 |
+
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
|
| 20 |
+
#
|
| 21 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 22 |
+
# you may not use this file except in compliance with the License.
|
| 23 |
+
# You may obtain a copy of the License at
|
| 24 |
+
#
|
| 25 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 26 |
+
#
|
| 27 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 28 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 29 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 30 |
+
# See the License for the specific language governing permissions and
|
| 31 |
+
# limitations under the License.
|
| 32 |
+
|
| 33 |
+
# MIT License
|
| 34 |
+
#
|
| 35 |
+
# Copyright (c) 2019 cybertronai
|
| 36 |
+
#
|
| 37 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 38 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 39 |
+
# in the Software without restriction, including without limitation the rights
|
| 40 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 41 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 42 |
+
# furnished to do so, subject to the following conditions:
|
| 43 |
+
#
|
| 44 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 45 |
+
# copies or substantial portions of the Software.
|
| 46 |
+
#
|
| 47 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 48 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 49 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 50 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 51 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 52 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 53 |
+
# SOFTWARE.
|
| 54 |
+
import math
|
| 55 |
+
|
| 56 |
+
import torch
|
| 57 |
+
from torch.optim import Optimizer
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Lamb(Optimizer):
|
| 61 |
+
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
|
| 62 |
+
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
| 63 |
+
|
| 64 |
+
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
| 65 |
+
|
| 66 |
+
Arguments:
|
| 67 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
| 68 |
+
lr (float, optional): learning rate. (default: 1e-3)
|
| 69 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 70 |
+
running averages of gradient and its norm. (default: (0.9, 0.999))
|
| 71 |
+
eps (float, optional): term added to the denominator to improve
|
| 72 |
+
numerical stability. (default: 1e-8)
|
| 73 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 74 |
+
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
|
| 75 |
+
calculating running averages of gradient. (default: True)
|
| 76 |
+
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
|
| 77 |
+
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
| 78 |
+
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
| 79 |
+
weight decay parameter (default: False)
|
| 80 |
+
|
| 81 |
+
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
|
| 82 |
+
https://arxiv.org/abs/1904.00962
|
| 83 |
+
.. _On the Convergence of Adam and Beyond:
|
| 84 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
|
| 89 |
+
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
|
| 90 |
+
defaults = dict(
|
| 91 |
+
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
|
| 92 |
+
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
|
| 93 |
+
trust_clip=trust_clip, always_adapt=always_adapt)
|
| 94 |
+
super().__init__(params, defaults)
|
| 95 |
+
|
| 96 |
+
@torch.no_grad()
|
| 97 |
+
def step(self, closure=None):
|
| 98 |
+
"""Performs a single optimization step.
|
| 99 |
+
Arguments:
|
| 100 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 101 |
+
and returns the loss.
|
| 102 |
+
"""
|
| 103 |
+
loss = None
|
| 104 |
+
if closure is not None:
|
| 105 |
+
with torch.enable_grad():
|
| 106 |
+
loss = closure()
|
| 107 |
+
|
| 108 |
+
device = self.param_groups[0]['params'][0].device
|
| 109 |
+
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
| 110 |
+
global_grad_norm = torch.zeros(1, device=device)
|
| 111 |
+
for group in self.param_groups:
|
| 112 |
+
for p in group['params']:
|
| 113 |
+
if p.grad is None:
|
| 114 |
+
continue
|
| 115 |
+
grad = p.grad
|
| 116 |
+
if grad.is_sparse:
|
| 117 |
+
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
| 118 |
+
global_grad_norm.add_(grad.pow(2).sum())
|
| 119 |
+
|
| 120 |
+
global_grad_norm = torch.sqrt(global_grad_norm)
|
| 121 |
+
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
|
| 122 |
+
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
|
| 123 |
+
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
|
| 124 |
+
clip_global_grad_norm = torch.where(
|
| 125 |
+
global_grad_norm > max_grad_norm,
|
| 126 |
+
global_grad_norm / max_grad_norm,
|
| 127 |
+
one_tensor)
|
| 128 |
+
|
| 129 |
+
for group in self.param_groups:
|
| 130 |
+
bias_correction = 1 if group['bias_correction'] else 0
|
| 131 |
+
beta1, beta2 = group['betas']
|
| 132 |
+
grad_averaging = 1 if group['grad_averaging'] else 0
|
| 133 |
+
beta3 = 1 - beta1 if grad_averaging else 1.0
|
| 134 |
+
|
| 135 |
+
# assume same step across group now to simplify things
|
| 136 |
+
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
| 137 |
+
if 'step' in group:
|
| 138 |
+
group['step'] += 1
|
| 139 |
+
else:
|
| 140 |
+
group['step'] = 1
|
| 141 |
+
|
| 142 |
+
if bias_correction:
|
| 143 |
+
bias_correction1 = 1 - beta1 ** group['step']
|
| 144 |
+
bias_correction2 = 1 - beta2 ** group['step']
|
| 145 |
+
else:
|
| 146 |
+
bias_correction1, bias_correction2 = 1.0, 1.0
|
| 147 |
+
|
| 148 |
+
for p in group['params']:
|
| 149 |
+
if p.grad is None:
|
| 150 |
+
continue
|
| 151 |
+
grad = p.grad.div_(clip_global_grad_norm)
|
| 152 |
+
state = self.state[p]
|
| 153 |
+
|
| 154 |
+
# State initialization
|
| 155 |
+
if len(state) == 0:
|
| 156 |
+
# Exponential moving average of gradient valuesa
|
| 157 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 158 |
+
# Exponential moving average of squared gradient values
|
| 159 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 160 |
+
|
| 161 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 162 |
+
|
| 163 |
+
# Decay the first and second moment running average coefficient
|
| 164 |
+
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
|
| 165 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
|
| 166 |
+
|
| 167 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
| 168 |
+
update = (exp_avg / bias_correction1).div_(denom)
|
| 169 |
+
|
| 170 |
+
weight_decay = group['weight_decay']
|
| 171 |
+
if weight_decay != 0:
|
| 172 |
+
update.add_(p, alpha=weight_decay)
|
| 173 |
+
|
| 174 |
+
if weight_decay != 0 or group['always_adapt']:
|
| 175 |
+
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
|
| 176 |
+
# excluded from weight decay, unless always_adapt == True, then always enabled.
|
| 177 |
+
w_norm = p.norm(2.0)
|
| 178 |
+
g_norm = update.norm(2.0)
|
| 179 |
+
# FIXME nested where required since logical and/or not working in PT XLA
|
| 180 |
+
trust_ratio = torch.where(
|
| 181 |
+
w_norm > 0,
|
| 182 |
+
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
| 183 |
+
one_tensor,
|
| 184 |
+
)
|
| 185 |
+
if group['trust_clip']:
|
| 186 |
+
# LAMBC trust clipping, upper bound fixed at one
|
| 187 |
+
trust_ratio = torch.minimum(trust_ratio, one_tensor)
|
| 188 |
+
update.mul_(trust_ratio)
|
| 189 |
+
|
| 190 |
+
p.add_(update, alpha=-group['lr'])
|
| 191 |
+
|
| 192 |
+
return loss
|
src/custom_timm/optim/lars.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch LARS / LARC Optimizer
|
| 2 |
+
|
| 3 |
+
An implementation of LARS (SGD) + LARC in PyTorch
|
| 4 |
+
|
| 5 |
+
Based on:
|
| 6 |
+
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
| 7 |
+
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
| 8 |
+
|
| 9 |
+
Additional cleanup and modifications to properly support PyTorch XLA.
|
| 10 |
+
|
| 11 |
+
Copyright 2021 Ross Wightman
|
| 12 |
+
"""
|
| 13 |
+
import torch
|
| 14 |
+
from torch.optim.optimizer import Optimizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Lars(Optimizer):
|
| 18 |
+
""" LARS for PyTorch
|
| 19 |
+
|
| 20 |
+
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
| 24 |
+
lr (float, optional): learning rate (default: 1.0).
|
| 25 |
+
momentum (float, optional): momentum factor (default: 0)
|
| 26 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 27 |
+
dampening (float, optional): dampening for momentum (default: 0)
|
| 28 |
+
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
| 29 |
+
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
|
| 30 |
+
eps (float): eps for division denominator (default: 1e-8)
|
| 31 |
+
trust_clip (bool): enable LARC trust ratio clipping (default: False)
|
| 32 |
+
always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
params,
|
| 38 |
+
lr=1.0,
|
| 39 |
+
momentum=0,
|
| 40 |
+
dampening=0,
|
| 41 |
+
weight_decay=0,
|
| 42 |
+
nesterov=False,
|
| 43 |
+
trust_coeff=0.001,
|
| 44 |
+
eps=1e-8,
|
| 45 |
+
trust_clip=False,
|
| 46 |
+
always_adapt=False,
|
| 47 |
+
):
|
| 48 |
+
if lr < 0.0:
|
| 49 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
| 50 |
+
if momentum < 0.0:
|
| 51 |
+
raise ValueError(f"Invalid momentum value: {momentum}")
|
| 52 |
+
if weight_decay < 0.0:
|
| 53 |
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
| 54 |
+
if nesterov and (momentum <= 0 or dampening != 0):
|
| 55 |
+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
| 56 |
+
|
| 57 |
+
defaults = dict(
|
| 58 |
+
lr=lr,
|
| 59 |
+
momentum=momentum,
|
| 60 |
+
dampening=dampening,
|
| 61 |
+
weight_decay=weight_decay,
|
| 62 |
+
nesterov=nesterov,
|
| 63 |
+
trust_coeff=trust_coeff,
|
| 64 |
+
eps=eps,
|
| 65 |
+
trust_clip=trust_clip,
|
| 66 |
+
always_adapt=always_adapt,
|
| 67 |
+
)
|
| 68 |
+
super().__init__(params, defaults)
|
| 69 |
+
|
| 70 |
+
def __setstate__(self, state):
|
| 71 |
+
super().__setstate__(state)
|
| 72 |
+
for group in self.param_groups:
|
| 73 |
+
group.setdefault("nesterov", False)
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def step(self, closure=None):
|
| 77 |
+
"""Performs a single optimization step.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
| 81 |
+
"""
|
| 82 |
+
loss = None
|
| 83 |
+
if closure is not None:
|
| 84 |
+
with torch.enable_grad():
|
| 85 |
+
loss = closure()
|
| 86 |
+
|
| 87 |
+
device = self.param_groups[0]['params'][0].device
|
| 88 |
+
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
| 89 |
+
|
| 90 |
+
for group in self.param_groups:
|
| 91 |
+
weight_decay = group['weight_decay']
|
| 92 |
+
momentum = group['momentum']
|
| 93 |
+
dampening = group['dampening']
|
| 94 |
+
nesterov = group['nesterov']
|
| 95 |
+
trust_coeff = group['trust_coeff']
|
| 96 |
+
eps = group['eps']
|
| 97 |
+
|
| 98 |
+
for p in group['params']:
|
| 99 |
+
if p.grad is None:
|
| 100 |
+
continue
|
| 101 |
+
grad = p.grad
|
| 102 |
+
|
| 103 |
+
# apply LARS LR adaptation, LARC clipping, weight decay
|
| 104 |
+
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
| 105 |
+
if weight_decay != 0 or group['always_adapt']:
|
| 106 |
+
w_norm = p.norm(2.0)
|
| 107 |
+
g_norm = grad.norm(2.0)
|
| 108 |
+
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
|
| 109 |
+
# FIXME nested where required since logical and/or not working in PT XLA
|
| 110 |
+
trust_ratio = torch.where(
|
| 111 |
+
w_norm > 0,
|
| 112 |
+
torch.where(g_norm > 0, trust_ratio, one_tensor),
|
| 113 |
+
one_tensor,
|
| 114 |
+
)
|
| 115 |
+
if group['trust_clip']:
|
| 116 |
+
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
|
| 117 |
+
grad.add_(p, alpha=weight_decay)
|
| 118 |
+
grad.mul_(trust_ratio)
|
| 119 |
+
|
| 120 |
+
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
| 121 |
+
if momentum != 0:
|
| 122 |
+
param_state = self.state[p]
|
| 123 |
+
if 'momentum_buffer' not in param_state:
|
| 124 |
+
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
|
| 125 |
+
else:
|
| 126 |
+
buf = param_state['momentum_buffer']
|
| 127 |
+
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
|
| 128 |
+
if nesterov:
|
| 129 |
+
grad = grad.add(buf, alpha=momentum)
|
| 130 |
+
else:
|
| 131 |
+
grad = buf
|
| 132 |
+
|
| 133 |
+
p.add_(grad, alpha=-group['lr'])
|
| 134 |
+
|
| 135 |
+
return loss
|
src/custom_timm/optim/lookahead.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Lookahead Optimizer Wrapper.
|
| 2 |
+
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
|
| 3 |
+
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
|
| 4 |
+
|
| 5 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
from torch.optim.optimizer import Optimizer
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Lookahead(Optimizer):
|
| 13 |
+
def __init__(self, base_optimizer, alpha=0.5, k=6):
|
| 14 |
+
# NOTE super().__init__() not called on purpose
|
| 15 |
+
if not 0.0 <= alpha <= 1.0:
|
| 16 |
+
raise ValueError(f'Invalid slow update rate: {alpha}')
|
| 17 |
+
if not 1 <= k:
|
| 18 |
+
raise ValueError(f'Invalid lookahead steps: {k}')
|
| 19 |
+
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
| 20 |
+
self._base_optimizer = base_optimizer
|
| 21 |
+
self.param_groups = base_optimizer.param_groups
|
| 22 |
+
self.defaults = base_optimizer.defaults
|
| 23 |
+
self.defaults.update(defaults)
|
| 24 |
+
self.state = defaultdict(dict)
|
| 25 |
+
# manually add our defaults to the param groups
|
| 26 |
+
for name, default in defaults.items():
|
| 27 |
+
for group in self._base_optimizer.param_groups:
|
| 28 |
+
group.setdefault(name, default)
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def update_slow(self, group):
|
| 32 |
+
for fast_p in group["params"]:
|
| 33 |
+
if fast_p.grad is None:
|
| 34 |
+
continue
|
| 35 |
+
param_state = self._base_optimizer.state[fast_p]
|
| 36 |
+
if 'lookahead_slow_buff' not in param_state:
|
| 37 |
+
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
|
| 38 |
+
param_state['lookahead_slow_buff'].copy_(fast_p)
|
| 39 |
+
slow = param_state['lookahead_slow_buff']
|
| 40 |
+
slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
|
| 41 |
+
fast_p.copy_(slow)
|
| 42 |
+
|
| 43 |
+
def sync_lookahead(self):
|
| 44 |
+
for group in self._base_optimizer.param_groups:
|
| 45 |
+
self.update_slow(group)
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def step(self, closure=None):
|
| 49 |
+
loss = self._base_optimizer.step(closure)
|
| 50 |
+
for group in self._base_optimizer.param_groups:
|
| 51 |
+
group['lookahead_step'] += 1
|
| 52 |
+
if group['lookahead_step'] % group['lookahead_k'] == 0:
|
| 53 |
+
self.update_slow(group)
|
| 54 |
+
return loss
|
| 55 |
+
|
| 56 |
+
def state_dict(self):
|
| 57 |
+
return self._base_optimizer.state_dict()
|
| 58 |
+
|
| 59 |
+
def load_state_dict(self, state_dict):
|
| 60 |
+
self._base_optimizer.load_state_dict(state_dict)
|
| 61 |
+
self.param_groups = self._base_optimizer.param_groups
|
src/custom_timm/optim/madgrad.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch MADGRAD optimizer
|
| 2 |
+
|
| 3 |
+
MADGRAD: https://arxiv.org/abs/2101.11075
|
| 4 |
+
|
| 5 |
+
Code from: https://github.com/facebookresearch/madgrad
|
| 6 |
+
"""
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 8 |
+
#
|
| 9 |
+
# This source code is licensed under the MIT license found in the
|
| 10 |
+
# LICENSE file in the root directory of this source tree.
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.optim
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from torch.optim.optimizer import _params_t
|
| 20 |
+
else:
|
| 21 |
+
_params_t = Any
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MADGRAD(torch.optim.Optimizer):
|
| 25 |
+
"""
|
| 26 |
+
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
|
| 27 |
+
Optimization.
|
| 28 |
+
|
| 29 |
+
.. _MADGRAD: https://arxiv.org/abs/2101.11075
|
| 30 |
+
|
| 31 |
+
MADGRAD is a general purpose optimizer that can be used in place of SGD or
|
| 32 |
+
Adam may converge faster and generalize better. Currently GPU-only.
|
| 33 |
+
Typically, the same learning rate schedule that is used for SGD or Adam may
|
| 34 |
+
be used. The overall learning rate is not comparable to either method and
|
| 35 |
+
should be determined by a hyper-parameter sweep.
|
| 36 |
+
|
| 37 |
+
MADGRAD requires less weight decay than other methods, often as little as
|
| 38 |
+
zero. Momentum values used for SGD or Adam's beta1 should work here also.
|
| 39 |
+
|
| 40 |
+
On sparse problems both weight_decay and momentum should be set to 0.
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
params (iterable):
|
| 44 |
+
Iterable of parameters to optimize or dicts defining parameter groups.
|
| 45 |
+
lr (float):
|
| 46 |
+
Learning rate (default: 1e-2).
|
| 47 |
+
momentum (float):
|
| 48 |
+
Momentum value in the range [0,1) (default: 0.9).
|
| 49 |
+
weight_decay (float):
|
| 50 |
+
Weight decay, i.e. a L2 penalty (default: 0).
|
| 51 |
+
eps (float):
|
| 52 |
+
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
params: _params_t,
|
| 58 |
+
lr: float = 1e-2,
|
| 59 |
+
momentum: float = 0.9,
|
| 60 |
+
weight_decay: float = 0,
|
| 61 |
+
eps: float = 1e-6,
|
| 62 |
+
decoupled_decay: bool = False,
|
| 63 |
+
):
|
| 64 |
+
if momentum < 0 or momentum >= 1:
|
| 65 |
+
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
| 66 |
+
if lr <= 0:
|
| 67 |
+
raise ValueError(f"Learning rate {lr} must be positive")
|
| 68 |
+
if weight_decay < 0:
|
| 69 |
+
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
|
| 70 |
+
if eps < 0:
|
| 71 |
+
raise ValueError(f"Eps must be non-negative")
|
| 72 |
+
|
| 73 |
+
defaults = dict(
|
| 74 |
+
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
| 75 |
+
super().__init__(params, defaults)
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def supports_memory_efficient_fp16(self) -> bool:
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def supports_flat_params(self) -> bool:
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
| 87 |
+
"""Performs a single optimization step.
|
| 88 |
+
|
| 89 |
+
Arguments:
|
| 90 |
+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
| 91 |
+
"""
|
| 92 |
+
loss = None
|
| 93 |
+
if closure is not None:
|
| 94 |
+
with torch.enable_grad():
|
| 95 |
+
loss = closure()
|
| 96 |
+
|
| 97 |
+
for group in self.param_groups:
|
| 98 |
+
eps = group['eps']
|
| 99 |
+
lr = group['lr'] + eps
|
| 100 |
+
weight_decay = group['weight_decay']
|
| 101 |
+
momentum = group['momentum']
|
| 102 |
+
ck = 1 - momentum
|
| 103 |
+
|
| 104 |
+
for p in group["params"]:
|
| 105 |
+
if p.grad is None:
|
| 106 |
+
continue
|
| 107 |
+
grad = p.grad
|
| 108 |
+
if momentum != 0.0 and grad.is_sparse:
|
| 109 |
+
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
| 110 |
+
|
| 111 |
+
state = self.state[p]
|
| 112 |
+
if len(state) == 0:
|
| 113 |
+
state['step'] = 0
|
| 114 |
+
state['grad_sum_sq'] = torch.zeros_like(p)
|
| 115 |
+
state['s'] = torch.zeros_like(p)
|
| 116 |
+
if momentum != 0:
|
| 117 |
+
state['x0'] = torch.clone(p).detach()
|
| 118 |
+
|
| 119 |
+
state['step'] += 1
|
| 120 |
+
grad_sum_sq = state['grad_sum_sq']
|
| 121 |
+
s = state['s']
|
| 122 |
+
lamb = lr * math.sqrt(state['step'])
|
| 123 |
+
|
| 124 |
+
# Apply weight decay
|
| 125 |
+
if weight_decay != 0:
|
| 126 |
+
if group['decoupled_decay']:
|
| 127 |
+
p.mul_(1.0 - group['lr'] * weight_decay)
|
| 128 |
+
else:
|
| 129 |
+
if grad.is_sparse:
|
| 130 |
+
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
| 131 |
+
grad.add_(p, alpha=weight_decay)
|
| 132 |
+
|
| 133 |
+
if grad.is_sparse:
|
| 134 |
+
grad = grad.coalesce()
|
| 135 |
+
grad_val = grad._values()
|
| 136 |
+
|
| 137 |
+
p_masked = p.sparse_mask(grad)
|
| 138 |
+
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
|
| 139 |
+
s_masked = s.sparse_mask(grad)
|
| 140 |
+
|
| 141 |
+
# Compute x_0 from other known quantities
|
| 142 |
+
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
|
| 143 |
+
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
|
| 144 |
+
|
| 145 |
+
# Dense + sparse op
|
| 146 |
+
grad_sq = grad * grad
|
| 147 |
+
grad_sum_sq.add_(grad_sq, alpha=lamb)
|
| 148 |
+
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
|
| 149 |
+
|
| 150 |
+
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
|
| 151 |
+
|
| 152 |
+
s.add_(grad, alpha=lamb)
|
| 153 |
+
s_masked._values().add_(grad_val, alpha=lamb)
|
| 154 |
+
|
| 155 |
+
# update masked copy of p
|
| 156 |
+
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
| 157 |
+
# Copy updated masked p to dense p using an add operation
|
| 158 |
+
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
| 159 |
+
p.add_(p_masked, alpha=-1)
|
| 160 |
+
else:
|
| 161 |
+
if momentum == 0:
|
| 162 |
+
# Compute x_0 from other known quantities
|
| 163 |
+
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
| 164 |
+
x0 = p.addcdiv(s, rms, value=1)
|
| 165 |
+
else:
|
| 166 |
+
x0 = state['x0']
|
| 167 |
+
|
| 168 |
+
# Accumulate second moments
|
| 169 |
+
grad_sum_sq.addcmul_(grad, grad, value=lamb)
|
| 170 |
+
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
| 171 |
+
|
| 172 |
+
# Update s
|
| 173 |
+
s.add_(grad, alpha=lamb)
|
| 174 |
+
|
| 175 |
+
# Step
|
| 176 |
+
if momentum == 0:
|
| 177 |
+
p.copy_(x0.addcdiv(s, rms, value=-1))
|
| 178 |
+
else:
|
| 179 |
+
z = x0.addcdiv(s, rms, value=-1)
|
| 180 |
+
|
| 181 |
+
# p is a moving average of z
|
| 182 |
+
p.mul_(1 - ck).add_(z, alpha=ck)
|
| 183 |
+
|
| 184 |
+
return loss
|
src/custom_timm/optim/nadam.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.optim.optimizer import Optimizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Nadam(Optimizer):
|
| 8 |
+
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
|
| 9 |
+
|
| 10 |
+
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
|
| 11 |
+
|
| 12 |
+
Arguments:
|
| 13 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 14 |
+
parameter groups
|
| 15 |
+
lr (float, optional): learning rate (default: 2e-3)
|
| 16 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 17 |
+
running averages of gradient and its square
|
| 18 |
+
eps (float, optional): term added to the denominator to improve
|
| 19 |
+
numerical stability (default: 1e-8)
|
| 20 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 21 |
+
schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
|
| 22 |
+
|
| 23 |
+
__ http://cs229.stanford.edu/proj2015/054_report.pdf
|
| 24 |
+
__ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
|
| 25 |
+
|
| 26 |
+
Originally taken from: https://github.com/pytorch/pytorch/pull/1408
|
| 27 |
+
NOTE: Has potential issues but does work well on some problems.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
|
| 31 |
+
weight_decay=0, schedule_decay=4e-3):
|
| 32 |
+
if not 0.0 <= lr:
|
| 33 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 34 |
+
defaults = dict(
|
| 35 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
| 36 |
+
super(Nadam, self).__init__(params, defaults)
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def step(self, closure=None):
|
| 40 |
+
"""Performs a single optimization step.
|
| 41 |
+
|
| 42 |
+
Arguments:
|
| 43 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 44 |
+
and returns the loss.
|
| 45 |
+
"""
|
| 46 |
+
loss = None
|
| 47 |
+
if closure is not None:
|
| 48 |
+
with torch.enable_grad():
|
| 49 |
+
loss = closure()
|
| 50 |
+
|
| 51 |
+
for group in self.param_groups:
|
| 52 |
+
for p in group['params']:
|
| 53 |
+
if p.grad is None:
|
| 54 |
+
continue
|
| 55 |
+
grad = p.grad
|
| 56 |
+
state = self.state[p]
|
| 57 |
+
|
| 58 |
+
# State initialization
|
| 59 |
+
if len(state) == 0:
|
| 60 |
+
state['step'] = 0
|
| 61 |
+
state['m_schedule'] = 1.
|
| 62 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 63 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 64 |
+
|
| 65 |
+
# Warming momentum schedule
|
| 66 |
+
m_schedule = state['m_schedule']
|
| 67 |
+
schedule_decay = group['schedule_decay']
|
| 68 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 69 |
+
beta1, beta2 = group['betas']
|
| 70 |
+
eps = group['eps']
|
| 71 |
+
state['step'] += 1
|
| 72 |
+
t = state['step']
|
| 73 |
+
bias_correction2 = 1 - beta2 ** t
|
| 74 |
+
|
| 75 |
+
if group['weight_decay'] != 0:
|
| 76 |
+
grad = grad.add(p, alpha=group['weight_decay'])
|
| 77 |
+
|
| 78 |
+
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
|
| 79 |
+
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
|
| 80 |
+
m_schedule_new = m_schedule * momentum_cache_t
|
| 81 |
+
m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
|
| 82 |
+
state['m_schedule'] = m_schedule_new
|
| 83 |
+
|
| 84 |
+
# Decay the first and second moment running average coefficient
|
| 85 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
|
| 86 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
|
| 87 |
+
|
| 88 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
| 89 |
+
p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
|
| 90 |
+
p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
|
| 91 |
+
|
| 92 |
+
return loss
|
src/custom_timm/optim/nvnovograd.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Nvidia NovoGrad Optimizer.
|
| 2 |
+
Original impl by Nvidia from Jasper example:
|
| 3 |
+
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
|
| 4 |
+
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
| 5 |
+
- https://arxiv.org/abs/1905.11286
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.optim.optimizer import Optimizer
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NvNovoGrad(Optimizer):
|
| 14 |
+
"""
|
| 15 |
+
Implements Novograd algorithm.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 19 |
+
parameter groups
|
| 20 |
+
lr (float, optional): learning rate (default: 1e-3)
|
| 21 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 22 |
+
running averages of gradient and its square (default: (0.95, 0.98))
|
| 23 |
+
eps (float, optional): term added to the denominator to improve
|
| 24 |
+
numerical stability (default: 1e-8)
|
| 25 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 26 |
+
grad_averaging: gradient averaging
|
| 27 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
| 28 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
| 29 |
+
(default: False)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
|
| 33 |
+
weight_decay=0, grad_averaging=False, amsgrad=False):
|
| 34 |
+
if not 0.0 <= lr:
|
| 35 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 36 |
+
if not 0.0 <= eps:
|
| 37 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 38 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 39 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 40 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 41 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 42 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
| 43 |
+
weight_decay=weight_decay,
|
| 44 |
+
grad_averaging=grad_averaging,
|
| 45 |
+
amsgrad=amsgrad)
|
| 46 |
+
|
| 47 |
+
super(NvNovoGrad, self).__init__(params, defaults)
|
| 48 |
+
|
| 49 |
+
def __setstate__(self, state):
|
| 50 |
+
super(NvNovoGrad, self).__setstate__(state)
|
| 51 |
+
for group in self.param_groups:
|
| 52 |
+
group.setdefault('amsgrad', False)
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def step(self, closure=None):
|
| 56 |
+
"""Performs a single optimization step.
|
| 57 |
+
|
| 58 |
+
Arguments:
|
| 59 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 60 |
+
and returns the loss.
|
| 61 |
+
"""
|
| 62 |
+
loss = None
|
| 63 |
+
if closure is not None:
|
| 64 |
+
with torch.enable_grad():
|
| 65 |
+
loss = closure()
|
| 66 |
+
|
| 67 |
+
for group in self.param_groups:
|
| 68 |
+
for p in group['params']:
|
| 69 |
+
if p.grad is None:
|
| 70 |
+
continue
|
| 71 |
+
grad = p.grad
|
| 72 |
+
if grad.is_sparse:
|
| 73 |
+
raise RuntimeError('Sparse gradients are not supported.')
|
| 74 |
+
amsgrad = group['amsgrad']
|
| 75 |
+
|
| 76 |
+
state = self.state[p]
|
| 77 |
+
|
| 78 |
+
# State initialization
|
| 79 |
+
if len(state) == 0:
|
| 80 |
+
state['step'] = 0
|
| 81 |
+
# Exponential moving average of gradient values
|
| 82 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 83 |
+
# Exponential moving average of squared gradient values
|
| 84 |
+
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
| 85 |
+
if amsgrad:
|
| 86 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 87 |
+
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
| 88 |
+
|
| 89 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 90 |
+
if amsgrad:
|
| 91 |
+
max_exp_avg_sq = state['max_exp_avg_sq']
|
| 92 |
+
beta1, beta2 = group['betas']
|
| 93 |
+
|
| 94 |
+
state['step'] += 1
|
| 95 |
+
|
| 96 |
+
norm = torch.sum(torch.pow(grad, 2))
|
| 97 |
+
|
| 98 |
+
if exp_avg_sq == 0:
|
| 99 |
+
exp_avg_sq.copy_(norm)
|
| 100 |
+
else:
|
| 101 |
+
exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
|
| 102 |
+
|
| 103 |
+
if amsgrad:
|
| 104 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
| 105 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
| 106 |
+
# Use the max. for normalizing running avg. of gradient
|
| 107 |
+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
| 108 |
+
else:
|
| 109 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
| 110 |
+
|
| 111 |
+
grad.div_(denom)
|
| 112 |
+
if group['weight_decay'] != 0:
|
| 113 |
+
grad.add_(p, alpha=group['weight_decay'])
|
| 114 |
+
if group['grad_averaging']:
|
| 115 |
+
grad.mul_(1 - beta1)
|
| 116 |
+
exp_avg.mul_(beta1).add_(grad)
|
| 117 |
+
|
| 118 |
+
p.add_(exp_avg, alpha=-group['lr'])
|
| 119 |
+
|
| 120 |
+
return loss
|
src/custom_timm/optim/optim_factory.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Optimizer Factory w/ Custom Weight Decay
|
| 2 |
+
Hacked together by / Copyright 2021 Ross Wightman
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from itertools import islice
|
| 6 |
+
from typing import Optional, Callable, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
|
| 12 |
+
from custom_timm.models.helpers import group_parameters
|
| 13 |
+
|
| 14 |
+
from .adabelief import AdaBelief
|
| 15 |
+
from .adafactor import Adafactor
|
| 16 |
+
from .adahessian import Adahessian
|
| 17 |
+
from .adamp import AdamP
|
| 18 |
+
from .lamb import Lamb
|
| 19 |
+
from .lars import Lars
|
| 20 |
+
from .lookahead import Lookahead
|
| 21 |
+
from .madgrad import MADGRAD
|
| 22 |
+
from .nadam import Nadam
|
| 23 |
+
from .nvnovograd import NvNovoGrad
|
| 24 |
+
from .radam import RAdam
|
| 25 |
+
from .rmsprop_tf import RMSpropTF
|
| 26 |
+
from .sgdp import SGDP
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
| 30 |
+
has_apex = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
has_apex = False
|
| 33 |
+
|
| 34 |
+
_logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def param_groups_weight_decay(
|
| 38 |
+
model: nn.Module,
|
| 39 |
+
weight_decay=1e-5,
|
| 40 |
+
no_weight_decay_list=()
|
| 41 |
+
):
|
| 42 |
+
no_weight_decay_list = set(no_weight_decay_list)
|
| 43 |
+
decay = []
|
| 44 |
+
no_decay = []
|
| 45 |
+
for name, param in model.named_parameters():
|
| 46 |
+
if not param.requires_grad:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
| 50 |
+
no_decay.append(param)
|
| 51 |
+
else:
|
| 52 |
+
decay.append(param)
|
| 53 |
+
|
| 54 |
+
return [
|
| 55 |
+
{'params': no_decay, 'weight_decay': 0.},
|
| 56 |
+
{'params': decay, 'weight_decay': weight_decay}]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _group(it, size):
|
| 60 |
+
it = iter(it)
|
| 61 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _layer_map(model, layers_per_group=12, num_groups=None):
|
| 65 |
+
def _in_head(n, hp):
|
| 66 |
+
if not hp:
|
| 67 |
+
return True
|
| 68 |
+
elif isinstance(hp, (tuple, list)):
|
| 69 |
+
return any([n.startswith(hpi) for hpi in hp])
|
| 70 |
+
else:
|
| 71 |
+
return n.startswith(hp)
|
| 72 |
+
|
| 73 |
+
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
|
| 74 |
+
names_trunk = []
|
| 75 |
+
names_head = []
|
| 76 |
+
for n, _ in model.named_parameters():
|
| 77 |
+
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
|
| 78 |
+
|
| 79 |
+
# group non-head layers
|
| 80 |
+
num_trunk_layers = len(names_trunk)
|
| 81 |
+
if num_groups is not None:
|
| 82 |
+
layers_per_group = -(num_trunk_layers // -num_groups)
|
| 83 |
+
names_trunk = list(_group(names_trunk, layers_per_group))
|
| 84 |
+
|
| 85 |
+
num_trunk_groups = len(names_trunk)
|
| 86 |
+
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
|
| 87 |
+
layer_map.update({n: num_trunk_groups for n in names_head})
|
| 88 |
+
return layer_map
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def param_groups_layer_decay(
|
| 92 |
+
model: nn.Module,
|
| 93 |
+
weight_decay: float = 0.05,
|
| 94 |
+
no_weight_decay_list: Tuple[str] = (),
|
| 95 |
+
layer_decay: float = .75,
|
| 96 |
+
end_layer_decay: Optional[float] = None,
|
| 97 |
+
verbose: bool = False,
|
| 98 |
+
):
|
| 99 |
+
"""
|
| 100 |
+
Parameter groups for layer-wise lr decay & weight decay
|
| 101 |
+
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
| 102 |
+
"""
|
| 103 |
+
no_weight_decay_list = set(no_weight_decay_list)
|
| 104 |
+
param_group_names = {} # NOTE for debugging
|
| 105 |
+
param_groups = {}
|
| 106 |
+
|
| 107 |
+
if hasattr(model, 'group_matcher'):
|
| 108 |
+
# FIXME interface needs more work
|
| 109 |
+
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
| 110 |
+
else:
|
| 111 |
+
# fallback
|
| 112 |
+
layer_map = _layer_map(model)
|
| 113 |
+
num_layers = max(layer_map.values()) + 1
|
| 114 |
+
layer_max = num_layers - 1
|
| 115 |
+
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
| 116 |
+
|
| 117 |
+
for name, param in model.named_parameters():
|
| 118 |
+
if not param.requires_grad:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
# no decay: all 1D parameters and model specific ones
|
| 122 |
+
if param.ndim == 1 or name in no_weight_decay_list:
|
| 123 |
+
g_decay = "no_decay"
|
| 124 |
+
this_decay = 0.
|
| 125 |
+
else:
|
| 126 |
+
g_decay = "decay"
|
| 127 |
+
this_decay = weight_decay
|
| 128 |
+
|
| 129 |
+
layer_id = layer_map.get(name, layer_max)
|
| 130 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
| 131 |
+
|
| 132 |
+
if group_name not in param_groups:
|
| 133 |
+
this_scale = layer_scales[layer_id]
|
| 134 |
+
param_group_names[group_name] = {
|
| 135 |
+
"lr_scale": this_scale,
|
| 136 |
+
"weight_decay": this_decay,
|
| 137 |
+
"param_names": [],
|
| 138 |
+
}
|
| 139 |
+
param_groups[group_name] = {
|
| 140 |
+
"lr_scale": this_scale,
|
| 141 |
+
"weight_decay": this_decay,
|
| 142 |
+
"params": [],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
param_group_names[group_name]["param_names"].append(name)
|
| 146 |
+
param_groups[group_name]["params"].append(param)
|
| 147 |
+
|
| 148 |
+
if verbose:
|
| 149 |
+
import json
|
| 150 |
+
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
| 151 |
+
|
| 152 |
+
return list(param_groups.values())
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def optimizer_kwargs(cfg):
|
| 156 |
+
""" cfg/argparse to kwargs helper
|
| 157 |
+
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
| 158 |
+
"""
|
| 159 |
+
kwargs = dict(
|
| 160 |
+
opt=cfg.opt,
|
| 161 |
+
lr=cfg.lr,
|
| 162 |
+
weight_decay=cfg.weight_decay,
|
| 163 |
+
momentum=cfg.momentum)
|
| 164 |
+
if getattr(cfg, 'opt_eps', None) is not None:
|
| 165 |
+
kwargs['eps'] = cfg.opt_eps
|
| 166 |
+
if getattr(cfg, 'opt_betas', None) is not None:
|
| 167 |
+
kwargs['betas'] = cfg.opt_betas
|
| 168 |
+
if getattr(cfg, 'layer_decay', None) is not None:
|
| 169 |
+
kwargs['layer_decay'] = cfg.layer_decay
|
| 170 |
+
if getattr(cfg, 'opt_args', None) is not None:
|
| 171 |
+
kwargs.update(cfg.opt_args)
|
| 172 |
+
return kwargs
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def create_optimizer(args, model, filter_bias_and_bn=True):
|
| 176 |
+
""" Legacy optimizer factory for backwards compatibility.
|
| 177 |
+
NOTE: Use create_optimizer_v2 for new code.
|
| 178 |
+
"""
|
| 179 |
+
return create_optimizer_v2(
|
| 180 |
+
model,
|
| 181 |
+
**optimizer_kwargs(cfg=args),
|
| 182 |
+
filter_bias_and_bn=filter_bias_and_bn,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def create_optimizer_v2(
|
| 187 |
+
model_or_params,
|
| 188 |
+
opt: str = 'sgd',
|
| 189 |
+
lr: Optional[float] = None,
|
| 190 |
+
weight_decay: float = 0.,
|
| 191 |
+
momentum: float = 0.9,
|
| 192 |
+
filter_bias_and_bn: bool = True,
|
| 193 |
+
layer_decay: Optional[float] = None,
|
| 194 |
+
param_group_fn: Optional[Callable] = None,
|
| 195 |
+
**kwargs):
|
| 196 |
+
""" Create an optimizer.
|
| 197 |
+
|
| 198 |
+
TODO currently the model is passed in and all parameters are selected for optimization.
|
| 199 |
+
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
|
| 200 |
+
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
|
| 201 |
+
* expose the parameters interface and leave it up to caller
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
model_or_params (nn.Module): model containing parameters to optimize
|
| 205 |
+
opt: name of optimizer to create
|
| 206 |
+
lr: initial learning rate
|
| 207 |
+
weight_decay: weight decay to apply in optimizer
|
| 208 |
+
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
|
| 209 |
+
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
|
| 210 |
+
**kwargs: extra optimizer specific kwargs to pass through
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Optimizer
|
| 214 |
+
"""
|
| 215 |
+
if isinstance(model_or_params, nn.Module):
|
| 216 |
+
# a model was passed in, extract parameters and add weight decays to appropriate layers
|
| 217 |
+
no_weight_decay = {}
|
| 218 |
+
if hasattr(model_or_params, 'no_weight_decay'):
|
| 219 |
+
no_weight_decay = model_or_params.no_weight_decay()
|
| 220 |
+
|
| 221 |
+
if param_group_fn:
|
| 222 |
+
parameters = param_group_fn(model_or_params)
|
| 223 |
+
elif layer_decay is not None:
|
| 224 |
+
parameters = param_groups_layer_decay(
|
| 225 |
+
model_or_params,
|
| 226 |
+
weight_decay=weight_decay,
|
| 227 |
+
layer_decay=layer_decay,
|
| 228 |
+
no_weight_decay_list=no_weight_decay)
|
| 229 |
+
weight_decay = 0.
|
| 230 |
+
elif weight_decay and filter_bias_and_bn:
|
| 231 |
+
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
|
| 232 |
+
weight_decay = 0.
|
| 233 |
+
else:
|
| 234 |
+
parameters = model_or_params.parameters()
|
| 235 |
+
else:
|
| 236 |
+
# iterable of parameters or param groups passed in
|
| 237 |
+
parameters = model_or_params
|
| 238 |
+
|
| 239 |
+
opt_lower = opt.lower()
|
| 240 |
+
opt_split = opt_lower.split('_')
|
| 241 |
+
opt_lower = opt_split[-1]
|
| 242 |
+
if 'fused' in opt_lower:
|
| 243 |
+
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
| 244 |
+
|
| 245 |
+
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
| 246 |
+
if lr is not None:
|
| 247 |
+
opt_args.setdefault('lr', lr)
|
| 248 |
+
|
| 249 |
+
# basic SGD & related
|
| 250 |
+
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
| 251 |
+
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
|
| 252 |
+
opt_args.pop('eps', None)
|
| 253 |
+
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
| 254 |
+
elif opt_lower == 'momentum':
|
| 255 |
+
opt_args.pop('eps', None)
|
| 256 |
+
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
| 257 |
+
elif opt_lower == 'sgdp':
|
| 258 |
+
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
|
| 259 |
+
|
| 260 |
+
# adaptive
|
| 261 |
+
elif opt_lower == 'adam':
|
| 262 |
+
optimizer = optim.Adam(parameters, **opt_args)
|
| 263 |
+
elif opt_lower == 'adamw':
|
| 264 |
+
optimizer = optim.AdamW(parameters, **opt_args)
|
| 265 |
+
elif opt_lower == 'adamp':
|
| 266 |
+
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
| 267 |
+
elif opt_lower == 'nadam':
|
| 268 |
+
try:
|
| 269 |
+
# NOTE PyTorch >= 1.10 should have native NAdam
|
| 270 |
+
optimizer = optim.Nadam(parameters, **opt_args)
|
| 271 |
+
except AttributeError:
|
| 272 |
+
optimizer = Nadam(parameters, **opt_args)
|
| 273 |
+
elif opt_lower == 'radam':
|
| 274 |
+
optimizer = RAdam(parameters, **opt_args)
|
| 275 |
+
elif opt_lower == 'adamax':
|
| 276 |
+
optimizer = optim.Adamax(parameters, **opt_args)
|
| 277 |
+
elif opt_lower == 'adabelief':
|
| 278 |
+
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
|
| 279 |
+
elif opt_lower == 'radabelief':
|
| 280 |
+
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
|
| 281 |
+
elif opt_lower == 'adadelta':
|
| 282 |
+
optimizer = optim.Adadelta(parameters, **opt_args)
|
| 283 |
+
elif opt_lower == 'adagrad':
|
| 284 |
+
opt_args.setdefault('eps', 1e-8)
|
| 285 |
+
optimizer = optim.Adagrad(parameters, **opt_args)
|
| 286 |
+
elif opt_lower == 'adafactor':
|
| 287 |
+
optimizer = Adafactor(parameters, **opt_args)
|
| 288 |
+
elif opt_lower == 'lamb':
|
| 289 |
+
optimizer = Lamb(parameters, **opt_args)
|
| 290 |
+
elif opt_lower == 'lambc':
|
| 291 |
+
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
|
| 292 |
+
elif opt_lower == 'larc':
|
| 293 |
+
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
|
| 294 |
+
elif opt_lower == 'lars':
|
| 295 |
+
optimizer = Lars(parameters, momentum=momentum, **opt_args)
|
| 296 |
+
elif opt_lower == 'nlarc':
|
| 297 |
+
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
|
| 298 |
+
elif opt_lower == 'nlars':
|
| 299 |
+
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
|
| 300 |
+
elif opt_lower == 'madgrad':
|
| 301 |
+
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
| 302 |
+
elif opt_lower == 'madgradw':
|
| 303 |
+
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
|
| 304 |
+
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
| 305 |
+
optimizer = NvNovoGrad(parameters, **opt_args)
|
| 306 |
+
elif opt_lower == 'rmsprop':
|
| 307 |
+
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
| 308 |
+
elif opt_lower == 'rmsproptf':
|
| 309 |
+
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
| 310 |
+
|
| 311 |
+
# second order
|
| 312 |
+
elif opt_lower == 'adahessian':
|
| 313 |
+
optimizer = Adahessian(parameters, **opt_args)
|
| 314 |
+
|
| 315 |
+
# NVIDIA fused optimizers, require APEX to be installed
|
| 316 |
+
elif opt_lower == 'fusedsgd':
|
| 317 |
+
opt_args.pop('eps', None)
|
| 318 |
+
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
| 319 |
+
elif opt_lower == 'fusedmomentum':
|
| 320 |
+
opt_args.pop('eps', None)
|
| 321 |
+
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
| 322 |
+
elif opt_lower == 'fusedadam':
|
| 323 |
+
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
| 324 |
+
elif opt_lower == 'fusedadamw':
|
| 325 |
+
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
| 326 |
+
elif opt_lower == 'fusedlamb':
|
| 327 |
+
optimizer = FusedLAMB(parameters, **opt_args)
|
| 328 |
+
elif opt_lower == 'fusednovograd':
|
| 329 |
+
opt_args.setdefault('betas', (0.95, 0.98))
|
| 330 |
+
optimizer = FusedNovoGrad(parameters, **opt_args)
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
assert False and "Invalid optimizer"
|
| 334 |
+
raise ValueError
|
| 335 |
+
|
| 336 |
+
if len(opt_split) > 1:
|
| 337 |
+
if opt_split[0] == 'lookahead':
|
| 338 |
+
optimizer = Lookahead(optimizer)
|
| 339 |
+
|
| 340 |
+
return optimizer
|
src/custom_timm/optim/radam.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAdam Optimizer.
|
| 2 |
+
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
|
| 3 |
+
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch.optim.optimizer import Optimizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RAdam(Optimizer):
|
| 11 |
+
|
| 12 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
| 13 |
+
defaults = dict(
|
| 14 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
| 15 |
+
buffer=[[None, None, None] for _ in range(10)])
|
| 16 |
+
super(RAdam, self).__init__(params, defaults)
|
| 17 |
+
|
| 18 |
+
def __setstate__(self, state):
|
| 19 |
+
super(RAdam, self).__setstate__(state)
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def step(self, closure=None):
|
| 23 |
+
loss = None
|
| 24 |
+
if closure is not None:
|
| 25 |
+
with torch.enable_grad():
|
| 26 |
+
loss = closure()
|
| 27 |
+
|
| 28 |
+
for group in self.param_groups:
|
| 29 |
+
|
| 30 |
+
for p in group['params']:
|
| 31 |
+
if p.grad is None:
|
| 32 |
+
continue
|
| 33 |
+
grad = p.grad.float()
|
| 34 |
+
if grad.is_sparse:
|
| 35 |
+
raise RuntimeError('RAdam does not support sparse gradients')
|
| 36 |
+
|
| 37 |
+
p_fp32 = p.float()
|
| 38 |
+
|
| 39 |
+
state = self.state[p]
|
| 40 |
+
|
| 41 |
+
if len(state) == 0:
|
| 42 |
+
state['step'] = 0
|
| 43 |
+
state['exp_avg'] = torch.zeros_like(p_fp32)
|
| 44 |
+
state['exp_avg_sq'] = torch.zeros_like(p_fp32)
|
| 45 |
+
else:
|
| 46 |
+
state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
|
| 47 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
|
| 48 |
+
|
| 49 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 50 |
+
beta1, beta2 = group['betas']
|
| 51 |
+
|
| 52 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 53 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 54 |
+
|
| 55 |
+
state['step'] += 1
|
| 56 |
+
buffered = group['buffer'][int(state['step'] % 10)]
|
| 57 |
+
if state['step'] == buffered[0]:
|
| 58 |
+
num_sma, step_size = buffered[1], buffered[2]
|
| 59 |
+
else:
|
| 60 |
+
buffered[0] = state['step']
|
| 61 |
+
beta2_t = beta2 ** state['step']
|
| 62 |
+
num_sma_max = 2 / (1 - beta2) - 1
|
| 63 |
+
num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
| 64 |
+
buffered[1] = num_sma
|
| 65 |
+
|
| 66 |
+
# more conservative since it's an approximated value
|
| 67 |
+
if num_sma >= 5:
|
| 68 |
+
step_size = group['lr'] * math.sqrt(
|
| 69 |
+
(1 - beta2_t) *
|
| 70 |
+
(num_sma - 4) / (num_sma_max - 4) *
|
| 71 |
+
(num_sma - 2) / num_sma *
|
| 72 |
+
num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
|
| 73 |
+
else:
|
| 74 |
+
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
| 75 |
+
buffered[2] = step_size
|
| 76 |
+
|
| 77 |
+
if group['weight_decay'] != 0:
|
| 78 |
+
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
|
| 79 |
+
|
| 80 |
+
# more conservative since it's an approximated value
|
| 81 |
+
if num_sma >= 5:
|
| 82 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
| 83 |
+
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
| 84 |
+
else:
|
| 85 |
+
p_fp32.add_(exp_avg, alpha=-step_size)
|
| 86 |
+
|
| 87 |
+
p.copy_(p_fp32)
|
| 88 |
+
|
| 89 |
+
return loss
|
src/custom_timm/optim/rmsprop_tf.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" RMSProp modified to behave like Tensorflow impl
|
| 2 |
+
|
| 3 |
+
Originally cut & paste from PyTorch RMSProp
|
| 4 |
+
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
|
| 5 |
+
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
|
| 6 |
+
|
| 7 |
+
Modifications Copyright 2021 Ross Wightman
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.optim import Optimizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RMSpropTF(Optimizer):
|
| 15 |
+
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
| 16 |
+
|
| 17 |
+
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
|
| 18 |
+
and a few other modifications to closer match Tensorflow for matching hyper-params.
|
| 19 |
+
|
| 20 |
+
Noteworthy changes include:
|
| 21 |
+
1. Epsilon applied inside square-root
|
| 22 |
+
2. square_avg initialized to ones
|
| 23 |
+
3. LR scaling of update accumulated in momentum buffer
|
| 24 |
+
|
| 25 |
+
Proposed by G. Hinton in his
|
| 26 |
+
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
|
| 27 |
+
|
| 28 |
+
The centered version first appears in `Generating Sequences
|
| 29 |
+
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
| 30 |
+
|
| 31 |
+
Arguments:
|
| 32 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 33 |
+
parameter groups
|
| 34 |
+
lr (float, optional): learning rate (default: 1e-2)
|
| 35 |
+
momentum (float, optional): momentum factor (default: 0)
|
| 36 |
+
alpha (float, optional): smoothing (decay) constant (default: 0.9)
|
| 37 |
+
eps (float, optional): term added to the denominator to improve
|
| 38 |
+
numerical stability (default: 1e-10)
|
| 39 |
+
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
| 40 |
+
the gradient is normalized by an estimation of its variance
|
| 41 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 42 |
+
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
| 43 |
+
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
|
| 44 |
+
update as per defaults in Tensorflow
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
|
| 49 |
+
decoupled_decay=False, lr_in_momentum=True):
|
| 50 |
+
if not 0.0 <= lr:
|
| 51 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 52 |
+
if not 0.0 <= eps:
|
| 53 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 54 |
+
if not 0.0 <= momentum:
|
| 55 |
+
raise ValueError("Invalid momentum value: {}".format(momentum))
|
| 56 |
+
if not 0.0 <= weight_decay:
|
| 57 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 58 |
+
if not 0.0 <= alpha:
|
| 59 |
+
raise ValueError("Invalid alpha value: {}".format(alpha))
|
| 60 |
+
|
| 61 |
+
defaults = dict(
|
| 62 |
+
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
|
| 63 |
+
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
|
| 64 |
+
super(RMSpropTF, self).__init__(params, defaults)
|
| 65 |
+
|
| 66 |
+
def __setstate__(self, state):
|
| 67 |
+
super(RMSpropTF, self).__setstate__(state)
|
| 68 |
+
for group in self.param_groups:
|
| 69 |
+
group.setdefault('momentum', 0)
|
| 70 |
+
group.setdefault('centered', False)
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def step(self, closure=None):
|
| 74 |
+
"""Performs a single optimization step.
|
| 75 |
+
|
| 76 |
+
Arguments:
|
| 77 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 78 |
+
and returns the loss.
|
| 79 |
+
"""
|
| 80 |
+
loss = None
|
| 81 |
+
if closure is not None:
|
| 82 |
+
with torch.enable_grad():
|
| 83 |
+
loss = closure()
|
| 84 |
+
|
| 85 |
+
for group in self.param_groups:
|
| 86 |
+
for p in group['params']:
|
| 87 |
+
if p.grad is None:
|
| 88 |
+
continue
|
| 89 |
+
grad = p.grad
|
| 90 |
+
if grad.is_sparse:
|
| 91 |
+
raise RuntimeError('RMSprop does not support sparse gradients')
|
| 92 |
+
state = self.state[p]
|
| 93 |
+
|
| 94 |
+
# State initialization
|
| 95 |
+
if len(state) == 0:
|
| 96 |
+
state['step'] = 0
|
| 97 |
+
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
|
| 98 |
+
if group['momentum'] > 0:
|
| 99 |
+
state['momentum_buffer'] = torch.zeros_like(p)
|
| 100 |
+
if group['centered']:
|
| 101 |
+
state['grad_avg'] = torch.zeros_like(p)
|
| 102 |
+
|
| 103 |
+
square_avg = state['square_avg']
|
| 104 |
+
one_minus_alpha = 1. - group['alpha']
|
| 105 |
+
|
| 106 |
+
state['step'] += 1
|
| 107 |
+
|
| 108 |
+
if group['weight_decay'] != 0:
|
| 109 |
+
if group['decoupled_decay']:
|
| 110 |
+
p.mul_(1. - group['lr'] * group['weight_decay'])
|
| 111 |
+
else:
|
| 112 |
+
grad = grad.add(p, alpha=group['weight_decay'])
|
| 113 |
+
|
| 114 |
+
# Tensorflow order of ops for updating squared avg
|
| 115 |
+
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
| 116 |
+
# square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
|
| 117 |
+
|
| 118 |
+
if group['centered']:
|
| 119 |
+
grad_avg = state['grad_avg']
|
| 120 |
+
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
|
| 121 |
+
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
|
| 122 |
+
# grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original
|
| 123 |
+
else:
|
| 124 |
+
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
|
| 125 |
+
|
| 126 |
+
if group['momentum'] > 0:
|
| 127 |
+
buf = state['momentum_buffer']
|
| 128 |
+
# Tensorflow accumulates the LR scaling in the momentum buffer
|
| 129 |
+
if group['lr_in_momentum']:
|
| 130 |
+
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
| 131 |
+
p.add_(-buf)
|
| 132 |
+
else:
|
| 133 |
+
# PyTorch scales the param update by LR
|
| 134 |
+
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
| 135 |
+
p.add_(buf, alpha=-group['lr'])
|
| 136 |
+
else:
|
| 137 |
+
p.addcdiv_(grad, avg, value=-group['lr'])
|
| 138 |
+
|
| 139 |
+
return loss
|