nux1111 commited on
Commit
43f65fd
·
verified ·
1 Parent(s): bbe83be

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/custom_timm/__pycache__/__init__.cpython-312.pyc +0 -0
  2. src/custom_timm/__pycache__/version.cpython-312.pyc +0 -0
  3. src/custom_timm/data/__pycache__/__init__.cpython-312.pyc +0 -0
  4. src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc +0 -0
  5. src/custom_timm/data/__pycache__/config.cpython-312.pyc +0 -0
  6. src/custom_timm/data/__pycache__/constants.cpython-312.pyc +0 -0
  7. src/custom_timm/data/__pycache__/dataset.cpython-312.pyc +0 -0
  8. src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc +0 -0
  9. src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc +0 -0
  10. src/custom_timm/data/__pycache__/loader.cpython-312.pyc +0 -0
  11. src/custom_timm/data/__pycache__/mixup.cpython-312.pyc +0 -0
  12. src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc +0 -0
  13. src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc +0 -0
  14. src/custom_timm/data/__pycache__/transforms.cpython-312.pyc +0 -0
  15. src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc +0 -0
  16. src/custom_timm/data/parsers/__init__.py +2 -0
  17. src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc +0 -0
  18. src/custom_timm/data/parsers/class_map.py +22 -0
  19. src/custom_timm/data/parsers/img_extensions.py +50 -0
  20. src/custom_timm/data/parsers/parser.py +17 -0
  21. src/custom_timm/data/parsers/parser_factory.py +28 -0
  22. src/custom_timm/data/parsers/parser_image_folder.py +90 -0
  23. src/custom_timm/data/parsers/parser_image_in_tar.py +229 -0
  24. src/custom_timm/data/parsers/parser_image_tar.py +74 -0
  25. src/custom_timm/data/parsers/parser_tfds.py +301 -0
  26. src/custom_timm/models/gluon_resnet.py +245 -0
  27. src/custom_timm/models/gluon_xception.py +267 -0
  28. src/custom_timm/models/hardcorenas.py +151 -0
  29. src/custom_timm/models/helpers.py +796 -0
  30. src/custom_timm/models/hrnet.py +858 -0
  31. src/custom_timm/models/hub.py +170 -0
  32. src/custom_timm/models/inception_resnet_v2.py +382 -0
  33. src/custom_timm/models/inception_v3.py +475 -0
  34. src/custom_timm/models/inception_v4.py +330 -0
  35. src/custom_timm/models/levit.py +592 -0
  36. src/custom_timm/optim/__init__.py +15 -0
  37. src/custom_timm/optim/adabelief.py +201 -0
  38. src/custom_timm/optim/adafactor.py +167 -0
  39. src/custom_timm/optim/adahessian.py +156 -0
  40. src/custom_timm/optim/adamp.py +105 -0
  41. src/custom_timm/optim/adamw.py +122 -0
  42. src/custom_timm/optim/lamb.py +192 -0
  43. src/custom_timm/optim/lars.py +135 -0
  44. src/custom_timm/optim/lookahead.py +61 -0
  45. src/custom_timm/optim/madgrad.py +184 -0
  46. src/custom_timm/optim/nadam.py +92 -0
  47. src/custom_timm/optim/nvnovograd.py +120 -0
  48. src/custom_timm/optim/optim_factory.py +340 -0
  49. src/custom_timm/optim/radam.py +89 -0
  50. src/custom_timm/optim/rmsprop_tf.py +139 -0
src/custom_timm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (674 Bytes). View file
 
src/custom_timm/__pycache__/version.cpython-312.pyc ADDED
Binary file (274 Bytes). View file
 
src/custom_timm/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.14 kB). View file
 
src/custom_timm/data/__pycache__/auto_augment.cpython-312.pyc ADDED
Binary file (35.2 kB). View file
 
src/custom_timm/data/__pycache__/config.cpython-312.pyc ADDED
Binary file (2.85 kB). View file
 
src/custom_timm/data/__pycache__/constants.cpython-312.pyc ADDED
Binary file (754 Bytes). View file
 
src/custom_timm/data/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (7.84 kB). View file
 
src/custom_timm/data/__pycache__/dataset_factory.cpython-312.pyc ADDED
Binary file (5.98 kB). View file
 
src/custom_timm/data/__pycache__/distributed_sampler.cpython-312.pyc ADDED
Binary file (7.33 kB). View file
 
src/custom_timm/data/__pycache__/loader.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
src/custom_timm/data/__pycache__/mixup.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
src/custom_timm/data/__pycache__/random_erasing.cpython-312.pyc ADDED
Binary file (6.36 kB). View file
 
src/custom_timm/data/__pycache__/real_labels.cpython-312.pyc ADDED
Binary file (3.26 kB). View file
 
src/custom_timm/data/__pycache__/transforms.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
src/custom_timm/data/__pycache__/transforms_factory.cpython-312.pyc ADDED
Binary file (7.82 kB). View file
 
src/custom_timm/data/parsers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .parser_factory import create_parser
2
+ from .img_extensions import *
src/custom_timm/data/parsers/__pycache__/class_map.cpython-312.pyc ADDED
Binary file (1.78 kB). View file
 
src/custom_timm/data/parsers/class_map.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ def load_class_map(map_or_filename, root=''):
5
+ if isinstance(map_or_filename, dict):
6
+ assert dict, 'class_map dict must be non-empty'
7
+ return map_or_filename
8
+ class_map_path = map_or_filename
9
+ if not os.path.exists(class_map_path):
10
+ class_map_path = os.path.join(root, class_map_path)
11
+ assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
12
+ class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
13
+ if class_map_ext == '.txt':
14
+ with open(class_map_path) as f:
15
+ class_to_idx = {v.strip(): k for k, v in enumerate(f)}
16
+ elif class_map_ext == '.pkl':
17
+ with open(class_map_path,'rb') as f:
18
+ class_to_idx = pickle.load(f)
19
+ else:
20
+ assert False, f'Unsupported class map file extension ({class_map_ext}).'
21
+ return class_to_idx
22
+
src/custom_timm/data/parsers/img_extensions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
4
+
5
+
6
+ IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
7
+ _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
8
+
9
+
10
+ def _set_extensions(extensions):
11
+ global IMG_EXTENSIONS
12
+ global _IMG_EXTENSIONS_SET
13
+ dedupe = set() # NOTE de-duping tuple while keeping original order
14
+ IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
15
+ _IMG_EXTENSIONS_SET = set(extensions)
16
+
17
+
18
+ def _valid_extension(x: str):
19
+ return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
20
+
21
+
22
+ def is_img_extension(ext):
23
+ return ext in _IMG_EXTENSIONS_SET
24
+
25
+
26
+ def get_img_extensions(as_set=False):
27
+ return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
28
+
29
+
30
+ def set_img_extensions(extensions):
31
+ assert len(extensions)
32
+ for x in extensions:
33
+ assert _valid_extension(x)
34
+ _set_extensions(extensions)
35
+
36
+
37
+ def add_img_extensions(ext):
38
+ if not isinstance(ext, (list, tuple, set)):
39
+ ext = (ext,)
40
+ for x in ext:
41
+ assert _valid_extension(x)
42
+ extensions = IMG_EXTENSIONS + tuple(ext)
43
+ _set_extensions(extensions)
44
+
45
+
46
+ def del_img_extensions(ext):
47
+ if not isinstance(ext, (list, tuple, set)):
48
+ ext = (ext,)
49
+ extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
50
+ _set_extensions(extensions)
src/custom_timm/data/parsers/parser.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+
4
+ class Parser:
5
+ def __init__(self):
6
+ pass
7
+
8
+ @abstractmethod
9
+ def _filename(self, index, basename=False, absolute=False):
10
+ pass
11
+
12
+ def filename(self, index, basename=False, absolute=False):
13
+ return self._filename(index, basename=basename, absolute=absolute)
14
+
15
+ def filenames(self, basename=False, absolute=False):
16
+ return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
17
+
src/custom_timm/data/parsers/parser_factory.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .parser_image_folder import ParserImageFolder
4
+ from .parser_image_in_tar import ParserImageInTar
5
+
6
+
7
+ def create_parser(name, root, split='train', **kwargs):
8
+ name = name.lower()
9
+ name = name.split('/', 2)
10
+ prefix = ''
11
+ if len(name) > 1:
12
+ prefix = name[0]
13
+ name = name[-1]
14
+
15
+ # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
16
+ # explicitly select other options shortly
17
+ if prefix == 'tfds':
18
+ from .parser_tfds import ParserTfds # defer tensorflow import
19
+ parser = ParserTfds(root, name, split=split, **kwargs)
20
+ else:
21
+ assert os.path.exists(root)
22
+ # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
23
+ # FIXME support split here, in parser?
24
+ if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
25
+ parser = ParserImageInTar(root, **kwargs)
26
+ else:
27
+ parser = ParserImageFolder(root, **kwargs)
28
+ return parser
src/custom_timm/data/parsers/parser_image_folder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ A dataset parser that reads images from folders
2
+
3
+ Folders are scannerd recursively to find image files. Labels are based
4
+ on the folder hierarchy, just leaf folders by default.
5
+
6
+ Hacked together by / Copyright 2020 Ross Wightman
7
+ """
8
+ import os
9
+ from typing import Dict, List, Optional, Set, Tuple, Union
10
+
11
+ from custom_timm.utils.misc import natural_key
12
+
13
+ from .class_map import load_class_map
14
+ from .img_extensions import get_img_extensions
15
+ from .parser import Parser
16
+
17
+
18
+ def find_images_and_targets(
19
+ folder: str,
20
+ types: Optional[Union[List, Tuple, Set]] = None,
21
+ class_to_idx: Optional[Dict] = None,
22
+ leaf_name_only: bool = True,
23
+ sort: bool = True
24
+ ):
25
+ """ Walk folder recursively to discover images and map them to classes by folder names.
26
+
27
+ Args:
28
+ folder: root of folder to recrusively search
29
+ types: types (file extensions) to search for in path
30
+ class_to_idx: specify mapping for class (folder name) to class index if set
31
+ leaf_name_only: use only leaf-name of folder walk for class names
32
+ sort: re-sort found images by name (for consistent ordering)
33
+
34
+ Returns:
35
+ A list of image and target tuples, class_to_idx mapping
36
+ """
37
+ types = get_img_extensions(as_set=True) if not types else set(types)
38
+ labels = []
39
+ filenames = []
40
+ for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
41
+ rel_path = os.path.relpath(root, folder) if (root != folder) else ''
42
+ label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
43
+ for f in files:
44
+ base, ext = os.path.splitext(f)
45
+ if ext.lower() in types:
46
+ filenames.append(os.path.join(root, f))
47
+ labels.append(label)
48
+ if class_to_idx is None:
49
+ # building class index
50
+ unique_labels = set(labels)
51
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
52
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
53
+ images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
54
+ if sort:
55
+ images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
56
+ return images_and_targets, class_to_idx
57
+
58
+
59
+ class ParserImageFolder(Parser):
60
+
61
+ def __init__(
62
+ self,
63
+ root,
64
+ class_map=''):
65
+ super().__init__()
66
+
67
+ self.root = root
68
+ class_to_idx = None
69
+ if class_map:
70
+ class_to_idx = load_class_map(class_map, root)
71
+ self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
72
+ if len(self.samples) == 0:
73
+ raise RuntimeError(
74
+ f'Found 0 images in subfolders of {root}. '
75
+ f'Supported image extensions are {", ".join(get_img_extensions())}')
76
+
77
+ def __getitem__(self, index):
78
+ path, target = self.samples[index]
79
+ return open(path, 'rb'), target
80
+
81
+ def __len__(self):
82
+ return len(self.samples)
83
+
84
+ def _filename(self, index, basename=False, absolute=False):
85
+ filename = self.samples[index][0]
86
+ if basename:
87
+ filename = os.path.basename(filename)
88
+ elif not absolute:
89
+ filename = os.path.relpath(filename, self.root)
90
+ return filename
src/custom_timm/data/parsers/parser_image_in_tar.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ A dataset parser that reads tarfile based datasets
2
+
3
+ This parser can read and extract image samples from:
4
+ * a single tar of image files
5
+ * a folder of multiple tarfiles containing imagefiles
6
+ * a tar of tars containing image files
7
+
8
+ Labels are based on the combined folder and/or tar name structure.
9
+
10
+ Hacked together by / Copyright 2020 Ross Wightman
11
+ """
12
+ import logging
13
+ import os
14
+ import pickle
15
+ import tarfile
16
+ from glob import glob
17
+ from typing import List, Tuple, Dict, Set, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from custom_timm.utils.misc import natural_key
22
+
23
+ from .class_map import load_class_map
24
+ from .img_extensions import get_img_extensions
25
+ from .parser import Parser
26
+
27
+ _logger = logging.getLogger(__name__)
28
+ CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
29
+
30
+
31
+ class TarState:
32
+
33
+ def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
34
+ self.tf: tarfile.TarFile = tf
35
+ self.ti: tarfile.TarInfo = ti
36
+ self.children: Dict[str, TarState] = {} # child states (tars within tars)
37
+
38
+ def reset(self):
39
+ self.tf = None
40
+
41
+
42
+ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
43
+ sample_count = 0
44
+ for i, ti in enumerate(tf):
45
+ if not ti.isfile():
46
+ continue
47
+ dirname, basename = os.path.split(ti.path)
48
+ name, ext = os.path.splitext(basename)
49
+ ext = ext.lower()
50
+ if ext == '.tar':
51
+ with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
52
+ child_info = dict(
53
+ name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
54
+ sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
55
+ _logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
56
+ parent_info['children'].append(child_info)
57
+ elif ext in extensions:
58
+ parent_info['samples'].append(ti)
59
+ sample_count += 1
60
+ return sample_count
61
+
62
+
63
+ def extract_tarinfos(
64
+ root,
65
+ class_name_to_idx: Optional[Dict] = None,
66
+ cache_tarinfo: Optional[bool] = None,
67
+ extensions: Optional[Union[List, Tuple, Set]] = None,
68
+ sort: bool = True
69
+ ):
70
+ extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
71
+ root_is_tar = False
72
+ if os.path.isfile(root):
73
+ assert os.path.splitext(root)[-1].lower() == '.tar'
74
+ tar_filenames = [root]
75
+ root, root_name = os.path.split(root)
76
+ root_name = os.path.splitext(root_name)[0]
77
+ root_is_tar = True
78
+ else:
79
+ root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
80
+ tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
81
+ num_tars = len(tar_filenames)
82
+ tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
83
+ assert num_tars, f'No .tar files found at specified path ({root}).'
84
+
85
+ _logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
86
+ info = dict(tartrees=[])
87
+ cache_path = ''
88
+ if cache_tarinfo is None:
89
+ cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
90
+ if cache_tarinfo:
91
+ cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
92
+ cache_path = os.path.join(root, cache_filename)
93
+ if os.path.exists(cache_path):
94
+ _logger.info(f'Reading tar info from cache file {cache_path}.')
95
+ with open(cache_path, 'rb') as pf:
96
+ info = pickle.load(pf)
97
+ assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
98
+ else:
99
+ for i, fn in enumerate(tar_filenames):
100
+ path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
101
+ with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
102
+ parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
103
+ num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
104
+ num_children = len(parent_info["children"])
105
+ _logger.debug(
106
+ f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
107
+ info['tartrees'].append(parent_info)
108
+ if cache_path:
109
+ _logger.info(f'Writing tar info to cache file {cache_path}.')
110
+ with open(cache_path, 'wb') as pf:
111
+ pickle.dump(info, pf)
112
+
113
+ samples = []
114
+ labels = []
115
+ build_class_map = False
116
+ if class_name_to_idx is None:
117
+ build_class_map = True
118
+
119
+ # Flatten tartree info into lists of samples and targets w/ targets based on label id via
120
+ # class map arg or from unique paths.
121
+ # NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
122
+ # this covers my current use cases and keeps things a little easier to test for now.
123
+ tarfiles = []
124
+
125
+ def _label_from_paths(*path, leaf_only=True):
126
+ path = os.path.join(*path).strip(os.path.sep)
127
+ return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
128
+
129
+ def _add_samples(info, fn):
130
+ added = 0
131
+ for s in info['samples']:
132
+ label = _label_from_paths(info['path'], os.path.dirname(s.path))
133
+ if not build_class_map and label not in class_name_to_idx:
134
+ continue
135
+ samples.append((s, fn, info['ti']))
136
+ labels.append(label)
137
+ added += 1
138
+ return added
139
+
140
+ _logger.info(f'Collecting samples and building tar states.')
141
+ for parent_info in info['tartrees']:
142
+ # if tartree has children, we assume all samples are at the child level
143
+ tar_name = None if root_is_tar else parent_info['name']
144
+ tar_state = TarState()
145
+ parent_added = 0
146
+ for child_info in parent_info['children']:
147
+ child_added = _add_samples(child_info, fn=tar_name)
148
+ if child_added:
149
+ tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
150
+ parent_added += child_added
151
+ parent_added += _add_samples(parent_info, fn=tar_name)
152
+ if parent_added:
153
+ tarfiles.append((tar_name, tar_state))
154
+ del info
155
+
156
+ if build_class_map:
157
+ # build class index
158
+ sorted_labels = list(sorted(set(labels), key=natural_key))
159
+ class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
160
+
161
+ _logger.info(f'Mapping targets and sorting samples.')
162
+ samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
163
+ if sort:
164
+ samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
165
+ samples, targets = zip(*samples_and_targets)
166
+ samples = np.array(samples)
167
+ targets = np.array(targets)
168
+ _logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
169
+ return samples, targets, class_name_to_idx, tarfiles
170
+
171
+
172
+ class ParserImageInTar(Parser):
173
+ """ Multi-tarfile dataset parser where there is one .tar file per class
174
+ """
175
+
176
+ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
177
+ super().__init__()
178
+
179
+ class_name_to_idx = None
180
+ if class_map:
181
+ class_name_to_idx = load_class_map(class_map, root)
182
+ self.root = root
183
+ self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
184
+ self.root,
185
+ class_name_to_idx=class_name_to_idx,
186
+ cache_tarinfo=cache_tarinfo
187
+ )
188
+ self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
189
+ if len(tarfiles) == 1 and tarfiles[0][0] is None:
190
+ self.root_is_tar = True
191
+ self.tar_state = tarfiles[0][1]
192
+ else:
193
+ self.root_is_tar = False
194
+ self.tar_state = dict(tarfiles)
195
+ self.cache_tarfiles = cache_tarfiles
196
+
197
+ def __len__(self):
198
+ return len(self.samples)
199
+
200
+ def __getitem__(self, index):
201
+ sample = self.samples[index]
202
+ target = self.targets[index]
203
+ sample_ti, parent_fn, child_ti = sample
204
+ parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
205
+
206
+ tf = None
207
+ cache_state = None
208
+ if self.cache_tarfiles:
209
+ cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
210
+ tf = cache_state.tf
211
+ if tf is None:
212
+ tf = tarfile.open(parent_abs)
213
+ if self.cache_tarfiles:
214
+ cache_state.tf = tf
215
+ if child_ti is not None:
216
+ ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
217
+ if ctf is None:
218
+ ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
219
+ if self.cache_tarfiles:
220
+ cache_state.children[child_ti.name].tf = ctf
221
+ tf = ctf
222
+
223
+ return tf.extractfile(sample_ti), target
224
+
225
+ def _filename(self, index, basename=False, absolute=False):
226
+ filename = self.samples[index][0].name
227
+ if basename:
228
+ filename = os.path.basename(filename)
229
+ return filename
src/custom_timm/data/parsers/parser_image_tar.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ A dataset parser that reads single tarfile based datasets
2
+
3
+ This parser can read datasets consisting if a single tarfile containing images.
4
+ I am planning to deprecated it in favour of ParerImageInTar.
5
+
6
+ Hacked together by / Copyright 2020 Ross Wightman
7
+ """
8
+ import os
9
+ import tarfile
10
+
11
+ from custom_timm.utils.misc import natural_key
12
+
13
+ from .class_map import load_class_map
14
+ from .img_extensions import get_img_extensions
15
+ from .parser import Parser
16
+
17
+
18
+ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
19
+ extensions = get_img_extensions(as_set=True)
20
+ files = []
21
+ labels = []
22
+ for ti in tarfile.getmembers():
23
+ if not ti.isfile():
24
+ continue
25
+ dirname, basename = os.path.split(ti.path)
26
+ label = os.path.basename(dirname)
27
+ ext = os.path.splitext(basename)[1]
28
+ if ext.lower() in extensions:
29
+ files.append(ti)
30
+ labels.append(label)
31
+ if class_to_idx is None:
32
+ unique_labels = set(labels)
33
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
34
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
35
+ tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
36
+ if sort:
37
+ tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
38
+ return tarinfo_and_targets, class_to_idx
39
+
40
+
41
+ class ParserImageTar(Parser):
42
+ """ Single tarfile dataset where classes are mapped to folders within tar
43
+ NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
44
+ operate on folders of tars or tars in tars.
45
+ """
46
+ def __init__(self, root, class_map=''):
47
+ super().__init__()
48
+
49
+ class_to_idx = None
50
+ if class_map:
51
+ class_to_idx = load_class_map(class_map, root)
52
+ assert os.path.isfile(root)
53
+ self.root = root
54
+
55
+ with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
56
+ self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
57
+ self.imgs = self.samples
58
+ self.tarfile = None # lazy init in __getitem__
59
+
60
+ def __getitem__(self, index):
61
+ if self.tarfile is None:
62
+ self.tarfile = tarfile.open(self.root)
63
+ tarinfo, target = self.samples[index]
64
+ fileobj = self.tarfile.extractfile(tarinfo)
65
+ return fileobj, target
66
+
67
+ def __len__(self):
68
+ return len(self.samples)
69
+
70
+ def _filename(self, index, basename=False, absolute=False):
71
+ filename = self.samples[index][0].name
72
+ if basename:
73
+ filename = os.path.basename(filename)
74
+ return filename
src/custom_timm/data/parsers/parser_tfds.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Dataset parser interface that wraps TFDS datasets
2
+
3
+ Wraps many (most?) TFDS image-classification datasets
4
+ from https://github.com/tensorflow/datasets
5
+ https://www.tensorflow.org/datasets/catalog/overview#image_classification
6
+
7
+ Hacked together by / Copyright 2020 Ross Wightman
8
+ """
9
+ import math
10
+ import torch
11
+ import torch.distributed as dist
12
+ from PIL import Image
13
+
14
+ try:
15
+ import tensorflow as tf
16
+ tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
17
+ import tensorflow_datasets as tfds
18
+ try:
19
+ tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
20
+ has_buggy_even_splits = False
21
+ except TypeError:
22
+ print("Warning: This version of tfds doesn't have the latest even_splits impl. "
23
+ "Please update or use tfds-nightly for better fine-grained split behaviour.")
24
+ has_buggy_even_splits = True
25
+ # NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
26
+ # import resource
27
+ # low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
28
+ # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
29
+ except ImportError as e:
30
+ print(e)
31
+ print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
32
+ exit(1)
33
+ from .parser import Parser
34
+
35
+
36
+ MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
37
+ SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
38
+ PREFETCH_SIZE = 2048 # examples to prefetch
39
+
40
+
41
+ def even_split_indices(split, n, num_examples):
42
+ partitions = [round(i * num_examples / n) for i in range(n + 1)]
43
+ return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
44
+
45
+
46
+ def get_class_labels(info):
47
+ if 'label' not in info.features:
48
+ return {}
49
+ class_label = info.features['label']
50
+ class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
51
+ return class_to_idx
52
+
53
+
54
+ class ParserTfds(Parser):
55
+ """ Wrap Tensorflow Datasets for use in PyTorch
56
+
57
+ There several things to be aware of:
58
+ * To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
59
+ dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
60
+ https://github.com/pytorch/pytorch/issues/33413
61
+ * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
62
+ from each worker could be a different size. For training this is worked around by option above, for
63
+ validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
64
+ across replicas are of same size. This will slightly alter the results, distributed validation will not be
65
+ 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
66
+ since there are up to N * J extra examples with IterableDatasets.
67
+ * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
68
+ replicas and dataloader workers you can use. For really small datasets that only contain a few shards
69
+ you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
70
+ benefit of distributed training or fast dataloading should be much less for small datasets.
71
+ * This wrapper is currently configured to return individual, decompressed image examples from the TFDS
72
+ dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
73
+ to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
74
+ components.
75
+
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ root,
81
+ name,
82
+ split='train',
83
+ is_training=False,
84
+ batch_size=None,
85
+ download=False,
86
+ repeats=0,
87
+ seed=42,
88
+ input_name='image',
89
+ input_image='RGB',
90
+ target_name='label',
91
+ target_image='',
92
+ prefetch_size=None,
93
+ shuffle_size=None,
94
+ max_threadpool_size=None
95
+ ):
96
+ """ Tensorflow-datasets Wrapper
97
+
98
+ Args:
99
+ root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
100
+ name: tfds dataset name (eg `imagenet2012`)
101
+ split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
102
+ is_training: training mode, shuffle enabled, dataset len rounded by batch_size
103
+ batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
104
+ download: download and build TFDS dataset if set, otherwise must use tfds CLI
105
+ repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
106
+ seed: common seed for shard shuffle across all distributed/worker instances
107
+ input_name: name of Feature to return as data (input)
108
+ input_image: image mode if input is an image (currently PIL mode string)
109
+ target_name: name of Feature to return as target (label)
110
+ target_image: image mode if target is an image (currently PIL mode string)
111
+ prefetch_size: override default tf.data prefetch buffer size
112
+ shuffle_size: override default tf.data shuffle buffer size
113
+ max_threadpool_size: override default threadpool size for tf.data
114
+ """
115
+ super().__init__()
116
+ self.root = root
117
+ self.split = split
118
+ self.is_training = is_training
119
+ if self.is_training:
120
+ assert batch_size is not None, \
121
+ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
122
+ self.batch_size = batch_size
123
+ self.repeats = repeats
124
+ self.common_seed = seed # a seed that's fixed across all worker / distributed instances
125
+
126
+ # performance settings
127
+ self.prefetch_size = prefetch_size or PREFETCH_SIZE
128
+ self.shuffle_size = shuffle_size or SHUFFLE_SIZE
129
+ self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
130
+
131
+ # TFDS builder and split information
132
+ self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
133
+ self.input_image = input_image
134
+ self.target_name = target_name
135
+ self.target_image = target_image
136
+ self.builder = tfds.builder(name, data_dir=root)
137
+ # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
138
+ if download:
139
+ self.builder.download_and_prepare()
140
+ self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
141
+ self.split_info = self.builder.info.splits[split]
142
+ self.num_examples = self.split_info.num_examples
143
+
144
+ # Distributed world state
145
+ self.dist_rank = 0
146
+ self.dist_num_replicas = 1
147
+ if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
148
+ self.dist_rank = dist.get_rank()
149
+ self.dist_num_replicas = dist.get_world_size()
150
+
151
+ # Attributes that are updated in _lazy_init, including the tf.data pipeline itself
152
+ self.global_num_workers = 1
153
+ self.worker_info = None
154
+ self.worker_seed = 0 # seed unique to each work instance
155
+ self.subsplit = None # set when data is distributed across workers using sub-splits
156
+ self.ds = None # initialized lazily on each dataloader worker process
157
+
158
+ def _lazy_init(self):
159
+ """ Lazily initialize the dataset.
160
+
161
+ This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
162
+ will be using the dataset instance. The __init__ method is called on the main process,
163
+ this will be called in a dataloader worker process.
164
+
165
+ NOTE: There will be problems if you try to re-use this dataset across different loader/worker
166
+ instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
167
+ before it is passed to dataloader.
168
+ """
169
+ worker_info = torch.utils.data.get_worker_info()
170
+
171
+ # setup input context to split dataset across distributed processes
172
+ num_workers = 1
173
+ global_worker_id = 0
174
+ if worker_info is not None:
175
+ self.worker_info = worker_info
176
+ self.worker_seed = worker_info.seed
177
+ num_workers = worker_info.num_workers
178
+ self.global_num_workers = self.dist_num_replicas * num_workers
179
+ global_worker_id = self.dist_rank * num_workers + worker_info.id
180
+
181
+ """ Data sharding
182
+ InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
183
+ My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
184
+ between the splits each iteration, but that understanding could be wrong.
185
+
186
+ I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
187
+ the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
188
+ in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
189
+ for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
190
+ """
191
+ should_subsplit = self.global_num_workers > 1 and (
192
+ self.split_info.num_shards < self.global_num_workers or not self.is_training)
193
+ if should_subsplit:
194
+ # split the dataset w/o using sharding for more even examples / worker, can result in less optimal
195
+ # read patterns for distributed training (overlap across shards) so better to use InputContext there
196
+ if has_buggy_even_splits:
197
+ # my even_split workaround doesn't work on subsplits, upgrade tfds!
198
+ if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
199
+ subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
200
+ self.subsplit = subsplits[global_worker_id]
201
+ else:
202
+ subsplits = tfds.even_splits(self.split, self.global_num_workers)
203
+ self.subsplit = subsplits[global_worker_id]
204
+
205
+ input_context = None
206
+ if self.global_num_workers > 1 and self.subsplit is None:
207
+ # set input context to divide shards among distributed replicas
208
+ input_context = tf.distribute.InputContext(
209
+ num_input_pipelines=self.global_num_workers,
210
+ input_pipeline_id=global_worker_id,
211
+ num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
212
+ )
213
+ read_config = tfds.ReadConfig(
214
+ shuffle_seed=self.common_seed,
215
+ shuffle_reshuffle_each_iteration=True,
216
+ input_context=input_context)
217
+ ds = self.builder.as_dataset(
218
+ split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
219
+ # avoid overloading threading w/ combo of TF ds threads + PyTorch workers
220
+ options = tf.data.Options()
221
+ thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
222
+ getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
223
+ getattr(options, thread_member).max_intra_op_parallelism = 1
224
+ ds = ds.with_options(options)
225
+ if self.is_training or self.repeats > 1:
226
+ # to prevent excessive drop_last batch behaviour w/ IterableDatasets
227
+ # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
228
+ ds = ds.repeat() # allow wrap around and break iteration manually
229
+ if self.is_training:
230
+ ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
231
+ ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
232
+ self.ds = tfds.as_numpy(ds)
233
+
234
+ def __iter__(self):
235
+ if self.ds is None:
236
+ self._lazy_init()
237
+
238
+ # Compute a rounded up sample count that is used to:
239
+ # 1. make batches even cross workers & replicas in distributed validation.
240
+ # This adds extra examples and will slightly alter validation results.
241
+ # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
242
+ # batches are produced (underlying tfds iter wraps around)
243
+ target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
244
+ if self.is_training:
245
+ # round up to nearest batch_size per worker-replica
246
+ target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
247
+
248
+ # Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
249
+ example_count = 0
250
+ for example in self.ds:
251
+ input_data = example[self.input_name]
252
+ if self.input_image:
253
+ input_data = Image.fromarray(input_data, mode=self.input_image)
254
+ target_data = example[self.target_name]
255
+ if self.target_image:
256
+ target_data = Image.fromarray(target_data, mode=self.target_image)
257
+ yield input_data, target_data
258
+ example_count += 1
259
+ if self.is_training and example_count >= target_example_count:
260
+ # Need to break out of loop when repeat() is enabled for training w/ oversampling
261
+ # this results in extra examples per epoch but seems more desirable than dropping
262
+ # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
263
+ break
264
+
265
+ # Pad across distributed nodes (make counts equal by adding examples)
266
+ if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
267
+ 0 < example_count < target_example_count:
268
+ # Validation batch padding only done for distributed training where results are reduced across nodes.
269
+ # For single process case, it won't matter if workers return different batch sizes.
270
+ # If using input_context or % based splits, sample count can vary significantly across workers and this
271
+ # approach should not be used (hence disabled if self.subsplit isn't set).
272
+ while example_count < target_example_count:
273
+ yield input_data, target_data # yield prev sample again
274
+ example_count += 1
275
+
276
+ def __len__(self):
277
+ # this is just an estimate and does not factor in extra examples added to pad batches based on
278
+ # complete worker & replica info (not available until init in dataloader).
279
+ return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
280
+
281
+ def _filename(self, index, basename=False, absolute=False):
282
+ assert False, "Not supported" # no random access to examples
283
+
284
+ def filenames(self, basename=False, absolute=False):
285
+ """ Return all filenames in dataset, overrides base"""
286
+ if self.ds is None:
287
+ self._lazy_init()
288
+ names = []
289
+ for sample in self.ds:
290
+ if len(names) > self.num_examples:
291
+ break # safety for ds.repeat() case
292
+ if 'file_name' in sample:
293
+ name = sample['file_name']
294
+ elif 'filename' in sample:
295
+ name = sample['filename']
296
+ elif 'id' in sample:
297
+ name = sample['id']
298
+ else:
299
+ assert False, "No supported name field present"
300
+ names.append(name)
301
+ return names
src/custom_timm/models/gluon_resnet.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants
2
+ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions
3
+ and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
4
+ by Ross Wightman
5
+ """
6
+
7
+ from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8
+ from .helpers import build_model_with_cfg
9
+ from .layers import SEModule
10
+ from .registry import register_model
11
+ from .resnet import ResNet, Bottleneck, BasicBlock
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
18
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
19
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
20
+ 'first_conv': 'conv1', 'classifier': 'fc',
21
+ **kwargs
22
+ }
23
+
24
+
25
+ default_cfgs = {
26
+ 'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'),
27
+ 'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'),
28
+ 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'),
29
+ 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'),
30
+ 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'),
31
+ 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth',
32
+ first_conv='conv1.0'),
33
+ 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth',
34
+ first_conv='conv1.0'),
35
+ 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth',
36
+ first_conv='conv1.0'),
37
+ 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth',
38
+ first_conv='conv1.0'),
39
+ 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth',
40
+ first_conv='conv1.0'),
41
+ 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth',
42
+ first_conv='conv1.0'),
43
+ 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth',
44
+ first_conv='conv1.0'),
45
+ 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth',
46
+ first_conv='conv1.0'),
47
+ 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth',
48
+ first_conv='conv1.0'),
49
+ 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
50
+ 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
51
+ 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
52
+ 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
53
+ 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
54
+ 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
55
+ 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
56
+ first_conv='conv1.0'),
57
+ }
58
+
59
+
60
+ def _create_resnet(variant, pretrained=False, **kwargs):
61
+ return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
62
+
63
+
64
+ @register_model
65
+ def gluon_resnet18_v1b(pretrained=False, **kwargs):
66
+ """Constructs a ResNet-18 model.
67
+ """
68
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
69
+ return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args)
70
+
71
+
72
+ @register_model
73
+ def gluon_resnet34_v1b(pretrained=False, **kwargs):
74
+ """Constructs a ResNet-34 model.
75
+ """
76
+ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
77
+ return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args)
78
+
79
+
80
+ @register_model
81
+ def gluon_resnet50_v1b(pretrained=False, **kwargs):
82
+ """Constructs a ResNet-50 model.
83
+ """
84
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
85
+ return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args)
86
+
87
+
88
+ @register_model
89
+ def gluon_resnet101_v1b(pretrained=False, **kwargs):
90
+ """Constructs a ResNet-101 model.
91
+ """
92
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
93
+ return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args)
94
+
95
+
96
+ @register_model
97
+ def gluon_resnet152_v1b(pretrained=False, **kwargs):
98
+ """Constructs a ResNet-152 model.
99
+ """
100
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
101
+ return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args)
102
+
103
+
104
+ @register_model
105
+ def gluon_resnet50_v1c(pretrained=False, **kwargs):
106
+ """Constructs a ResNet-50 model.
107
+ """
108
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs)
109
+ return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args)
110
+
111
+
112
+ @register_model
113
+ def gluon_resnet101_v1c(pretrained=False, **kwargs):
114
+ """Constructs a ResNet-101 model.
115
+ """
116
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs)
117
+ return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args)
118
+
119
+
120
+ @register_model
121
+ def gluon_resnet152_v1c(pretrained=False, **kwargs):
122
+ """Constructs a ResNet-152 model.
123
+ """
124
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs)
125
+ return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args)
126
+
127
+
128
+ @register_model
129
+ def gluon_resnet50_v1d(pretrained=False, **kwargs):
130
+ """Constructs a ResNet-50 model.
131
+ """
132
+ model_args = dict(
133
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
134
+ return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args)
135
+
136
+
137
+ @register_model
138
+ def gluon_resnet101_v1d(pretrained=False, **kwargs):
139
+ """Constructs a ResNet-101 model.
140
+ """
141
+ model_args = dict(
142
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
143
+ return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args)
144
+
145
+
146
+ @register_model
147
+ def gluon_resnet152_v1d(pretrained=False, **kwargs):
148
+ """Constructs a ResNet-152 model.
149
+ """
150
+ model_args = dict(
151
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
152
+ return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args)
153
+
154
+
155
+ @register_model
156
+ def gluon_resnet50_v1s(pretrained=False, **kwargs):
157
+ """Constructs a ResNet-50 model.
158
+ """
159
+ model_args = dict(
160
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs)
161
+ return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args)
162
+
163
+
164
+
165
+ @register_model
166
+ def gluon_resnet101_v1s(pretrained=False, **kwargs):
167
+ """Constructs a ResNet-101 model.
168
+ """
169
+ model_args = dict(
170
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs)
171
+ return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args)
172
+
173
+
174
+ @register_model
175
+ def gluon_resnet152_v1s(pretrained=False, **kwargs):
176
+ """Constructs a ResNet-152 model.
177
+ """
178
+ model_args = dict(
179
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs)
180
+ return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args)
181
+
182
+
183
+
184
+ @register_model
185
+ def gluon_resnext50_32x4d(pretrained=False, **kwargs):
186
+ """Constructs a ResNeXt50-32x4d model.
187
+ """
188
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
189
+ return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args)
190
+
191
+
192
+ @register_model
193
+ def gluon_resnext101_32x4d(pretrained=False, **kwargs):
194
+ """Constructs a ResNeXt-101 model.
195
+ """
196
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
197
+ return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args)
198
+
199
+
200
+ @register_model
201
+ def gluon_resnext101_64x4d(pretrained=False, **kwargs):
202
+ """Constructs a ResNeXt-101 model.
203
+ """
204
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
205
+ return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args)
206
+
207
+
208
+ @register_model
209
+ def gluon_seresnext50_32x4d(pretrained=False, **kwargs):
210
+ """Constructs a SEResNeXt50-32x4d model.
211
+ """
212
+ model_args = dict(
213
+ block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
214
+ block_args=dict(attn_layer=SEModule), **kwargs)
215
+ return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args)
216
+
217
+
218
+ @register_model
219
+ def gluon_seresnext101_32x4d(pretrained=False, **kwargs):
220
+ """Constructs a SEResNeXt-101-32x4d model.
221
+ """
222
+ model_args = dict(
223
+ block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
224
+ block_args=dict(attn_layer=SEModule), **kwargs)
225
+ return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args)
226
+
227
+
228
+ @register_model
229
+ def gluon_seresnext101_64x4d(pretrained=False, **kwargs):
230
+ """Constructs a SEResNeXt-101-64x4d model.
231
+ """
232
+ model_args = dict(
233
+ block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4,
234
+ block_args=dict(attn_layer=SEModule), **kwargs)
235
+ return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args)
236
+
237
+
238
+ @register_model
239
+ def gluon_senet154(pretrained=False, **kwargs):
240
+ """Constructs an SENet-154 model.
241
+ """
242
+ model_args = dict(
243
+ block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
244
+ down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs)
245
+ return _create_resnet('gluon_senet154', pretrained, **model_args)
src/custom_timm/models/gluon_xception.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytorch impl of Gluon Xception
2
+ This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl.
3
+
4
+ Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html)
5
+ Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception
6
+
7
+ Hacked together by / Copyright 2020 Ross Wightman
8
+ """
9
+ from collections import OrderedDict
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from .helpers import build_model_with_cfg
17
+ from .layers import create_classifier, get_padding
18
+ from .registry import register_model
19
+
20
+ __all__ = ['Xception65']
21
+
22
+ default_cfgs = {
23
+ 'gluon_xception65': {
24
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
25
+ 'input_size': (3, 299, 299),
26
+ 'crop_pct': 0.903,
27
+ 'pool_size': (10, 10),
28
+ 'interpolation': 'bicubic',
29
+ 'mean': IMAGENET_DEFAULT_MEAN,
30
+ 'std': IMAGENET_DEFAULT_STD,
31
+ 'num_classes': 1000,
32
+ 'first_conv': 'conv1',
33
+ 'classifier': 'fc'
34
+ # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
35
+ },
36
+ }
37
+
38
+ """ PADDING NOTES
39
+ The original PyTorch and Gluon impl of these models dutifully reproduced the
40
+ aligned padding added to Tensorflow models for Deeplab. This padding was compensating
41
+ for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
42
+ """
43
+
44
+
45
+ class SeparableConv2d(nn.Module):
46
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
47
+ super(SeparableConv2d, self).__init__()
48
+ self.kernel_size = kernel_size
49
+ self.dilation = dilation
50
+
51
+ # depthwise convolution
52
+ padding = get_padding(kernel_size, stride, dilation)
53
+ self.conv_dw = nn.Conv2d(
54
+ inplanes, inplanes, kernel_size, stride=stride,
55
+ padding=padding, dilation=dilation, groups=inplanes, bias=bias)
56
+ self.bn = norm_layer(num_features=inplanes)
57
+ # pointwise convolution
58
+ self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
59
+
60
+ def forward(self, x):
61
+ x = self.conv_dw(x)
62
+ x = self.bn(x)
63
+ x = self.conv_pw(x)
64
+ return x
65
+
66
+
67
+ class Block(nn.Module):
68
+ def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
69
+ super(Block, self).__init__()
70
+ if isinstance(planes, (list, tuple)):
71
+ assert len(planes) == 3
72
+ else:
73
+ planes = (planes,) * 3
74
+ outplanes = planes[-1]
75
+
76
+ if outplanes != inplanes or stride != 1:
77
+ self.skip = nn.Sequential()
78
+ self.skip.add_module('conv1', nn.Conv2d(
79
+ inplanes, outplanes, 1, stride=stride, bias=False)),
80
+ self.skip.add_module('bn1', norm_layer(num_features=outplanes))
81
+ else:
82
+ self.skip = None
83
+
84
+ rep = OrderedDict()
85
+ for i in range(3):
86
+ rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
87
+ rep['conv%d' % (i + 1)] = SeparableConv2d(
88
+ inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
89
+ rep['bn%d' % (i + 1)] = norm_layer(planes[i])
90
+ inplanes = planes[i]
91
+
92
+ if not start_with_relu:
93
+ del rep['act1']
94
+ else:
95
+ rep['act1'] = nn.ReLU(inplace=False)
96
+ self.rep = nn.Sequential(rep)
97
+
98
+ def forward(self, x):
99
+ skip = x
100
+ if self.skip is not None:
101
+ skip = self.skip(skip)
102
+ x = self.rep(x) + skip
103
+ return x
104
+
105
+
106
+ class Xception65(nn.Module):
107
+ """Modified Aligned Xception.
108
+
109
+ NOTE: only the 65 layer version is included here, the 71 layer variant
110
+ was not correct and had no pretrained weights
111
+ """
112
+
113
+ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
114
+ drop_rate=0., global_pool='avg'):
115
+ super(Xception65, self).__init__()
116
+ self.num_classes = num_classes
117
+ self.drop_rate = drop_rate
118
+ if output_stride == 32:
119
+ entry_block3_stride = 2
120
+ exit_block20_stride = 2
121
+ middle_dilation = 1
122
+ exit_dilation = (1, 1)
123
+ elif output_stride == 16:
124
+ entry_block3_stride = 2
125
+ exit_block20_stride = 1
126
+ middle_dilation = 1
127
+ exit_dilation = (1, 2)
128
+ elif output_stride == 8:
129
+ entry_block3_stride = 1
130
+ exit_block20_stride = 1
131
+ middle_dilation = 2
132
+ exit_dilation = (2, 4)
133
+ else:
134
+ raise NotImplementedError
135
+
136
+ # Entry flow
137
+ self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
138
+ self.bn1 = norm_layer(num_features=32)
139
+ self.act1 = nn.ReLU(inplace=True)
140
+
141
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
142
+ self.bn2 = norm_layer(num_features=64)
143
+ self.act2 = nn.ReLU(inplace=True)
144
+
145
+ self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
146
+ self.block1_act = nn.ReLU(inplace=True)
147
+ self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
148
+ self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
149
+
150
+ # Middle flow
151
+ self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
152
+ 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
153
+
154
+ # Exit flow
155
+ self.block20 = Block(
156
+ 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
157
+ self.block20_act = nn.ReLU(inplace=True)
158
+
159
+ self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
160
+ self.bn3 = norm_layer(num_features=1536)
161
+ self.act3 = nn.ReLU(inplace=True)
162
+
163
+ self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
164
+ self.bn4 = norm_layer(num_features=1536)
165
+ self.act4 = nn.ReLU(inplace=True)
166
+
167
+ self.num_features = 2048
168
+ self.conv5 = SeparableConv2d(
169
+ 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
170
+ self.bn5 = norm_layer(num_features=self.num_features)
171
+ self.act5 = nn.ReLU(inplace=True)
172
+ self.feature_info = [
173
+ dict(num_chs=64, reduction=2, module='act2'),
174
+ dict(num_chs=128, reduction=4, module='block1_act'),
175
+ dict(num_chs=256, reduction=8, module='block3.rep.act1'),
176
+ dict(num_chs=728, reduction=16, module='block20.rep.act1'),
177
+ dict(num_chs=2048, reduction=32, module='act5'),
178
+ ]
179
+
180
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
181
+
182
+ @torch.jit.ignore
183
+ def group_matcher(self, coarse=False):
184
+ matcher = dict(
185
+ stem=r'^conv[12]|bn[12]',
186
+ blocks=[
187
+ (r'^mid\.block(\d+)', None),
188
+ (r'^block(\d+)', None),
189
+ (r'^conv[345]|bn[345]', (99,)),
190
+ ],
191
+ )
192
+ return matcher
193
+
194
+ @torch.jit.ignore
195
+ def set_grad_checkpointing(self, enable=True):
196
+ assert not enable, "gradient checkpointing not supported"
197
+
198
+ @torch.jit.ignore
199
+ def get_classifier(self):
200
+ return self.fc
201
+
202
+ def reset_classifier(self, num_classes, global_pool='avg'):
203
+ self.num_classes = num_classes
204
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
205
+
206
+ def forward_features(self, x):
207
+ # Entry flow
208
+ x = self.conv1(x)
209
+ x = self.bn1(x)
210
+ x = self.act1(x)
211
+
212
+ x = self.conv2(x)
213
+ x = self.bn2(x)
214
+ x = self.act2(x)
215
+
216
+ x = self.block1(x)
217
+ x = self.block1_act(x)
218
+ # c1 = x
219
+ x = self.block2(x)
220
+ # c2 = x
221
+ x = self.block3(x)
222
+
223
+ # Middle flow
224
+ x = self.mid(x)
225
+ # c3 = x
226
+
227
+ # Exit flow
228
+ x = self.block20(x)
229
+ x = self.block20_act(x)
230
+ x = self.conv3(x)
231
+ x = self.bn3(x)
232
+ x = self.act3(x)
233
+
234
+ x = self.conv4(x)
235
+ x = self.bn4(x)
236
+ x = self.act4(x)
237
+
238
+ x = self.conv5(x)
239
+ x = self.bn5(x)
240
+ x = self.act5(x)
241
+ return x
242
+
243
+ def forward_head(self, x):
244
+ x = self.global_pool(x)
245
+ if self.drop_rate:
246
+ F.dropout(x, self.drop_rate, training=self.training)
247
+ x = self.fc(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ x = self.forward_features(x)
252
+ x = self.forward_head(x)
253
+ return x
254
+
255
+
256
+ def _create_gluon_xception(variant, pretrained=False, **kwargs):
257
+ return build_model_with_cfg(
258
+ Xception65, variant, pretrained,
259
+ feature_cfg=dict(feature_cls='hook'),
260
+ **kwargs)
261
+
262
+
263
+ @register_model
264
+ def gluon_xception65(pretrained=False, **kwargs):
265
+ """ Modified Aligned Xception-65
266
+ """
267
+ return _create_gluon_xception('gluon_xception65', pretrained, **kwargs)
src/custom_timm/models/hardcorenas.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch.nn as nn
4
+
5
+ from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
6
+ from .efficientnet_blocks import SqueezeExcite
7
+ from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
8
+ from .helpers import build_model_with_cfg, pretrained_cfg_for_features
9
+ from .layers import get_act_fn
10
+ from .mobilenetv3 import MobileNetV3, MobileNetV3Features
11
+ from .registry import register_model
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
17
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
18
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
19
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ 'hardcorenas_a': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_a_green_38ms_75_9-31dc7186.pth'),
26
+ 'hardcorenas_b': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_b_green_40ms_76_5-32d91ff2.pth'),
27
+ 'hardcorenas_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_c_green_44ms_77_1-631a0983.pth'),
28
+ 'hardcorenas_d': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_d_green_50ms_77_4-998d9d7a.pth'),
29
+ 'hardcorenas_e': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_e_green_55ms_77_9-482886a3.pth'),
30
+ 'hardcorenas_f': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/hardcorenas_f_green_60ms_78_1-14b9e780.pth'),
31
+ }
32
+
33
+
34
+ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
35
+ """Creates a hardcorenas model
36
+
37
+ Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS
38
+ Paper: https://arxiv.org/abs/2102.11646
39
+
40
+ """
41
+ num_features = 1280
42
+ se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
43
+ model_kwargs = dict(
44
+ block_args=decode_arch_def(arch_def),
45
+ num_features=num_features,
46
+ stem_size=32,
47
+ norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
48
+ act_layer=resolve_act_layer(kwargs, 'hard_swish'),
49
+ se_layer=se_layer,
50
+ **kwargs,
51
+ )
52
+
53
+ features_only = False
54
+ model_cls = MobileNetV3
55
+ kwargs_filter = None
56
+ if model_kwargs.pop('features_only', False):
57
+ features_only = True
58
+ kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
59
+ model_cls = MobileNetV3Features
60
+ model = build_model_with_cfg(
61
+ model_cls, variant, pretrained,
62
+ pretrained_strict=not features_only,
63
+ kwargs_filter=kwargs_filter,
64
+ **model_kwargs)
65
+ if features_only:
66
+ model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
67
+ return model
68
+
69
+
70
+ @register_model
71
+ def hardcorenas_a(pretrained=False, **kwargs):
72
+ """ hardcorenas_A """
73
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
74
+ ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
75
+ ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'],
76
+ ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'],
77
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
78
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs)
79
+ return model
80
+
81
+
82
+ @register_model
83
+ def hardcorenas_b(pretrained=False, **kwargs):
84
+ """ hardcorenas_B """
85
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'],
86
+ ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'],
87
+ ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'],
88
+ ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
89
+ ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
90
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
91
+ ['cn_r1_k1_s1_c960']]
92
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs)
93
+ return model
94
+
95
+
96
+ @register_model
97
+ def hardcorenas_c(pretrained=False, **kwargs):
98
+ """ hardcorenas_C """
99
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
100
+ ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre',
101
+ 'ir_r1_k5_s1_e3_c40_nre'],
102
+ ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
103
+ ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
104
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
105
+ ['cn_r1_k1_s1_c960']]
106
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs)
107
+ return model
108
+
109
+
110
+ @register_model
111
+ def hardcorenas_d(pretrained=False, **kwargs):
112
+ """ hardcorenas_D """
113
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
114
+ ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'],
115
+ ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
116
+ 'ir_r1_k3_s1_e3_c80_se0.25'],
117
+ ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25',
118
+ 'ir_r1_k5_s1_e3_c112_se0.25'],
119
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
120
+ 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
121
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs)
122
+ return model
123
+
124
+
125
+ @register_model
126
+ def hardcorenas_e(pretrained=False, **kwargs):
127
+ """ hardcorenas_E """
128
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
129
+ ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25',
130
+ 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'],
131
+ ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
132
+ 'ir_r1_k5_s1_e3_c112_se0.25'],
133
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
134
+ 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
135
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs)
136
+ return model
137
+
138
+
139
+ @register_model
140
+ def hardcorenas_f(pretrained=False, **kwargs):
141
+ """ hardcorenas_F """
142
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
143
+ ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
144
+ ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
145
+ 'ir_r1_k3_s1_e3_c80_se0.25'],
146
+ ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
147
+ 'ir_r1_k3_s1_e3_c112_se0.25'],
148
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25',
149
+ 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
150
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs)
151
+ return model
src/custom_timm/models/helpers.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Model creation / weight loading / state_dict helpers
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import collections.abc
6
+ import logging
7
+ import math
8
+ import os
9
+ import re
10
+ from collections import OrderedDict, defaultdict
11
+ from copy import deepcopy
12
+ from itertools import chain
13
+ from typing import Any, Callable, Optional, Tuple, Dict, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.hub import load_state_dict_from_url
18
+ from torch.utils.checkpoint import checkpoint
19
+
20
+ from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
21
+ from .fx_features import FeatureGraphNet
22
+ from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
23
+ from .layers import Conv2dSame, Linear, BatchNormAct2d
24
+ from .registry import get_pretrained_cfg
25
+
26
+
27
+ _logger = logging.getLogger(__name__)
28
+
29
+
30
+ # Global variables for rarely used pretrained checkpoint download progress and hash check.
31
+ # Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
32
+ _DOWNLOAD_PROGRESS = False
33
+ _CHECK_HASH = False
34
+
35
+
36
+ def clean_state_dict(state_dict):
37
+ # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
38
+ cleaned_state_dict = OrderedDict()
39
+ for k, v in state_dict.items():
40
+ name = k[7:] if k.startswith('module.') else k
41
+ cleaned_state_dict[name] = v
42
+ return cleaned_state_dict
43
+
44
+
45
+ def load_state_dict(checkpoint_path, use_ema=True):
46
+ if checkpoint_path and os.path.isfile(checkpoint_path):
47
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
48
+ state_dict_key = ''
49
+ if isinstance(checkpoint, dict):
50
+ if use_ema and checkpoint.get('state_dict_ema', None) is not None:
51
+ state_dict_key = 'state_dict_ema'
52
+ elif use_ema and checkpoint.get('model_ema', None) is not None:
53
+ state_dict_key = 'model_ema'
54
+ elif 'state_dict' in checkpoint:
55
+ state_dict_key = 'state_dict'
56
+ elif 'model' in checkpoint:
57
+ state_dict_key = 'model'
58
+ state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
59
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
60
+ return state_dict
61
+ else:
62
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
63
+ raise FileNotFoundError()
64
+
65
+
66
+ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
67
+ if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
68
+ # numpy checkpoint, try to load via model specific load_pretrained fn
69
+ if hasattr(model, 'load_pretrained'):
70
+ model.load_pretrained(checkpoint_path)
71
+ else:
72
+ raise NotImplementedError('Model cannot load numpy checkpoint')
73
+ return
74
+ state_dict = load_state_dict(checkpoint_path, use_ema)
75
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
76
+ return incompatible_keys
77
+
78
+
79
+ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
80
+ resume_epoch = None
81
+ if os.path.isfile(checkpoint_path):
82
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
83
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
84
+ if log_info:
85
+ _logger.info('Restoring model state from checkpoint...')
86
+ state_dict = clean_state_dict(checkpoint['state_dict'])
87
+ model.load_state_dict(state_dict)
88
+
89
+ if optimizer is not None and 'optimizer' in checkpoint:
90
+ if log_info:
91
+ _logger.info('Restoring optimizer state from checkpoint...')
92
+ optimizer.load_state_dict(checkpoint['optimizer'])
93
+
94
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
95
+ if log_info:
96
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
97
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
98
+
99
+ if 'epoch' in checkpoint:
100
+ resume_epoch = checkpoint['epoch']
101
+ if 'version' in checkpoint and checkpoint['version'] > 1:
102
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
103
+
104
+ if log_info:
105
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
106
+ else:
107
+ model.load_state_dict(checkpoint)
108
+ if log_info:
109
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
110
+ return resume_epoch
111
+ else:
112
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
113
+ raise FileNotFoundError()
114
+
115
+
116
+ def _resolve_pretrained_source(pretrained_cfg):
117
+ cfg_source = pretrained_cfg.get('source', '')
118
+ pretrained_url = pretrained_cfg.get('url', None)
119
+ pretrained_file = pretrained_cfg.get('file', None)
120
+ hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
121
+ # resolve where to load pretrained weights from
122
+ load_from = ''
123
+ pretrained_loc = ''
124
+ if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
125
+ # hf-hub specified as source via model identifier
126
+ load_from = 'hf-hub'
127
+ assert hf_hub_id
128
+ pretrained_loc = hf_hub_id
129
+ else:
130
+ # default source == timm or unspecified
131
+ if pretrained_file:
132
+ load_from = 'file'
133
+ pretrained_loc = pretrained_file
134
+ elif pretrained_url:
135
+ load_from = 'url'
136
+ pretrained_loc = pretrained_url
137
+ elif hf_hub_id and has_hf_hub(necessary=True):
138
+ # hf-hub available as alternate weight source in default_cfg
139
+ load_from = 'hf-hub'
140
+ pretrained_loc = hf_hub_id
141
+ if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
142
+ # if a filename override is set, return tuple for location w/ (hub_id, filename)
143
+ pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
144
+ return load_from, pretrained_loc
145
+
146
+
147
+ def set_pretrained_download_progress(enable=True):
148
+ """ Set download progress for pretrained weights on/off (globally). """
149
+ global _DOWNLOAD_PROGRESS
150
+ _DOWNLOAD_PROGRESS = enable
151
+
152
+
153
+ def set_pretrained_check_hash(enable=True):
154
+ """ Set hash checking for pretrained weights on/off (globally). """
155
+ global _CHECK_HASH
156
+ _CHECK_HASH = enable
157
+
158
+
159
+ def load_custom_pretrained(
160
+ model: nn.Module,
161
+ pretrained_cfg: Optional[Dict] = None,
162
+ load_fn: Optional[Callable] = None,
163
+ ):
164
+ r"""Loads a custom (read non .pth) weight file
165
+
166
+ Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
167
+ a passed in custom load fun, or the `load_pretrained` model member fn.
168
+
169
+ If the object is already present in `model_dir`, it's deserialized and returned.
170
+ The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
171
+ `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
172
+
173
+ Args:
174
+ model: The instantiated model to load weights into
175
+ pretrained_cfg (dict): Default pretrained model cfg
176
+ load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
177
+ 'laod_pretrained' on the model will be called if it exists
178
+ """
179
+ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
180
+ load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
181
+ if not load_from:
182
+ _logger.warning("No pretrained weights exist for this model. Using random initialization.")
183
+ return
184
+ if load_from == 'hf-hub': # FIXME
185
+ _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
186
+ elif load_from == 'url':
187
+ pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS)
188
+
189
+ if load_fn is not None:
190
+ load_fn(model, pretrained_loc)
191
+ elif hasattr(model, 'load_pretrained'):
192
+ model.load_pretrained(pretrained_loc)
193
+ else:
194
+ _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
195
+
196
+
197
+ def adapt_input_conv(in_chans, conv_weight):
198
+ conv_type = conv_weight.dtype
199
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
200
+ O, I, J, K = conv_weight.shape
201
+ if in_chans == 1:
202
+ if I > 3:
203
+ assert conv_weight.shape[1] % 3 == 0
204
+ # For models with space2depth stems
205
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
206
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
207
+ else:
208
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
209
+ elif in_chans != 3:
210
+ if I != 3:
211
+ raise NotImplementedError('Weight format not supported by conversion.')
212
+ else:
213
+ # NOTE this strategy should be better than random init, but there could be other combinations of
214
+ # the original RGB input layer weights that'd work better for specific cases.
215
+ repeat = int(math.ceil(in_chans / 3))
216
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
217
+ conv_weight *= (3 / float(in_chans))
218
+ conv_weight = conv_weight.to(conv_type)
219
+ return conv_weight
220
+
221
+
222
+ def load_pretrained(
223
+ model: nn.Module,
224
+ pretrained_cfg: Optional[Dict] = None,
225
+ num_classes: int = 1000,
226
+ in_chans: int = 3,
227
+ filter_fn: Optional[Callable] = None,
228
+ strict: bool = True,
229
+ ):
230
+ """ Load pretrained checkpoint
231
+
232
+ Args:
233
+ model (nn.Module) : PyTorch model module
234
+ pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
235
+ num_classes (int): num_classes for model
236
+ in_chans (int): in_chans for model
237
+ filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
238
+ strict (bool): strict load of checkpoint
239
+
240
+ """
241
+ pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
242
+ load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
243
+ if load_from == 'file':
244
+ _logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
245
+ state_dict = load_state_dict(pretrained_loc)
246
+ elif load_from == 'url':
247
+ _logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
248
+ state_dict = load_state_dict_from_url(
249
+ pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
250
+ elif load_from == 'hf-hub':
251
+ _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
252
+ if isinstance(pretrained_loc, (list, tuple)):
253
+ state_dict = load_state_dict_from_hf(*pretrained_loc)
254
+ else:
255
+ state_dict = load_state_dict_from_hf(pretrained_loc)
256
+ else:
257
+ _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
258
+ return
259
+
260
+ if filter_fn is not None:
261
+ # for backwards compat with filter fn that take one arg, try one first, the two
262
+ try:
263
+ state_dict = filter_fn(state_dict)
264
+ except TypeError:
265
+ state_dict = filter_fn(state_dict, model)
266
+
267
+ input_convs = pretrained_cfg.get('first_conv', None)
268
+ if input_convs is not None and in_chans != 3:
269
+ if isinstance(input_convs, str):
270
+ input_convs = (input_convs,)
271
+ for input_conv_name in input_convs:
272
+ weight_name = input_conv_name + '.weight'
273
+ try:
274
+ state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
275
+ _logger.info(
276
+ f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
277
+ except NotImplementedError as e:
278
+ del state_dict[weight_name]
279
+ strict = False
280
+ _logger.warning(
281
+ f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
282
+
283
+ classifiers = pretrained_cfg.get('classifier', None)
284
+ label_offset = pretrained_cfg.get('label_offset', 0)
285
+ if classifiers is not None:
286
+ if isinstance(classifiers, str):
287
+ classifiers = (classifiers,)
288
+ if num_classes != pretrained_cfg['num_classes']:
289
+ for classifier_name in classifiers:
290
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
291
+ state_dict.pop(classifier_name + '.weight', None)
292
+ state_dict.pop(classifier_name + '.bias', None)
293
+ strict = False
294
+ elif label_offset > 0:
295
+ for classifier_name in classifiers:
296
+ # special case for pretrained weights with an extra background class in pretrained weights
297
+ classifier_weight = state_dict[classifier_name + '.weight']
298
+ state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
299
+ classifier_bias = state_dict[classifier_name + '.bias']
300
+ state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
301
+
302
+ model.load_state_dict(state_dict, strict=strict)
303
+
304
+
305
+ def extract_layer(model, layer):
306
+ layer = layer.split('.')
307
+ module = model
308
+ if hasattr(model, 'module') and layer[0] != 'module':
309
+ module = model.module
310
+ if not hasattr(model, 'module') and layer[0] == 'module':
311
+ layer = layer[1:]
312
+ for l in layer:
313
+ if hasattr(module, l):
314
+ if not l.isdigit():
315
+ module = getattr(module, l)
316
+ else:
317
+ module = module[int(l)]
318
+ else:
319
+ return module
320
+ return module
321
+
322
+
323
+ def set_layer(model, layer, val):
324
+ layer = layer.split('.')
325
+ module = model
326
+ if hasattr(model, 'module') and layer[0] != 'module':
327
+ module = model.module
328
+ lst_index = 0
329
+ module2 = module
330
+ for l in layer:
331
+ if hasattr(module2, l):
332
+ if not l.isdigit():
333
+ module2 = getattr(module2, l)
334
+ else:
335
+ module2 = module2[int(l)]
336
+ lst_index += 1
337
+ lst_index -= 1
338
+ for l in layer[:lst_index]:
339
+ if not l.isdigit():
340
+ module = getattr(module, l)
341
+ else:
342
+ module = module[int(l)]
343
+ l = layer[lst_index]
344
+ setattr(module, l, val)
345
+
346
+
347
+ def adapt_model_from_string(parent_module, model_string):
348
+ separator = '***'
349
+ state_dict = {}
350
+ lst_shape = model_string.split(separator)
351
+ for k in lst_shape:
352
+ k = k.split(':')
353
+ key = k[0]
354
+ shape = k[1][1:-1].split(',')
355
+ if shape[0] != '':
356
+ state_dict[key] = [int(i) for i in shape]
357
+
358
+ new_module = deepcopy(parent_module)
359
+ for n, m in parent_module.named_modules():
360
+ old_module = extract_layer(parent_module, n)
361
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
362
+ if isinstance(old_module, Conv2dSame):
363
+ conv = Conv2dSame
364
+ else:
365
+ conv = nn.Conv2d
366
+ s = state_dict[n + '.weight']
367
+ in_channels = s[1]
368
+ out_channels = s[0]
369
+ g = 1
370
+ if old_module.groups > 1:
371
+ in_channels = out_channels
372
+ g = in_channels
373
+ new_conv = conv(
374
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
375
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
376
+ groups=g, stride=old_module.stride)
377
+ set_layer(new_module, n, new_conv)
378
+ elif isinstance(old_module, BatchNormAct2d):
379
+ new_bn = BatchNormAct2d(
380
+ state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
381
+ affine=old_module.affine, track_running_stats=True)
382
+ new_bn.drop = old_module.drop
383
+ new_bn.act = old_module.act
384
+ set_layer(new_module, n, new_bn)
385
+ elif isinstance(old_module, nn.BatchNorm2d):
386
+ new_bn = nn.BatchNorm2d(
387
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
388
+ affine=old_module.affine, track_running_stats=True)
389
+ set_layer(new_module, n, new_bn)
390
+ elif isinstance(old_module, nn.Linear):
391
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
392
+ num_features = state_dict[n + '.weight'][1]
393
+ new_fc = Linear(
394
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
395
+ set_layer(new_module, n, new_fc)
396
+ if hasattr(new_module, 'num_features'):
397
+ new_module.num_features = num_features
398
+ new_module.eval()
399
+ parent_module.eval()
400
+
401
+ return new_module
402
+
403
+
404
+ def adapt_model_from_file(parent_module, model_variant):
405
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
406
+ with open(adapt_file, 'r') as f:
407
+ return adapt_model_from_string(parent_module, f.read().strip())
408
+
409
+
410
+ def pretrained_cfg_for_features(pretrained_cfg):
411
+ pretrained_cfg = deepcopy(pretrained_cfg)
412
+ # remove default pretrained cfg fields that don't have much relevance for feature backbone
413
+ to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
414
+ for tr in to_remove:
415
+ pretrained_cfg.pop(tr, None)
416
+ return pretrained_cfg
417
+
418
+
419
+ def set_default_kwargs(kwargs, names, pretrained_cfg):
420
+ for n in names:
421
+ # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
422
+ # pretrained_cfg has one input_size=(C, H ,W) entry
423
+ if n == 'img_size':
424
+ input_size = pretrained_cfg.get('input_size', None)
425
+ if input_size is not None:
426
+ assert len(input_size) == 3
427
+ kwargs.setdefault(n, input_size[-2:])
428
+ elif n == 'in_chans':
429
+ input_size = pretrained_cfg.get('input_size', None)
430
+ if input_size is not None:
431
+ assert len(input_size) == 3
432
+ kwargs.setdefault(n, input_size[0])
433
+ else:
434
+ default_val = pretrained_cfg.get(n, None)
435
+ if default_val is not None:
436
+ kwargs.setdefault(n, pretrained_cfg[n])
437
+
438
+
439
+ def filter_kwargs(kwargs, names):
440
+ if not kwargs or not names:
441
+ return
442
+ for n in names:
443
+ kwargs.pop(n, None)
444
+
445
+
446
+ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
447
+ """ Update the default_cfg and kwargs before passing to model
448
+
449
+ Args:
450
+ pretrained_cfg: input pretrained cfg (updated in-place)
451
+ kwargs: keyword args passed to model build fn (updated in-place)
452
+ kwargs_filter: keyword arg keys that must be removed before model __init__
453
+ """
454
+ # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
455
+ default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
456
+ if pretrained_cfg.get('fixed_input_size', False):
457
+ # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
458
+ default_kwarg_names += ('img_size',)
459
+ set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg)
460
+ # Filter keyword args for task specific model variants (some 'features only' models, etc.)
461
+ filter_kwargs(kwargs, names=kwargs_filter)
462
+
463
+
464
+ def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
465
+ if pretrained_cfg and isinstance(pretrained_cfg, dict):
466
+ # highest priority, pretrained_cfg available and passed as arg
467
+ return deepcopy(pretrained_cfg)
468
+ # fallback to looking up pretrained cfg in model registry by variant identifier
469
+ pretrained_cfg = get_pretrained_cfg(variant)
470
+ if not pretrained_cfg:
471
+ _logger.warning(
472
+ f"No pretrained configuration specified for {variant} model. Using a default."
473
+ f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
474
+ pretrained_cfg = dict(
475
+ url='',
476
+ num_classes=1000,
477
+ input_size=(3, 224, 224),
478
+ pool_size=None,
479
+ crop_pct=.9,
480
+ interpolation='bicubic',
481
+ first_conv='',
482
+ classifier='',
483
+ )
484
+ return pretrained_cfg
485
+
486
+
487
+ def build_model_with_cfg(
488
+ model_cls: Callable,
489
+ variant: str,
490
+ pretrained: bool,
491
+ pretrained_cfg: Optional[Dict] = None,
492
+ model_cfg: Optional[Any] = None,
493
+ feature_cfg: Optional[Dict] = None,
494
+ pretrained_strict: bool = True,
495
+ pretrained_filter_fn: Optional[Callable] = None,
496
+ pretrained_custom_load: bool = False,
497
+ kwargs_filter: Optional[Tuple[str]] = None,
498
+ **kwargs):
499
+ """ Build model with specified default_cfg and optional model_cfg
500
+
501
+ This helper fn aids in the construction of a model including:
502
+ * handling default_cfg and associated pretrained weight loading
503
+ * passing through optional model_cfg for models with config based arch spec
504
+ * features_only model adaptation
505
+ * pruning config / model adaptation
506
+
507
+ Args:
508
+ model_cls (nn.Module): model class
509
+ variant (str): model variant name
510
+ pretrained (bool): load pretrained weights
511
+ pretrained_cfg (dict): model's pretrained weight/task config
512
+ model_cfg (Optional[Dict]): model's architecture config
513
+ feature_cfg (Optional[Dict]: feature extraction adapter config
514
+ pretrained_strict (bool): load pretrained weights strictly
515
+ pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
516
+ pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
517
+ kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
518
+ **kwargs: model args passed through to model __init__
519
+ """
520
+ pruned = kwargs.pop('pruned', False)
521
+ features = False
522
+ feature_cfg = feature_cfg or {}
523
+
524
+ # resolve and update model pretrained config and model kwargs
525
+ pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg)
526
+ update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter)
527
+ pretrained_cfg.setdefault('architecture', variant)
528
+
529
+ # Setup for feature extraction wrapper done at end of this fn
530
+ if kwargs.pop('features_only', False):
531
+ features = True
532
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
533
+ if 'out_indices' in kwargs:
534
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
535
+
536
+ # Build the model
537
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
538
+ model.pretrained_cfg = pretrained_cfg
539
+ model.default_cfg = model.pretrained_cfg # alias for backwards compat
540
+
541
+ if pruned:
542
+ model = adapt_model_from_file(model, variant)
543
+
544
+ # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
545
+ num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
546
+ if pretrained:
547
+ if pretrained_custom_load:
548
+ # FIXME improve custom load trigger
549
+ load_custom_pretrained(model, pretrained_cfg=pretrained_cfg)
550
+ else:
551
+ load_pretrained(
552
+ model,
553
+ pretrained_cfg=pretrained_cfg,
554
+ num_classes=num_classes_pretrained,
555
+ in_chans=kwargs.get('in_chans', 3),
556
+ filter_fn=pretrained_filter_fn,
557
+ strict=pretrained_strict)
558
+
559
+ # Wrap the model in a feature extraction module if enabled
560
+ if features:
561
+ feature_cls = FeatureListNet
562
+ if 'feature_cls' in feature_cfg:
563
+ feature_cls = feature_cfg.pop('feature_cls')
564
+ if isinstance(feature_cls, str):
565
+ feature_cls = feature_cls.lower()
566
+ if 'hook' in feature_cls:
567
+ feature_cls = FeatureHookNet
568
+ elif feature_cls == 'fx':
569
+ feature_cls = FeatureGraphNet
570
+ else:
571
+ assert False, f'Unknown feature class {feature_cls}'
572
+ model = feature_cls(model, **feature_cfg)
573
+ model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
574
+ model.default_cfg = model.pretrained_cfg # alias for backwards compat
575
+
576
+ return model
577
+
578
+
579
+ def model_parameters(model, exclude_head=False):
580
+ if exclude_head:
581
+ # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
582
+ return [p for p in model.parameters()][:-2]
583
+ else:
584
+ return model.parameters()
585
+
586
+
587
+ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
588
+ if not depth_first and include_root:
589
+ fn(module=module, name=name)
590
+ for child_name, child_module in module.named_children():
591
+ child_name = '.'.join((name, child_name)) if name else child_name
592
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
593
+ if depth_first and include_root:
594
+ fn(module=module, name=name)
595
+ return module
596
+
597
+
598
+ def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
599
+ if not depth_first and include_root:
600
+ yield name, module
601
+ for child_name, child_module in module.named_children():
602
+ child_name = '.'.join((name, child_name)) if name else child_name
603
+ yield from named_modules(
604
+ module=child_module, name=child_name, depth_first=depth_first, include_root=True)
605
+ if depth_first and include_root:
606
+ yield name, module
607
+
608
+
609
+ def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
610
+ if module._parameters and not depth_first and include_root:
611
+ yield name, module
612
+ for child_name, child_module in module.named_children():
613
+ child_name = '.'.join((name, child_name)) if name else child_name
614
+ yield from named_modules_with_params(
615
+ module=child_module, name=child_name, depth_first=depth_first, include_root=True)
616
+ if module._parameters and depth_first and include_root:
617
+ yield name, module
618
+
619
+
620
+ MATCH_PREV_GROUP = (99999,)
621
+
622
+
623
+ def group_with_matcher(
624
+ named_objects,
625
+ group_matcher: Union[Dict, Callable],
626
+ output_values: bool = False,
627
+ reverse: bool = False
628
+ ):
629
+ if isinstance(group_matcher, dict):
630
+ # dictionary matcher contains a dict of raw-string regex expr that must be compiled
631
+ compiled = []
632
+ for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
633
+ if mspec is None:
634
+ continue
635
+ # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
636
+ if isinstance(mspec, (tuple, list)):
637
+ # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
638
+ for sspec in mspec:
639
+ compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
640
+ else:
641
+ compiled += [(re.compile(mspec), (group_ordinal,), None)]
642
+ group_matcher = compiled
643
+
644
+ def _get_grouping(name):
645
+ if isinstance(group_matcher, (list, tuple)):
646
+ for match_fn, prefix, suffix in group_matcher:
647
+ r = match_fn.match(name)
648
+ if r:
649
+ parts = (prefix, r.groups(), suffix)
650
+ # map all tuple elem to int for numeric sort, filter out None entries
651
+ return tuple(map(float, chain.from_iterable(filter(None, parts))))
652
+ return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
653
+ else:
654
+ ord = group_matcher(name)
655
+ if not isinstance(ord, collections.abc.Iterable):
656
+ return ord,
657
+ return tuple(ord)
658
+
659
+ # map layers into groups via ordinals (ints or tuples of ints) from matcher
660
+ grouping = defaultdict(list)
661
+ for k, v in named_objects:
662
+ grouping[_get_grouping(k)].append(v if output_values else k)
663
+
664
+ # remap to integers
665
+ layer_id_to_param = defaultdict(list)
666
+ lid = -1
667
+ for k in sorted(filter(lambda x: x is not None, grouping.keys())):
668
+ if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
669
+ lid += 1
670
+ layer_id_to_param[lid].extend(grouping[k])
671
+
672
+ if reverse:
673
+ assert not output_values, "reverse mapping only sensible for name output"
674
+ # output reverse mapping
675
+ param_to_layer_id = {}
676
+ for lid, lm in layer_id_to_param.items():
677
+ for n in lm:
678
+ param_to_layer_id[n] = lid
679
+ return param_to_layer_id
680
+
681
+ return layer_id_to_param
682
+
683
+
684
+ def group_parameters(
685
+ module: nn.Module,
686
+ group_matcher,
687
+ output_values=False,
688
+ reverse=False,
689
+ ):
690
+ return group_with_matcher(
691
+ module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
692
+
693
+
694
+ def group_modules(
695
+ module: nn.Module,
696
+ group_matcher,
697
+ output_values=False,
698
+ reverse=False,
699
+ ):
700
+ return group_with_matcher(
701
+ named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
702
+
703
+
704
+ def checkpoint_seq(
705
+ functions,
706
+ x,
707
+ every=1,
708
+ flatten=False,
709
+ skip_last=False,
710
+ preserve_rng_state=True
711
+ ):
712
+ r"""A helper function for checkpointing sequential models.
713
+
714
+ Sequential models execute a list of modules/functions in order
715
+ (sequentially). Therefore, we can divide such a sequence into segments
716
+ and checkpoint each segment. All segments except run in :func:`torch.no_grad`
717
+ manner, i.e., not storing the intermediate activations. The inputs of each
718
+ checkpointed segment will be saved for re-running the segment in the backward pass.
719
+
720
+ See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
721
+
722
+ .. warning::
723
+ Checkpointing currently only supports :func:`torch.autograd.backward`
724
+ and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
725
+ is not supported.
726
+
727
+ .. warning:
728
+ At least one of the inputs needs to have :code:`requires_grad=True` if
729
+ grads are needed for model inputs, otherwise the checkpointed part of the
730
+ model won't have gradients.
731
+
732
+ Args:
733
+ functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
734
+ x: A Tensor that is input to :attr:`functions`
735
+ every: checkpoint every-n functions (default: 1)
736
+ flatten (bool): flatten nn.Sequential of nn.Sequentials
737
+ skip_last (bool): skip checkpointing the last function in the sequence if True
738
+ preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
739
+ the RNG state during each checkpoint.
740
+
741
+ Returns:
742
+ Output of running :attr:`functions` sequentially on :attr:`*inputs`
743
+
744
+ Example:
745
+ >>> model = nn.Sequential(...)
746
+ >>> input_var = checkpoint_seq(model, input_var, every=2)
747
+ """
748
+ def run_function(start, end, functions):
749
+ def forward(_x):
750
+ for j in range(start, end + 1):
751
+ _x = functions[j](_x)
752
+ return _x
753
+ return forward
754
+
755
+ if isinstance(functions, torch.nn.Sequential):
756
+ functions = functions.children()
757
+ if flatten:
758
+ functions = chain.from_iterable(functions)
759
+ if not isinstance(functions, (tuple, list)):
760
+ functions = tuple(functions)
761
+
762
+ num_checkpointed = len(functions)
763
+ if skip_last:
764
+ num_checkpointed -= 1
765
+ end = -1
766
+ for start in range(0, num_checkpointed, every):
767
+ end = min(start + every - 1, num_checkpointed - 1)
768
+ x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
769
+ if skip_last:
770
+ return run_function(end + 1, len(functions) - 1, functions)(x)
771
+ return x
772
+
773
+
774
+ def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
775
+ prefix_is_tuple = isinstance(prefix, tuple)
776
+ if isinstance(module_types, str):
777
+ if module_types == 'container':
778
+ module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
779
+ else:
780
+ module_types = (nn.Sequential,)
781
+ for name, module in named_modules:
782
+ if depth and isinstance(module, module_types):
783
+ yield from flatten_modules(
784
+ module.named_children(),
785
+ depth - 1,
786
+ prefix=(name,) if prefix_is_tuple else name,
787
+ module_types=module_types,
788
+ )
789
+ else:
790
+ if prefix_is_tuple:
791
+ name = prefix + (name,)
792
+ yield name, module
793
+ else:
794
+ if prefix:
795
+ name = '.'.join([prefix, name])
796
+ yield name, module
src/custom_timm/models/hrnet.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ HRNet
2
+
3
+ Copied from https://github.com/HRNet/HRNet-Image-Classification
4
+
5
+ Original header:
6
+ Copyright (c) Microsoft
7
+ Licensed under the MIT License.
8
+ Written by Bin Xiao (Bin.Xiao@microsoft.com)
9
+ Modified by Ke Sun (sunk@mail.ustc.edu.cn)
10
+ """
11
+ import logging
12
+ from typing import List
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19
+ from .features import FeatureInfo
20
+ from .helpers import build_model_with_cfg, pretrained_cfg_for_features
21
+ from .layers import create_classifier
22
+ from .registry import register_model
23
+ from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
24
+
25
+ _BN_MOMENTUM = 0.1
26
+ _logger = logging.getLogger(__name__)
27
+
28
+
29
+ def _cfg(url='', **kwargs):
30
+ return {
31
+ 'url': url,
32
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
33
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
34
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
35
+ 'first_conv': 'conv1', 'classifier': 'classifier',
36
+ **kwargs
37
+ }
38
+
39
+
40
+ default_cfgs = {
41
+ 'hrnet_w18_small': _cfg(
42
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'),
43
+ 'hrnet_w18_small_v2': _cfg(
44
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'),
45
+ 'hrnet_w18': _cfg(
46
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'),
47
+ 'hrnet_w30': _cfg(
48
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'),
49
+ 'hrnet_w32': _cfg(
50
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'),
51
+ 'hrnet_w40': _cfg(
52
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'),
53
+ 'hrnet_w44': _cfg(
54
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'),
55
+ 'hrnet_w48': _cfg(
56
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'),
57
+ 'hrnet_w64': _cfg(
58
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'),
59
+ }
60
+
61
+ cfg_cls = dict(
62
+ hrnet_w18_small=dict(
63
+ STEM_WIDTH=64,
64
+ STAGE1=dict(
65
+ NUM_MODULES=1,
66
+ NUM_BRANCHES=1,
67
+ BLOCK='BOTTLENECK',
68
+ NUM_BLOCKS=(1,),
69
+ NUM_CHANNELS=(32,),
70
+ FUSE_METHOD='SUM',
71
+ ),
72
+ STAGE2=dict(
73
+ NUM_MODULES=1,
74
+ NUM_BRANCHES=2,
75
+ BLOCK='BASIC',
76
+ NUM_BLOCKS=(2, 2),
77
+ NUM_CHANNELS=(16, 32),
78
+ FUSE_METHOD='SUM'
79
+ ),
80
+ STAGE3=dict(
81
+ NUM_MODULES=1,
82
+ NUM_BRANCHES=3,
83
+ BLOCK='BASIC',
84
+ NUM_BLOCKS=(2, 2, 2),
85
+ NUM_CHANNELS=(16, 32, 64),
86
+ FUSE_METHOD='SUM'
87
+ ),
88
+ STAGE4=dict(
89
+ NUM_MODULES=1,
90
+ NUM_BRANCHES=4,
91
+ BLOCK='BASIC',
92
+ NUM_BLOCKS=(2, 2, 2, 2),
93
+ NUM_CHANNELS=(16, 32, 64, 128),
94
+ FUSE_METHOD='SUM',
95
+ ),
96
+ ),
97
+
98
+ hrnet_w18_small_v2=dict(
99
+ STEM_WIDTH=64,
100
+ STAGE1=dict(
101
+ NUM_MODULES=1,
102
+ NUM_BRANCHES=1,
103
+ BLOCK='BOTTLENECK',
104
+ NUM_BLOCKS=(2,),
105
+ NUM_CHANNELS=(64,),
106
+ FUSE_METHOD='SUM',
107
+ ),
108
+ STAGE2=dict(
109
+ NUM_MODULES=1,
110
+ NUM_BRANCHES=2,
111
+ BLOCK='BASIC',
112
+ NUM_BLOCKS=(2, 2),
113
+ NUM_CHANNELS=(18, 36),
114
+ FUSE_METHOD='SUM'
115
+ ),
116
+ STAGE3=dict(
117
+ NUM_MODULES=3,
118
+ NUM_BRANCHES=3,
119
+ BLOCK='BASIC',
120
+ NUM_BLOCKS=(2, 2, 2),
121
+ NUM_CHANNELS=(18, 36, 72),
122
+ FUSE_METHOD='SUM'
123
+ ),
124
+ STAGE4=dict(
125
+ NUM_MODULES=2,
126
+ NUM_BRANCHES=4,
127
+ BLOCK='BASIC',
128
+ NUM_BLOCKS=(2, 2, 2, 2),
129
+ NUM_CHANNELS=(18, 36, 72, 144),
130
+ FUSE_METHOD='SUM',
131
+ ),
132
+ ),
133
+
134
+ hrnet_w18=dict(
135
+ STEM_WIDTH=64,
136
+ STAGE1=dict(
137
+ NUM_MODULES=1,
138
+ NUM_BRANCHES=1,
139
+ BLOCK='BOTTLENECK',
140
+ NUM_BLOCKS=(4,),
141
+ NUM_CHANNELS=(64,),
142
+ FUSE_METHOD='SUM',
143
+ ),
144
+ STAGE2=dict(
145
+ NUM_MODULES=1,
146
+ NUM_BRANCHES=2,
147
+ BLOCK='BASIC',
148
+ NUM_BLOCKS=(4, 4),
149
+ NUM_CHANNELS=(18, 36),
150
+ FUSE_METHOD='SUM'
151
+ ),
152
+ STAGE3=dict(
153
+ NUM_MODULES=4,
154
+ NUM_BRANCHES=3,
155
+ BLOCK='BASIC',
156
+ NUM_BLOCKS=(4, 4, 4),
157
+ NUM_CHANNELS=(18, 36, 72),
158
+ FUSE_METHOD='SUM'
159
+ ),
160
+ STAGE4=dict(
161
+ NUM_MODULES=3,
162
+ NUM_BRANCHES=4,
163
+ BLOCK='BASIC',
164
+ NUM_BLOCKS=(4, 4, 4, 4),
165
+ NUM_CHANNELS=(18, 36, 72, 144),
166
+ FUSE_METHOD='SUM',
167
+ ),
168
+ ),
169
+
170
+ hrnet_w30=dict(
171
+ STEM_WIDTH=64,
172
+ STAGE1=dict(
173
+ NUM_MODULES=1,
174
+ NUM_BRANCHES=1,
175
+ BLOCK='BOTTLENECK',
176
+ NUM_BLOCKS=(4,),
177
+ NUM_CHANNELS=(64,),
178
+ FUSE_METHOD='SUM',
179
+ ),
180
+ STAGE2=dict(
181
+ NUM_MODULES=1,
182
+ NUM_BRANCHES=2,
183
+ BLOCK='BASIC',
184
+ NUM_BLOCKS=(4, 4),
185
+ NUM_CHANNELS=(30, 60),
186
+ FUSE_METHOD='SUM'
187
+ ),
188
+ STAGE3=dict(
189
+ NUM_MODULES=4,
190
+ NUM_BRANCHES=3,
191
+ BLOCK='BASIC',
192
+ NUM_BLOCKS=(4, 4, 4),
193
+ NUM_CHANNELS=(30, 60, 120),
194
+ FUSE_METHOD='SUM'
195
+ ),
196
+ STAGE4=dict(
197
+ NUM_MODULES=3,
198
+ NUM_BRANCHES=4,
199
+ BLOCK='BASIC',
200
+ NUM_BLOCKS=(4, 4, 4, 4),
201
+ NUM_CHANNELS=(30, 60, 120, 240),
202
+ FUSE_METHOD='SUM',
203
+ ),
204
+ ),
205
+
206
+ hrnet_w32=dict(
207
+ STEM_WIDTH=64,
208
+ STAGE1=dict(
209
+ NUM_MODULES=1,
210
+ NUM_BRANCHES=1,
211
+ BLOCK='BOTTLENECK',
212
+ NUM_BLOCKS=(4,),
213
+ NUM_CHANNELS=(64,),
214
+ FUSE_METHOD='SUM',
215
+ ),
216
+ STAGE2=dict(
217
+ NUM_MODULES=1,
218
+ NUM_BRANCHES=2,
219
+ BLOCK='BASIC',
220
+ NUM_BLOCKS=(4, 4),
221
+ NUM_CHANNELS=(32, 64),
222
+ FUSE_METHOD='SUM'
223
+ ),
224
+ STAGE3=dict(
225
+ NUM_MODULES=4,
226
+ NUM_BRANCHES=3,
227
+ BLOCK='BASIC',
228
+ NUM_BLOCKS=(4, 4, 4),
229
+ NUM_CHANNELS=(32, 64, 128),
230
+ FUSE_METHOD='SUM'
231
+ ),
232
+ STAGE4=dict(
233
+ NUM_MODULES=3,
234
+ NUM_BRANCHES=4,
235
+ BLOCK='BASIC',
236
+ NUM_BLOCKS=(4, 4, 4, 4),
237
+ NUM_CHANNELS=(32, 64, 128, 256),
238
+ FUSE_METHOD='SUM',
239
+ ),
240
+ ),
241
+
242
+ hrnet_w40=dict(
243
+ STEM_WIDTH=64,
244
+ STAGE1=dict(
245
+ NUM_MODULES=1,
246
+ NUM_BRANCHES=1,
247
+ BLOCK='BOTTLENECK',
248
+ NUM_BLOCKS=(4,),
249
+ NUM_CHANNELS=(64,),
250
+ FUSE_METHOD='SUM',
251
+ ),
252
+ STAGE2=dict(
253
+ NUM_MODULES=1,
254
+ NUM_BRANCHES=2,
255
+ BLOCK='BASIC',
256
+ NUM_BLOCKS=(4, 4),
257
+ NUM_CHANNELS=(40, 80),
258
+ FUSE_METHOD='SUM'
259
+ ),
260
+ STAGE3=dict(
261
+ NUM_MODULES=4,
262
+ NUM_BRANCHES=3,
263
+ BLOCK='BASIC',
264
+ NUM_BLOCKS=(4, 4, 4),
265
+ NUM_CHANNELS=(40, 80, 160),
266
+ FUSE_METHOD='SUM'
267
+ ),
268
+ STAGE4=dict(
269
+ NUM_MODULES=3,
270
+ NUM_BRANCHES=4,
271
+ BLOCK='BASIC',
272
+ NUM_BLOCKS=(4, 4, 4, 4),
273
+ NUM_CHANNELS=(40, 80, 160, 320),
274
+ FUSE_METHOD='SUM',
275
+ ),
276
+ ),
277
+
278
+ hrnet_w44=dict(
279
+ STEM_WIDTH=64,
280
+ STAGE1=dict(
281
+ NUM_MODULES=1,
282
+ NUM_BRANCHES=1,
283
+ BLOCK='BOTTLENECK',
284
+ NUM_BLOCKS=(4,),
285
+ NUM_CHANNELS=(64,),
286
+ FUSE_METHOD='SUM',
287
+ ),
288
+ STAGE2=dict(
289
+ NUM_MODULES=1,
290
+ NUM_BRANCHES=2,
291
+ BLOCK='BASIC',
292
+ NUM_BLOCKS=(4, 4),
293
+ NUM_CHANNELS=(44, 88),
294
+ FUSE_METHOD='SUM'
295
+ ),
296
+ STAGE3=dict(
297
+ NUM_MODULES=4,
298
+ NUM_BRANCHES=3,
299
+ BLOCK='BASIC',
300
+ NUM_BLOCKS=(4, 4, 4),
301
+ NUM_CHANNELS=(44, 88, 176),
302
+ FUSE_METHOD='SUM'
303
+ ),
304
+ STAGE4=dict(
305
+ NUM_MODULES=3,
306
+ NUM_BRANCHES=4,
307
+ BLOCK='BASIC',
308
+ NUM_BLOCKS=(4, 4, 4, 4),
309
+ NUM_CHANNELS=(44, 88, 176, 352),
310
+ FUSE_METHOD='SUM',
311
+ ),
312
+ ),
313
+
314
+ hrnet_w48=dict(
315
+ STEM_WIDTH=64,
316
+ STAGE1=dict(
317
+ NUM_MODULES=1,
318
+ NUM_BRANCHES=1,
319
+ BLOCK='BOTTLENECK',
320
+ NUM_BLOCKS=(4,),
321
+ NUM_CHANNELS=(64,),
322
+ FUSE_METHOD='SUM',
323
+ ),
324
+ STAGE2=dict(
325
+ NUM_MODULES=1,
326
+ NUM_BRANCHES=2,
327
+ BLOCK='BASIC',
328
+ NUM_BLOCKS=(4, 4),
329
+ NUM_CHANNELS=(48, 96),
330
+ FUSE_METHOD='SUM'
331
+ ),
332
+ STAGE3=dict(
333
+ NUM_MODULES=4,
334
+ NUM_BRANCHES=3,
335
+ BLOCK='BASIC',
336
+ NUM_BLOCKS=(4, 4, 4),
337
+ NUM_CHANNELS=(48, 96, 192),
338
+ FUSE_METHOD='SUM'
339
+ ),
340
+ STAGE4=dict(
341
+ NUM_MODULES=3,
342
+ NUM_BRANCHES=4,
343
+ BLOCK='BASIC',
344
+ NUM_BLOCKS=(4, 4, 4, 4),
345
+ NUM_CHANNELS=(48, 96, 192, 384),
346
+ FUSE_METHOD='SUM',
347
+ ),
348
+ ),
349
+
350
+ hrnet_w64=dict(
351
+ STEM_WIDTH=64,
352
+ STAGE1=dict(
353
+ NUM_MODULES=1,
354
+ NUM_BRANCHES=1,
355
+ BLOCK='BOTTLENECK',
356
+ NUM_BLOCKS=(4,),
357
+ NUM_CHANNELS=(64,),
358
+ FUSE_METHOD='SUM',
359
+ ),
360
+ STAGE2=dict(
361
+ NUM_MODULES=1,
362
+ NUM_BRANCHES=2,
363
+ BLOCK='BASIC',
364
+ NUM_BLOCKS=(4, 4),
365
+ NUM_CHANNELS=(64, 128),
366
+ FUSE_METHOD='SUM'
367
+ ),
368
+ STAGE3=dict(
369
+ NUM_MODULES=4,
370
+ NUM_BRANCHES=3,
371
+ BLOCK='BASIC',
372
+ NUM_BLOCKS=(4, 4, 4),
373
+ NUM_CHANNELS=(64, 128, 256),
374
+ FUSE_METHOD='SUM'
375
+ ),
376
+ STAGE4=dict(
377
+ NUM_MODULES=3,
378
+ NUM_BRANCHES=4,
379
+ BLOCK='BASIC',
380
+ NUM_BLOCKS=(4, 4, 4, 4),
381
+ NUM_CHANNELS=(64, 128, 256, 512),
382
+ FUSE_METHOD='SUM',
383
+ ),
384
+ )
385
+ )
386
+
387
+
388
+ class HighResolutionModule(nn.Module):
389
+ def __init__(self, num_branches, blocks, num_blocks, num_in_chs,
390
+ num_channels, fuse_method, multi_scale_output=True):
391
+ super(HighResolutionModule, self).__init__()
392
+ self._check_branches(
393
+ num_branches, blocks, num_blocks, num_in_chs, num_channels)
394
+
395
+ self.num_in_chs = num_in_chs
396
+ self.fuse_method = fuse_method
397
+ self.num_branches = num_branches
398
+
399
+ self.multi_scale_output = multi_scale_output
400
+
401
+ self.branches = self._make_branches(
402
+ num_branches, blocks, num_blocks, num_channels)
403
+ self.fuse_layers = self._make_fuse_layers()
404
+ self.fuse_act = nn.ReLU(False)
405
+
406
+ def _check_branches(self, num_branches, blocks, num_blocks, num_in_chs, num_channels):
407
+ error_msg = ''
408
+ if num_branches != len(num_blocks):
409
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks))
410
+ elif num_branches != len(num_channels):
411
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels))
412
+ elif num_branches != len(num_in_chs):
413
+ error_msg = 'NUM_BRANCHES({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs))
414
+ if error_msg:
415
+ _logger.error(error_msg)
416
+ raise ValueError(error_msg)
417
+
418
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
419
+ downsample = None
420
+ if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block.expansion:
421
+ downsample = nn.Sequential(
422
+ nn.Conv2d(
423
+ self.num_in_chs[branch_index], num_channels[branch_index] * block.expansion,
424
+ kernel_size=1, stride=stride, bias=False),
425
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM),
426
+ )
427
+
428
+ layers = [block(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)]
429
+ self.num_in_chs[branch_index] = num_channels[branch_index] * block.expansion
430
+ for i in range(1, num_blocks[branch_index]):
431
+ layers.append(block(self.num_in_chs[branch_index], num_channels[branch_index]))
432
+
433
+ return nn.Sequential(*layers)
434
+
435
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
436
+ branches = []
437
+ for i in range(num_branches):
438
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
439
+
440
+ return nn.ModuleList(branches)
441
+
442
+ def _make_fuse_layers(self):
443
+ if self.num_branches == 1:
444
+ return nn.Identity()
445
+
446
+ num_branches = self.num_branches
447
+ num_in_chs = self.num_in_chs
448
+ fuse_layers = []
449
+ for i in range(num_branches if self.multi_scale_output else 1):
450
+ fuse_layer = []
451
+ for j in range(num_branches):
452
+ if j > i:
453
+ fuse_layer.append(nn.Sequential(
454
+ nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False),
455
+ nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM),
456
+ nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
457
+ elif j == i:
458
+ fuse_layer.append(nn.Identity())
459
+ else:
460
+ conv3x3s = []
461
+ for k in range(i - j):
462
+ if k == i - j - 1:
463
+ num_outchannels_conv3x3 = num_in_chs[i]
464
+ conv3x3s.append(nn.Sequential(
465
+ nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
466
+ nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM)))
467
+ else:
468
+ num_outchannels_conv3x3 = num_in_chs[j]
469
+ conv3x3s.append(nn.Sequential(
470
+ nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
471
+ nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM),
472
+ nn.ReLU(False)))
473
+ fuse_layer.append(nn.Sequential(*conv3x3s))
474
+ fuse_layers.append(nn.ModuleList(fuse_layer))
475
+
476
+ return nn.ModuleList(fuse_layers)
477
+
478
+ def get_num_in_chs(self):
479
+ return self.num_in_chs
480
+
481
+ def forward(self, x: List[torch.Tensor]):
482
+ if self.num_branches == 1:
483
+ return [self.branches[0](x[0])]
484
+
485
+ for i, branch in enumerate(self.branches):
486
+ x[i] = branch(x[i])
487
+
488
+ x_fuse = []
489
+ for i, fuse_outer in enumerate(self.fuse_layers):
490
+ y = x[0] if i == 0 else fuse_outer[0](x[0])
491
+ for j in range(1, self.num_branches):
492
+ if i == j:
493
+ y = y + x[j]
494
+ else:
495
+ y = y + fuse_outer[j](x[j])
496
+ x_fuse.append(self.fuse_act(y))
497
+
498
+ return x_fuse
499
+
500
+
501
+ blocks_dict = {
502
+ 'BASIC': BasicBlock,
503
+ 'BOTTLENECK': Bottleneck
504
+ }
505
+
506
+
507
+ class HighResolutionNet(nn.Module):
508
+
509
+ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'):
510
+ super(HighResolutionNet, self).__init__()
511
+ self.num_classes = num_classes
512
+ self.drop_rate = drop_rate
513
+
514
+ stem_width = cfg['STEM_WIDTH']
515
+ self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False)
516
+ self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM)
517
+ self.act1 = nn.ReLU(inplace=True)
518
+ self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False)
519
+ self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM)
520
+ self.act2 = nn.ReLU(inplace=True)
521
+
522
+ self.stage1_cfg = cfg['STAGE1']
523
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
524
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
525
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
526
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
527
+ stage1_out_channel = block.expansion * num_channels
528
+
529
+ self.stage2_cfg = cfg['STAGE2']
530
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
531
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
532
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
533
+ self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
534
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
535
+
536
+ self.stage3_cfg = cfg['STAGE3']
537
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
538
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
539
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
540
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
541
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
542
+
543
+ self.stage4_cfg = cfg['STAGE4']
544
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
545
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
546
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
547
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
548
+ self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
549
+
550
+ self.head = head
551
+ self.head_channels = None # set if _make_head called
552
+ if head == 'classification':
553
+ # Classification Head
554
+ self.num_features = 2048
555
+ self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels)
556
+ self.global_pool, self.classifier = create_classifier(
557
+ self.num_features, self.num_classes, pool_type=global_pool)
558
+ elif head == 'incre':
559
+ self.num_features = 2048
560
+ self.incre_modules, _, _ = self._make_head(pre_stage_channels, True)
561
+ else:
562
+ self.incre_modules = None
563
+ self.num_features = 256
564
+
565
+ curr_stride = 2
566
+ # module names aren't actually valid here, hook or FeatureNet based extraction would not work
567
+ self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')]
568
+ for i, c in enumerate(self.head_channels if self.head_channels else num_channels):
569
+ curr_stride *= 2
570
+ c = c * 4 if self.head_channels else c # head block expansion factor of 4
571
+ self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')]
572
+
573
+ self.init_weights()
574
+
575
+ def _make_head(self, pre_stage_channels, incre_only=False):
576
+ head_block = Bottleneck
577
+ self.head_channels = [32, 64, 128, 256]
578
+
579
+ # Increasing the #channels on each resolution
580
+ # from C, 2C, 4C, 8C to 128, 256, 512, 1024
581
+ incre_modules = []
582
+ for i, channels in enumerate(pre_stage_channels):
583
+ incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1))
584
+ incre_modules = nn.ModuleList(incre_modules)
585
+ if incre_only:
586
+ return incre_modules, None, None
587
+
588
+ # downsampling modules
589
+ downsamp_modules = []
590
+ for i in range(len(pre_stage_channels) - 1):
591
+ in_channels = self.head_channels[i] * head_block.expansion
592
+ out_channels = self.head_channels[i + 1] * head_block.expansion
593
+ downsamp_module = nn.Sequential(
594
+ nn.Conv2d(
595
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1),
596
+ nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM),
597
+ nn.ReLU(inplace=True)
598
+ )
599
+ downsamp_modules.append(downsamp_module)
600
+ downsamp_modules = nn.ModuleList(downsamp_modules)
601
+
602
+ final_layer = nn.Sequential(
603
+ nn.Conv2d(
604
+ in_channels=self.head_channels[3] * head_block.expansion,
605
+ out_channels=self.num_features, kernel_size=1, stride=1, padding=0
606
+ ),
607
+ nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM),
608
+ nn.ReLU(inplace=True)
609
+ )
610
+
611
+ return incre_modules, downsamp_modules, final_layer
612
+
613
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
614
+ num_branches_cur = len(num_channels_cur_layer)
615
+ num_branches_pre = len(num_channels_pre_layer)
616
+
617
+ transition_layers = []
618
+ for i in range(num_branches_cur):
619
+ if i < num_branches_pre:
620
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
621
+ transition_layers.append(nn.Sequential(
622
+ nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
623
+ nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
624
+ nn.ReLU(inplace=True)))
625
+ else:
626
+ transition_layers.append(nn.Identity())
627
+ else:
628
+ conv3x3s = []
629
+ for j in range(i + 1 - num_branches_pre):
630
+ inchannels = num_channels_pre_layer[-1]
631
+ outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
632
+ conv3x3s.append(nn.Sequential(
633
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
634
+ nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM),
635
+ nn.ReLU(inplace=True)))
636
+ transition_layers.append(nn.Sequential(*conv3x3s))
637
+
638
+ return nn.ModuleList(transition_layers)
639
+
640
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
641
+ downsample = None
642
+ if stride != 1 or inplanes != planes * block.expansion:
643
+ downsample = nn.Sequential(
644
+ nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
645
+ nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM),
646
+ )
647
+
648
+ layers = [block(inplanes, planes, stride, downsample)]
649
+ inplanes = planes * block.expansion
650
+ for i in range(1, blocks):
651
+ layers.append(block(inplanes, planes))
652
+
653
+ return nn.Sequential(*layers)
654
+
655
+ def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True):
656
+ num_modules = layer_config['NUM_MODULES']
657
+ num_branches = layer_config['NUM_BRANCHES']
658
+ num_blocks = layer_config['NUM_BLOCKS']
659
+ num_channels = layer_config['NUM_CHANNELS']
660
+ block = blocks_dict[layer_config['BLOCK']]
661
+ fuse_method = layer_config['FUSE_METHOD']
662
+
663
+ modules = []
664
+ for i in range(num_modules):
665
+ # multi_scale_output is only used last module
666
+ reset_multi_scale_output = multi_scale_output or i < num_modules - 1
667
+ modules.append(HighResolutionModule(
668
+ num_branches, block, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output)
669
+ )
670
+ num_in_chs = modules[-1].get_num_in_chs()
671
+
672
+ return nn.Sequential(*modules), num_in_chs
673
+
674
+ @torch.jit.ignore
675
+ def init_weights(self):
676
+ for m in self.modules():
677
+ if isinstance(m, nn.Conv2d):
678
+ nn.init.kaiming_normal_(
679
+ m.weight, mode='fan_out', nonlinearity='relu')
680
+ elif isinstance(m, nn.BatchNorm2d):
681
+ nn.init.constant_(m.weight, 1)
682
+ nn.init.constant_(m.bias, 0)
683
+
684
+ @torch.jit.ignore
685
+ def group_matcher(self, coarse=False):
686
+ matcher = dict(
687
+ stem=r'^conv[12]|bn[12]',
688
+ blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
689
+ (r'^layer(\d+)\.(\d+)', None),
690
+ (r'^stage(\d+)\.(\d+)', None),
691
+ (r'^transition(\d+)', (99999,)),
692
+ ],
693
+ )
694
+ return matcher
695
+
696
+ @torch.jit.ignore
697
+ def set_grad_checkpointing(self, enable=True):
698
+ assert not enable, "gradient checkpointing not supported"
699
+
700
+ @torch.jit.ignore
701
+ def get_classifier(self):
702
+ return self.classifier
703
+
704
+ def reset_classifier(self, num_classes, global_pool='avg'):
705
+ self.num_classes = num_classes
706
+ self.global_pool, self.classifier = create_classifier(
707
+ self.num_features, self.num_classes, pool_type=global_pool)
708
+
709
+ def stages(self, x) -> List[torch.Tensor]:
710
+ x = self.layer1(x)
711
+
712
+ xl = [t(x) for i, t in enumerate(self.transition1)]
713
+ yl = self.stage2(xl)
714
+
715
+ xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)]
716
+ yl = self.stage3(xl)
717
+
718
+ xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)]
719
+ yl = self.stage4(xl)
720
+ return yl
721
+
722
+ def forward_features(self, x):
723
+ # Stem
724
+ x = self.conv1(x)
725
+ x = self.bn1(x)
726
+ x = self.act1(x)
727
+ x = self.conv2(x)
728
+ x = self.bn2(x)
729
+ x = self.act2(x)
730
+
731
+ # Stages
732
+ yl = self.stages(x)
733
+ if self.incre_modules is None or self.downsamp_modules is None:
734
+ return yl
735
+ y = self.incre_modules[0](yl[0])
736
+ for i, down in enumerate(self.downsamp_modules):
737
+ y = self.incre_modules[i + 1](yl[i + 1]) + down(y)
738
+ y = self.final_layer(y)
739
+ return y
740
+
741
+ def forward_head(self, x, pre_logits: bool = False):
742
+ # Classification Head
743
+ x = self.global_pool(x)
744
+ if self.drop_rate > 0.:
745
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
746
+ return x if pre_logits else self.classifier(x)
747
+
748
+ def forward(self, x):
749
+ y = self.forward_features(x)
750
+ x = self.forward_head(y)
751
+ return x
752
+
753
+
754
+ class HighResolutionNetFeatures(HighResolutionNet):
755
+ """HighResolutionNet feature extraction
756
+
757
+ The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so.
758
+ It would be more complicated to use the FeatureNet helpers.
759
+
760
+ The `feature_location=incre` allows grabbing increased channel count features using part of the
761
+ classification head. If `feature_location=''` the default HRNet features are returned. First stem
762
+ conv is used for stride 2 features.
763
+ """
764
+
765
+ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0,
766
+ feature_location='incre', out_indices=(0, 1, 2, 3, 4)):
767
+ assert feature_location in ('incre', '')
768
+ super(HighResolutionNetFeatures, self).__init__(
769
+ cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
770
+ drop_rate=drop_rate, head=feature_location)
771
+ self.feature_info = FeatureInfo(self.feature_info, out_indices)
772
+ self._out_idx = {i for i in out_indices}
773
+
774
+ def forward_features(self, x):
775
+ assert False, 'Not supported'
776
+
777
+ def forward(self, x) -> List[torch.tensor]:
778
+ out = []
779
+ x = self.conv1(x)
780
+ x = self.bn1(x)
781
+ x = self.act1(x)
782
+ if 0 in self._out_idx:
783
+ out.append(x)
784
+ x = self.conv2(x)
785
+ x = self.bn2(x)
786
+ x = self.act2(x)
787
+ x = self.stages(x)
788
+ if self.incre_modules is not None:
789
+ x = [incre(f) for f, incre in zip(x, self.incre_modules)]
790
+ for i, f in enumerate(x):
791
+ if i + 1 in self._out_idx:
792
+ out.append(f)
793
+ return out
794
+
795
+
796
+ def _create_hrnet(variant, pretrained, **model_kwargs):
797
+ model_cls = HighResolutionNet
798
+ features_only = False
799
+ kwargs_filter = None
800
+ if model_kwargs.pop('features_only', False):
801
+ model_cls = HighResolutionNetFeatures
802
+ kwargs_filter = ('num_classes', 'global_pool')
803
+ features_only = True
804
+ model = build_model_with_cfg(
805
+ model_cls, variant, pretrained,
806
+ model_cfg=cfg_cls[variant],
807
+ pretrained_strict=not features_only,
808
+ kwargs_filter=kwargs_filter,
809
+ **model_kwargs)
810
+ if features_only:
811
+ model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
812
+ model.default_cfg = model.pretrained_cfg # backwards compat
813
+ return model
814
+
815
+
816
+ @register_model
817
+ def hrnet_w18_small(pretrained=False, **kwargs):
818
+ return _create_hrnet('hrnet_w18_small', pretrained, **kwargs)
819
+
820
+
821
+ @register_model
822
+ def hrnet_w18_small_v2(pretrained=False, **kwargs):
823
+ return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs)
824
+
825
+
826
+ @register_model
827
+ def hrnet_w18(pretrained=False, **kwargs):
828
+ return _create_hrnet('hrnet_w18', pretrained, **kwargs)
829
+
830
+
831
+ @register_model
832
+ def hrnet_w30(pretrained=False, **kwargs):
833
+ return _create_hrnet('hrnet_w30', pretrained, **kwargs)
834
+
835
+
836
+ @register_model
837
+ def hrnet_w32(pretrained=False, **kwargs):
838
+ return _create_hrnet('hrnet_w32', pretrained, **kwargs)
839
+
840
+
841
+ @register_model
842
+ def hrnet_w40(pretrained=False, **kwargs):
843
+ return _create_hrnet('hrnet_w40', pretrained, **kwargs)
844
+
845
+
846
+ @register_model
847
+ def hrnet_w44(pretrained=False, **kwargs):
848
+ return _create_hrnet('hrnet_w44', pretrained, **kwargs)
849
+
850
+
851
+ @register_model
852
+ def hrnet_w48(pretrained=False, **kwargs):
853
+ return _create_hrnet('hrnet_w48', pretrained, **kwargs)
854
+
855
+
856
+ @register_model
857
+ def hrnet_w64(pretrained=False, **kwargs):
858
+ return _create_hrnet('hrnet_w64', pretrained, **kwargs)
src/custom_timm/models/hub.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ from torch.hub import HASH_REGEX, download_url_to_file, urlparse
11
+
12
+ try:
13
+ from torch.hub import get_dir
14
+ except ImportError:
15
+ from torch.hub import _get_torch_home as get_dir
16
+
17
+ from custom_timm import __version__
18
+
19
+ try:
20
+ from huggingface_hub import (create_repo, get_hf_file_metadata,
21
+ hf_hub_download, hf_hub_url,
22
+ repo_type_and_id_from_hf_id, upload_folder)
23
+ from huggingface_hub.utils import EntryNotFoundError
24
+ hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
25
+ _has_hf_hub = True
26
+ except ImportError:
27
+ hf_hub_download = None
28
+ _has_hf_hub = False
29
+
30
+ _logger = logging.getLogger(__name__)
31
+
32
+
33
+ def get_cache_dir(child_dir=''):
34
+ """
35
+ Returns the location of the directory where models are cached (and creates it if necessary).
36
+ """
37
+ # Issue warning to move data if old env is set
38
+ if os.getenv('TORCH_MODEL_ZOO'):
39
+ _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
40
+
41
+ hub_dir = get_dir()
42
+ child_dir = () if not child_dir else (child_dir,)
43
+ model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
44
+ os.makedirs(model_dir, exist_ok=True)
45
+ return model_dir
46
+
47
+
48
+ def download_cached_file(url, check_hash=True, progress=False):
49
+ parts = urlparse(url)
50
+ filename = os.path.basename(parts.path)
51
+ cached_file = os.path.join(get_cache_dir(), filename)
52
+ if not os.path.exists(cached_file):
53
+ _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
54
+ hash_prefix = None
55
+ if check_hash:
56
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
57
+ hash_prefix = r.group(1) if r else None
58
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
59
+ return cached_file
60
+
61
+
62
+ def has_hf_hub(necessary=False):
63
+ if not _has_hf_hub and necessary:
64
+ # if no HF Hub module installed, and it is necessary to continue, raise error
65
+ raise RuntimeError(
66
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
67
+ return _has_hf_hub
68
+
69
+
70
+ def hf_split(hf_id):
71
+ # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
72
+ rev_split = hf_id.split('@')
73
+ assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
74
+ hf_model_id = rev_split[0]
75
+ hf_revision = rev_split[-1] if len(rev_split) > 1 else None
76
+ return hf_model_id, hf_revision
77
+
78
+
79
+ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
80
+ with open(json_file, "r", encoding="utf-8") as reader:
81
+ text = reader.read()
82
+ return json.loads(text)
83
+
84
+
85
+ def _download_from_hf(model_id: str, filename: str):
86
+ hf_model_id, hf_revision = hf_split(model_id)
87
+ return hf_hub_download(hf_model_id, filename, revision=hf_revision)
88
+
89
+
90
+ def load_model_config_from_hf(model_id: str):
91
+ assert has_hf_hub(True)
92
+ cached_file = _download_from_hf(model_id, 'config.json')
93
+ pretrained_cfg = load_cfg_from_json(cached_file)
94
+ pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
95
+ pretrained_cfg['source'] = 'hf-hub'
96
+ model_name = pretrained_cfg.get('architecture')
97
+ return pretrained_cfg, model_name
98
+
99
+
100
+ def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
101
+ assert has_hf_hub(True)
102
+ cached_file = _download_from_hf(model_id, filename)
103
+ state_dict = torch.load(cached_file, map_location='cpu')
104
+ return state_dict
105
+
106
+
107
+ def save_for_hf(model, save_directory, model_config=None):
108
+ assert has_hf_hub(True)
109
+ model_config = model_config or {}
110
+ save_directory = Path(save_directory)
111
+ save_directory.mkdir(exist_ok=True, parents=True)
112
+
113
+ weights_path = save_directory / 'pytorch_model.bin'
114
+ torch.save(model.state_dict(), weights_path)
115
+
116
+ config_path = save_directory / 'config.json'
117
+ hf_config = model.pretrained_cfg
118
+ hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
119
+ hf_config['num_features'] = model_config.pop('num_features', model.num_features)
120
+ hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
121
+ hf_config.update(model_config)
122
+
123
+ with config_path.open('w') as f:
124
+ json.dump(hf_config, f, indent=2)
125
+
126
+
127
+ def push_to_hf_hub(
128
+ model,
129
+ repo_id: str,
130
+ commit_message: str ='Add model',
131
+ token: Optional[str] = None,
132
+ revision: Optional[str] = None,
133
+ private: bool = False,
134
+ create_pr: bool = False,
135
+ model_config: Optional[dict] = None,
136
+ ):
137
+ # Create repo if doesn't exist yet
138
+ repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
139
+
140
+ # Infer complete repo_id from repo_url
141
+ # Can be different from the input `repo_id` if repo_owner was implicit
142
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
143
+ repo_id = f"{repo_owner}/{repo_name}"
144
+
145
+ # Check if README file already exist in repo
146
+ try:
147
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
148
+ has_readme = True
149
+ except EntryNotFoundError:
150
+ has_readme = False
151
+
152
+ # Dump model and push to Hub
153
+ with TemporaryDirectory() as tmpdir:
154
+ # Save model weights and config.
155
+ save_for_hf(model, tmpdir, model_config=model_config)
156
+
157
+ # Add readme if does not exist
158
+ if not has_readme:
159
+ readme_path = Path(tmpdir) / "README.md"
160
+ readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
161
+ readme_path.write_text(readme_text)
162
+
163
+ # Upload model and return
164
+ return upload_folder(
165
+ repo_id=repo_id,
166
+ folder_path=tmpdir,
167
+ revision=revision,
168
+ create_pr=create_pr,
169
+ commit_message=commit_message,
170
+ )
src/custom_timm/models/inception_resnet_v2.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Pytorch Inception-Resnet-V2 implementation
2
+ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
3
+ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from custom_timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
10
+ from .helpers import build_model_with_cfg, flatten_modules
11
+ from .layers import create_classifier
12
+ from .registry import register_model
13
+
14
+ __all__ = ['InceptionResnetV2']
15
+
16
+ default_cfgs = {
17
+ # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
18
+ 'inception_resnet_v2': {
19
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
20
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
21
+ 'crop_pct': 0.8975, 'interpolation': 'bicubic',
22
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
23
+ 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
24
+ 'label_offset': 1, # 1001 classes in pretrained weights
25
+ },
26
+ # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
27
+ 'ens_adv_inception_resnet_v2': {
28
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
29
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
30
+ 'crop_pct': 0.8975, 'interpolation': 'bicubic',
31
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
32
+ 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
33
+ 'label_offset': 1, # 1001 classes in pretrained weights
34
+ }
35
+ }
36
+
37
+
38
+ class BasicConv2d(nn.Module):
39
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
40
+ super(BasicConv2d, self).__init__()
41
+ self.conv = nn.Conv2d(
42
+ in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
43
+ self.bn = nn.BatchNorm2d(out_planes, eps=.001)
44
+ self.relu = nn.ReLU(inplace=False)
45
+
46
+ def forward(self, x):
47
+ x = self.conv(x)
48
+ x = self.bn(x)
49
+ x = self.relu(x)
50
+ return x
51
+
52
+
53
+ class Mixed_5b(nn.Module):
54
+ def __init__(self):
55
+ super(Mixed_5b, self).__init__()
56
+
57
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
58
+
59
+ self.branch1 = nn.Sequential(
60
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
61
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
62
+ )
63
+
64
+ self.branch2 = nn.Sequential(
65
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
66
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
67
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
68
+ )
69
+
70
+ self.branch3 = nn.Sequential(
71
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
72
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
73
+ )
74
+
75
+ def forward(self, x):
76
+ x0 = self.branch0(x)
77
+ x1 = self.branch1(x)
78
+ x2 = self.branch2(x)
79
+ x3 = self.branch3(x)
80
+ out = torch.cat((x0, x1, x2, x3), 1)
81
+ return out
82
+
83
+
84
+ class Block35(nn.Module):
85
+ def __init__(self, scale=1.0):
86
+ super(Block35, self).__init__()
87
+
88
+ self.scale = scale
89
+
90
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
91
+
92
+ self.branch1 = nn.Sequential(
93
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
94
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
95
+ )
96
+
97
+ self.branch2 = nn.Sequential(
98
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
99
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
100
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
101
+ )
102
+
103
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
104
+ self.relu = nn.ReLU(inplace=False)
105
+
106
+ def forward(self, x):
107
+ x0 = self.branch0(x)
108
+ x1 = self.branch1(x)
109
+ x2 = self.branch2(x)
110
+ out = torch.cat((x0, x1, x2), 1)
111
+ out = self.conv2d(out)
112
+ out = out * self.scale + x
113
+ out = self.relu(out)
114
+ return out
115
+
116
+
117
+ class Mixed_6a(nn.Module):
118
+ def __init__(self):
119
+ super(Mixed_6a, self).__init__()
120
+
121
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
122
+
123
+ self.branch1 = nn.Sequential(
124
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
125
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
126
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
127
+ )
128
+
129
+ self.branch2 = nn.MaxPool2d(3, stride=2)
130
+
131
+ def forward(self, x):
132
+ x0 = self.branch0(x)
133
+ x1 = self.branch1(x)
134
+ x2 = self.branch2(x)
135
+ out = torch.cat((x0, x1, x2), 1)
136
+ return out
137
+
138
+
139
+ class Block17(nn.Module):
140
+ def __init__(self, scale=1.0):
141
+ super(Block17, self).__init__()
142
+
143
+ self.scale = scale
144
+
145
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
146
+
147
+ self.branch1 = nn.Sequential(
148
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
149
+ BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
150
+ BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
151
+ )
152
+
153
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
154
+ self.relu = nn.ReLU(inplace=False)
155
+
156
+ def forward(self, x):
157
+ x0 = self.branch0(x)
158
+ x1 = self.branch1(x)
159
+ out = torch.cat((x0, x1), 1)
160
+ out = self.conv2d(out)
161
+ out = out * self.scale + x
162
+ out = self.relu(out)
163
+ return out
164
+
165
+
166
+ class Mixed_7a(nn.Module):
167
+ def __init__(self):
168
+ super(Mixed_7a, self).__init__()
169
+
170
+ self.branch0 = nn.Sequential(
171
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
172
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
173
+ )
174
+
175
+ self.branch1 = nn.Sequential(
176
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
177
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
178
+ )
179
+
180
+ self.branch2 = nn.Sequential(
181
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
182
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
183
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
184
+ )
185
+
186
+ self.branch3 = nn.MaxPool2d(3, stride=2)
187
+
188
+ def forward(self, x):
189
+ x0 = self.branch0(x)
190
+ x1 = self.branch1(x)
191
+ x2 = self.branch2(x)
192
+ x3 = self.branch3(x)
193
+ out = torch.cat((x0, x1, x2, x3), 1)
194
+ return out
195
+
196
+
197
+ class Block8(nn.Module):
198
+
199
+ def __init__(self, scale=1.0, no_relu=False):
200
+ super(Block8, self).__init__()
201
+
202
+ self.scale = scale
203
+
204
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
205
+
206
+ self.branch1 = nn.Sequential(
207
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
208
+ BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
209
+ BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
210
+ )
211
+
212
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
213
+ self.relu = None if no_relu else nn.ReLU(inplace=False)
214
+
215
+ def forward(self, x):
216
+ x0 = self.branch0(x)
217
+ x1 = self.branch1(x)
218
+ out = torch.cat((x0, x1), 1)
219
+ out = self.conv2d(out)
220
+ out = out * self.scale + x
221
+ if self.relu is not None:
222
+ out = self.relu(out)
223
+ return out
224
+
225
+
226
+ class InceptionResnetV2(nn.Module):
227
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
228
+ super(InceptionResnetV2, self).__init__()
229
+ self.drop_rate = drop_rate
230
+ self.num_classes = num_classes
231
+ self.num_features = 1536
232
+ assert output_stride == 32
233
+
234
+ self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
235
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
236
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
237
+ self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
238
+
239
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
240
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
241
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
242
+ self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
243
+
244
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
245
+ self.mixed_5b = Mixed_5b()
246
+ self.repeat = nn.Sequential(
247
+ Block35(scale=0.17),
248
+ Block35(scale=0.17),
249
+ Block35(scale=0.17),
250
+ Block35(scale=0.17),
251
+ Block35(scale=0.17),
252
+ Block35(scale=0.17),
253
+ Block35(scale=0.17),
254
+ Block35(scale=0.17),
255
+ Block35(scale=0.17),
256
+ Block35(scale=0.17)
257
+ )
258
+ self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
259
+
260
+ self.mixed_6a = Mixed_6a()
261
+ self.repeat_1 = nn.Sequential(
262
+ Block17(scale=0.10),
263
+ Block17(scale=0.10),
264
+ Block17(scale=0.10),
265
+ Block17(scale=0.10),
266
+ Block17(scale=0.10),
267
+ Block17(scale=0.10),
268
+ Block17(scale=0.10),
269
+ Block17(scale=0.10),
270
+ Block17(scale=0.10),
271
+ Block17(scale=0.10),
272
+ Block17(scale=0.10),
273
+ Block17(scale=0.10),
274
+ Block17(scale=0.10),
275
+ Block17(scale=0.10),
276
+ Block17(scale=0.10),
277
+ Block17(scale=0.10),
278
+ Block17(scale=0.10),
279
+ Block17(scale=0.10),
280
+ Block17(scale=0.10),
281
+ Block17(scale=0.10)
282
+ )
283
+ self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
284
+
285
+ self.mixed_7a = Mixed_7a()
286
+ self.repeat_2 = nn.Sequential(
287
+ Block8(scale=0.20),
288
+ Block8(scale=0.20),
289
+ Block8(scale=0.20),
290
+ Block8(scale=0.20),
291
+ Block8(scale=0.20),
292
+ Block8(scale=0.20),
293
+ Block8(scale=0.20),
294
+ Block8(scale=0.20),
295
+ Block8(scale=0.20)
296
+ )
297
+ self.block8 = Block8(no_relu=True)
298
+ self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
299
+ self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
300
+
301
+ self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
302
+
303
+ @torch.jit.ignore
304
+ def group_matcher(self, coarse=False):
305
+ module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
306
+ module_map.pop(('classif',))
307
+
308
+ def _matcher(name):
309
+ if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]):
310
+ return 0
311
+ elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]):
312
+ return 1
313
+ elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]):
314
+ return len(module_map) + 1
315
+ else:
316
+ for k in module_map.keys():
317
+ if k == tuple(name.split('.')[:len(k)]):
318
+ return module_map[k]
319
+ return float('inf')
320
+ return _matcher
321
+
322
+ @torch.jit.ignore
323
+ def set_grad_checkpointing(self, enable=True):
324
+ assert not enable, "checkpointing not supported"
325
+
326
+ @torch.jit.ignore
327
+ def get_classifier(self):
328
+ return self.classif
329
+
330
+ def reset_classifier(self, num_classes, global_pool='avg'):
331
+ self.num_classes = num_classes
332
+ self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
333
+
334
+ def forward_features(self, x):
335
+ x = self.conv2d_1a(x)
336
+ x = self.conv2d_2a(x)
337
+ x = self.conv2d_2b(x)
338
+ x = self.maxpool_3a(x)
339
+ x = self.conv2d_3b(x)
340
+ x = self.conv2d_4a(x)
341
+ x = self.maxpool_5a(x)
342
+ x = self.mixed_5b(x)
343
+ x = self.repeat(x)
344
+ x = self.mixed_6a(x)
345
+ x = self.repeat_1(x)
346
+ x = self.mixed_7a(x)
347
+ x = self.repeat_2(x)
348
+ x = self.block8(x)
349
+ x = self.conv2d_7b(x)
350
+ return x
351
+
352
+ def forward_head(self, x, pre_logits: bool = False):
353
+ x = self.global_pool(x)
354
+ if self.drop_rate > 0:
355
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
356
+ return x if pre_logits else self.classif(x)
357
+
358
+ def forward(self, x):
359
+ x = self.forward_features(x)
360
+ x = self.forward_head(x)
361
+ return x
362
+
363
+
364
+ def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
365
+ return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs)
366
+
367
+
368
+ @register_model
369
+ def inception_resnet_v2(pretrained=False, **kwargs):
370
+ r"""InceptionResnetV2 model architecture from the
371
+ `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
372
+ """
373
+ return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
374
+
375
+
376
+ @register_model
377
+ def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
378
+ r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
379
+ As per https://arxiv.org/abs/1705.07204 and
380
+ https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
381
+ """
382
+ return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)
src/custom_timm/models/inception_v3.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Inception-V3
2
+
3
+ Originally from torchvision Inception3 model
4
+ Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from custom_timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
11
+ from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules
12
+ from .registry import register_model
13
+ from .layers import trunc_normal_, create_classifier, Linear
14
+
15
+
16
+ def _cfg(url='', **kwargs):
17
+ return {
18
+ 'url': url,
19
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
20
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
21
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
22
+ 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc',
23
+ **kwargs
24
+ }
25
+
26
+
27
+ default_cfgs = {
28
+ # original PyTorch weights, ported from Tensorflow but modified
29
+ 'inception_v3': _cfg(
30
+ url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
31
+ has_aux=True), # checkpoint has aux logit layer weights
32
+ # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
33
+ 'tf_inception_v3': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
35
+ num_classes=1000, has_aux=False, label_offset=1),
36
+ # my port of Tensorflow adversarially trained Inception V3 from
37
+ # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
38
+ 'adv_inception_v3': _cfg(
39
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
40
+ num_classes=1000, has_aux=False, label_offset=1),
41
+ # from gluon pretrained models, best performing in terms of accuracy/loss metrics
42
+ # https://gluon-cv.mxnet.io/model_zoo/classification.html
43
+ 'gluon_inception_v3': _cfg(
44
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
45
+ mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
46
+ std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
47
+ has_aux=False,
48
+ )
49
+ }
50
+
51
+
52
+ class InceptionA(nn.Module):
53
+
54
+ def __init__(self, in_channels, pool_features, conv_block=None):
55
+ super(InceptionA, self).__init__()
56
+ if conv_block is None:
57
+ conv_block = BasicConv2d
58
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
59
+
60
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
61
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
62
+
63
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
64
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
65
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
66
+
67
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
68
+
69
+ def _forward(self, x):
70
+ branch1x1 = self.branch1x1(x)
71
+
72
+ branch5x5 = self.branch5x5_1(x)
73
+ branch5x5 = self.branch5x5_2(branch5x5)
74
+
75
+ branch3x3dbl = self.branch3x3dbl_1(x)
76
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
77
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
78
+
79
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
80
+ branch_pool = self.branch_pool(branch_pool)
81
+
82
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
83
+ return outputs
84
+
85
+ def forward(self, x):
86
+ outputs = self._forward(x)
87
+ return torch.cat(outputs, 1)
88
+
89
+
90
+ class InceptionB(nn.Module):
91
+
92
+ def __init__(self, in_channels, conv_block=None):
93
+ super(InceptionB, self).__init__()
94
+ if conv_block is None:
95
+ conv_block = BasicConv2d
96
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
97
+
98
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
99
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
100
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
101
+
102
+ def _forward(self, x):
103
+ branch3x3 = self.branch3x3(x)
104
+
105
+ branch3x3dbl = self.branch3x3dbl_1(x)
106
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
107
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
108
+
109
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
110
+
111
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
112
+ return outputs
113
+
114
+ def forward(self, x):
115
+ outputs = self._forward(x)
116
+ return torch.cat(outputs, 1)
117
+
118
+
119
+ class InceptionC(nn.Module):
120
+
121
+ def __init__(self, in_channels, channels_7x7, conv_block=None):
122
+ super(InceptionC, self).__init__()
123
+ if conv_block is None:
124
+ conv_block = BasicConv2d
125
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
126
+
127
+ c7 = channels_7x7
128
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
129
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
130
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
131
+
132
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
133
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
134
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
135
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
136
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
137
+
138
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
139
+
140
+ def _forward(self, x):
141
+ branch1x1 = self.branch1x1(x)
142
+
143
+ branch7x7 = self.branch7x7_1(x)
144
+ branch7x7 = self.branch7x7_2(branch7x7)
145
+ branch7x7 = self.branch7x7_3(branch7x7)
146
+
147
+ branch7x7dbl = self.branch7x7dbl_1(x)
148
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
149
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
150
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
151
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
152
+
153
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
154
+ branch_pool = self.branch_pool(branch_pool)
155
+
156
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
157
+ return outputs
158
+
159
+ def forward(self, x):
160
+ outputs = self._forward(x)
161
+ return torch.cat(outputs, 1)
162
+
163
+
164
+ class InceptionD(nn.Module):
165
+
166
+ def __init__(self, in_channels, conv_block=None):
167
+ super(InceptionD, self).__init__()
168
+ if conv_block is None:
169
+ conv_block = BasicConv2d
170
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
171
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
172
+
173
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
174
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
175
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
176
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
177
+
178
+ def _forward(self, x):
179
+ branch3x3 = self.branch3x3_1(x)
180
+ branch3x3 = self.branch3x3_2(branch3x3)
181
+
182
+ branch7x7x3 = self.branch7x7x3_1(x)
183
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
184
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
185
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
186
+
187
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
188
+ outputs = [branch3x3, branch7x7x3, branch_pool]
189
+ return outputs
190
+
191
+ def forward(self, x):
192
+ outputs = self._forward(x)
193
+ return torch.cat(outputs, 1)
194
+
195
+
196
+ class InceptionE(nn.Module):
197
+
198
+ def __init__(self, in_channels, conv_block=None):
199
+ super(InceptionE, self).__init__()
200
+ if conv_block is None:
201
+ conv_block = BasicConv2d
202
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
203
+
204
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
205
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
206
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
207
+
208
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
209
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
210
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
211
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
212
+
213
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
214
+
215
+ def _forward(self, x):
216
+ branch1x1 = self.branch1x1(x)
217
+
218
+ branch3x3 = self.branch3x3_1(x)
219
+ branch3x3 = [
220
+ self.branch3x3_2a(branch3x3),
221
+ self.branch3x3_2b(branch3x3),
222
+ ]
223
+ branch3x3 = torch.cat(branch3x3, 1)
224
+
225
+ branch3x3dbl = self.branch3x3dbl_1(x)
226
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
227
+ branch3x3dbl = [
228
+ self.branch3x3dbl_3a(branch3x3dbl),
229
+ self.branch3x3dbl_3b(branch3x3dbl),
230
+ ]
231
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
232
+
233
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
234
+ branch_pool = self.branch_pool(branch_pool)
235
+
236
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
237
+ return outputs
238
+
239
+ def forward(self, x):
240
+ outputs = self._forward(x)
241
+ return torch.cat(outputs, 1)
242
+
243
+
244
+ class InceptionAux(nn.Module):
245
+
246
+ def __init__(self, in_channels, num_classes, conv_block=None):
247
+ super(InceptionAux, self).__init__()
248
+ if conv_block is None:
249
+ conv_block = BasicConv2d
250
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
251
+ self.conv1 = conv_block(128, 768, kernel_size=5)
252
+ self.conv1.stddev = 0.01
253
+ self.fc = Linear(768, num_classes)
254
+ self.fc.stddev = 0.001
255
+
256
+ def forward(self, x):
257
+ # N x 768 x 17 x 17
258
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
259
+ # N x 768 x 5 x 5
260
+ x = self.conv0(x)
261
+ # N x 128 x 5 x 5
262
+ x = self.conv1(x)
263
+ # N x 768 x 1 x 1
264
+ # Adaptive average pooling
265
+ x = F.adaptive_avg_pool2d(x, (1, 1))
266
+ # N x 768 x 1 x 1
267
+ x = torch.flatten(x, 1)
268
+ # N x 768
269
+ x = self.fc(x)
270
+ # N x 1000
271
+ return x
272
+
273
+
274
+ class BasicConv2d(nn.Module):
275
+
276
+ def __init__(self, in_channels, out_channels, **kwargs):
277
+ super(BasicConv2d, self).__init__()
278
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
279
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
280
+
281
+ def forward(self, x):
282
+ x = self.conv(x)
283
+ x = self.bn(x)
284
+ return F.relu(x, inplace=True)
285
+
286
+
287
+ class InceptionV3(nn.Module):
288
+ """Inception-V3 with no AuxLogits
289
+ FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
290
+ """
291
+
292
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False):
293
+ super(InceptionV3, self).__init__()
294
+ self.num_classes = num_classes
295
+ self.drop_rate = drop_rate
296
+ self.aux_logits = aux_logits
297
+
298
+ self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
299
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
300
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
301
+ self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
302
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
303
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
304
+ self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
305
+ self.Mixed_5b = InceptionA(192, pool_features=32)
306
+ self.Mixed_5c = InceptionA(256, pool_features=64)
307
+ self.Mixed_5d = InceptionA(288, pool_features=64)
308
+ self.Mixed_6a = InceptionB(288)
309
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
310
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
311
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
312
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
313
+ if aux_logits:
314
+ self.AuxLogits = InceptionAux(768, num_classes)
315
+ else:
316
+ self.AuxLogits = None
317
+ self.Mixed_7a = InceptionD(768)
318
+ self.Mixed_7b = InceptionE(1280)
319
+ self.Mixed_7c = InceptionE(2048)
320
+ self.feature_info = [
321
+ dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
322
+ dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
323
+ dict(num_chs=288, reduction=8, module='Mixed_5d'),
324
+ dict(num_chs=768, reduction=16, module='Mixed_6e'),
325
+ dict(num_chs=2048, reduction=32, module='Mixed_7c'),
326
+ ]
327
+
328
+ self.num_features = 2048
329
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
330
+
331
+ for m in self.modules():
332
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
333
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
334
+ trunc_normal_(m.weight, std=stddev)
335
+ elif isinstance(m, nn.BatchNorm2d):
336
+ nn.init.constant_(m.weight, 1)
337
+ nn.init.constant_(m.bias, 0)
338
+
339
+ @torch.jit.ignore
340
+ def group_matcher(self, coarse=False):
341
+ module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
342
+ module_map.pop(('fc',))
343
+
344
+ def _matcher(name):
345
+ if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]):
346
+ return 0
347
+ elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]):
348
+ return 1
349
+ else:
350
+ for k in module_map.keys():
351
+ if k == tuple(name.split('.')[:len(k)]):
352
+ return module_map[k]
353
+ return float('inf')
354
+ return _matcher
355
+
356
+ @torch.jit.ignore
357
+ def set_grad_checkpointing(self, enable=True):
358
+ assert not enable, 'gradient checkpointing not supported'
359
+
360
+ @torch.jit.ignore
361
+ def get_classifier(self):
362
+ return self.fc
363
+
364
+ def reset_classifier(self, num_classes, global_pool='avg'):
365
+ self.num_classes = num_classes
366
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
367
+
368
+ def forward_preaux(self, x):
369
+ x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149
370
+ x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147
371
+ x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147
372
+ x = self.Pool1(x) # N x 64 x 73 x 73
373
+ x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73
374
+ x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71
375
+ x = self.Pool2(x) # N x 192 x 35 x 35
376
+ x = self.Mixed_5b(x) # N x 256 x 35 x 35
377
+ x = self.Mixed_5c(x) # N x 288 x 35 x 35
378
+ x = self.Mixed_5d(x) # N x 288 x 35 x 35
379
+ x = self.Mixed_6a(x) # N x 768 x 17 x 17
380
+ x = self.Mixed_6b(x) # N x 768 x 17 x 17
381
+ x = self.Mixed_6c(x) # N x 768 x 17 x 17
382
+ x = self.Mixed_6d(x) # N x 768 x 17 x 17
383
+ x = self.Mixed_6e(x) # N x 768 x 17 x 17
384
+ return x
385
+
386
+ def forward_postaux(self, x):
387
+ x = self.Mixed_7a(x) # N x 1280 x 8 x 8
388
+ x = self.Mixed_7b(x) # N x 2048 x 8 x 8
389
+ x = self.Mixed_7c(x) # N x 2048 x 8 x 8
390
+ return x
391
+
392
+ def forward_features(self, x):
393
+ x = self.forward_preaux(x)
394
+ x = self.forward_postaux(x)
395
+ return x
396
+
397
+ def forward_head(self, x):
398
+ x = self.global_pool(x)
399
+ if self.drop_rate > 0:
400
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
401
+ x = self.fc(x)
402
+ return x
403
+
404
+ def forward(self, x):
405
+ x = self.forward_features(x)
406
+ x = self.forward_head(x)
407
+ return x
408
+
409
+
410
+ class InceptionV3Aux(InceptionV3):
411
+ """InceptionV3 with AuxLogits
412
+ """
413
+
414
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True):
415
+ super(InceptionV3Aux, self).__init__(
416
+ num_classes, in_chans, drop_rate, global_pool, aux_logits)
417
+
418
+ def forward_features(self, x):
419
+ x = self.forward_preaux(x)
420
+ aux = self.AuxLogits(x) if self.training else None
421
+ x = self.forward_postaux(x)
422
+ return x, aux
423
+
424
+ def forward(self, x):
425
+ x, aux = self.forward_features(x)
426
+ x = self.forward_head(x)
427
+ return x, aux
428
+
429
+
430
+ def _create_inception_v3(variant, pretrained=False, **kwargs):
431
+ pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
432
+ aux_logits = kwargs.pop('aux_logits', False)
433
+ if aux_logits:
434
+ assert not kwargs.pop('features_only', False)
435
+ model_cls = InceptionV3Aux
436
+ load_strict = pretrained_cfg['has_aux']
437
+ else:
438
+ model_cls = InceptionV3
439
+ load_strict = not pretrained_cfg['has_aux']
440
+
441
+ return build_model_with_cfg(
442
+ model_cls, variant, pretrained,
443
+ pretrained_cfg=pretrained_cfg,
444
+ pretrained_strict=load_strict,
445
+ **kwargs)
446
+
447
+
448
+ @register_model
449
+ def inception_v3(pretrained=False, **kwargs):
450
+ # original PyTorch weights, ported from Tensorflow but modified
451
+ model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
452
+ return model
453
+
454
+
455
+ @register_model
456
+ def tf_inception_v3(pretrained=False, **kwargs):
457
+ # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
458
+ model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
459
+ return model
460
+
461
+
462
+ @register_model
463
+ def adv_inception_v3(pretrained=False, **kwargs):
464
+ # my port of Tensorflow adversarially trained Inception V3 from
465
+ # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
466
+ model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
467
+ return model
468
+
469
+
470
+ @register_model
471
+ def gluon_inception_v3(pretrained=False, **kwargs):
472
+ # from gluon pretrained models, best performing in terms of accuracy/loss metrics
473
+ # https://gluon-cv.mxnet.io/model_zoo/classification.html
474
+ model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
475
+ return model
src/custom_timm/models/inception_v4.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Pytorch Inception-V4 implementation
2
+ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
3
+ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from custom_timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
10
+ from .helpers import build_model_with_cfg
11
+ from .layers import create_classifier
12
+ from .registry import register_model
13
+
14
+ __all__ = ['InceptionV4']
15
+
16
+ default_cfgs = {
17
+ 'inception_v4': {
18
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
19
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
20
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
21
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
22
+ 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
23
+ 'label_offset': 1, # 1001 classes in pretrained weights
24
+ }
25
+ }
26
+
27
+
28
+ class BasicConv2d(nn.Module):
29
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
30
+ super(BasicConv2d, self).__init__()
31
+ self.conv = nn.Conv2d(
32
+ in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
33
+ self.bn = nn.BatchNorm2d(out_planes, eps=0.001)
34
+ self.relu = nn.ReLU(inplace=True)
35
+
36
+ def forward(self, x):
37
+ x = self.conv(x)
38
+ x = self.bn(x)
39
+ x = self.relu(x)
40
+ return x
41
+
42
+
43
+ class Mixed3a(nn.Module):
44
+ def __init__(self):
45
+ super(Mixed3a, self).__init__()
46
+ self.maxpool = nn.MaxPool2d(3, stride=2)
47
+ self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
48
+
49
+ def forward(self, x):
50
+ x0 = self.maxpool(x)
51
+ x1 = self.conv(x)
52
+ out = torch.cat((x0, x1), 1)
53
+ return out
54
+
55
+
56
+ class Mixed4a(nn.Module):
57
+ def __init__(self):
58
+ super(Mixed4a, self).__init__()
59
+
60
+ self.branch0 = nn.Sequential(
61
+ BasicConv2d(160, 64, kernel_size=1, stride=1),
62
+ BasicConv2d(64, 96, kernel_size=3, stride=1)
63
+ )
64
+
65
+ self.branch1 = nn.Sequential(
66
+ BasicConv2d(160, 64, kernel_size=1, stride=1),
67
+ BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
68
+ BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
69
+ BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
70
+ )
71
+
72
+ def forward(self, x):
73
+ x0 = self.branch0(x)
74
+ x1 = self.branch1(x)
75
+ out = torch.cat((x0, x1), 1)
76
+ return out
77
+
78
+
79
+ class Mixed5a(nn.Module):
80
+ def __init__(self):
81
+ super(Mixed5a, self).__init__()
82
+ self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
83
+ self.maxpool = nn.MaxPool2d(3, stride=2)
84
+
85
+ def forward(self, x):
86
+ x0 = self.conv(x)
87
+ x1 = self.maxpool(x)
88
+ out = torch.cat((x0, x1), 1)
89
+ return out
90
+
91
+
92
+ class InceptionA(nn.Module):
93
+ def __init__(self):
94
+ super(InceptionA, self).__init__()
95
+ self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
96
+
97
+ self.branch1 = nn.Sequential(
98
+ BasicConv2d(384, 64, kernel_size=1, stride=1),
99
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
100
+ )
101
+
102
+ self.branch2 = nn.Sequential(
103
+ BasicConv2d(384, 64, kernel_size=1, stride=1),
104
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
105
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
106
+ )
107
+
108
+ self.branch3 = nn.Sequential(
109
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
110
+ BasicConv2d(384, 96, kernel_size=1, stride=1)
111
+ )
112
+
113
+ def forward(self, x):
114
+ x0 = self.branch0(x)
115
+ x1 = self.branch1(x)
116
+ x2 = self.branch2(x)
117
+ x3 = self.branch3(x)
118
+ out = torch.cat((x0, x1, x2, x3), 1)
119
+ return out
120
+
121
+
122
+ class ReductionA(nn.Module):
123
+ def __init__(self):
124
+ super(ReductionA, self).__init__()
125
+ self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
126
+
127
+ self.branch1 = nn.Sequential(
128
+ BasicConv2d(384, 192, kernel_size=1, stride=1),
129
+ BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
130
+ BasicConv2d(224, 256, kernel_size=3, stride=2)
131
+ )
132
+
133
+ self.branch2 = nn.MaxPool2d(3, stride=2)
134
+
135
+ def forward(self, x):
136
+ x0 = self.branch0(x)
137
+ x1 = self.branch1(x)
138
+ x2 = self.branch2(x)
139
+ out = torch.cat((x0, x1, x2), 1)
140
+ return out
141
+
142
+
143
+ class InceptionB(nn.Module):
144
+ def __init__(self):
145
+ super(InceptionB, self).__init__()
146
+ self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
147
+
148
+ self.branch1 = nn.Sequential(
149
+ BasicConv2d(1024, 192, kernel_size=1, stride=1),
150
+ BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
151
+ BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
152
+ )
153
+
154
+ self.branch2 = nn.Sequential(
155
+ BasicConv2d(1024, 192, kernel_size=1, stride=1),
156
+ BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
157
+ BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
158
+ BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
159
+ BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
160
+ )
161
+
162
+ self.branch3 = nn.Sequential(
163
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
164
+ BasicConv2d(1024, 128, kernel_size=1, stride=1)
165
+ )
166
+
167
+ def forward(self, x):
168
+ x0 = self.branch0(x)
169
+ x1 = self.branch1(x)
170
+ x2 = self.branch2(x)
171
+ x3 = self.branch3(x)
172
+ out = torch.cat((x0, x1, x2, x3), 1)
173
+ return out
174
+
175
+
176
+ class ReductionB(nn.Module):
177
+ def __init__(self):
178
+ super(ReductionB, self).__init__()
179
+
180
+ self.branch0 = nn.Sequential(
181
+ BasicConv2d(1024, 192, kernel_size=1, stride=1),
182
+ BasicConv2d(192, 192, kernel_size=3, stride=2)
183
+ )
184
+
185
+ self.branch1 = nn.Sequential(
186
+ BasicConv2d(1024, 256, kernel_size=1, stride=1),
187
+ BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
188
+ BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
189
+ BasicConv2d(320, 320, kernel_size=3, stride=2)
190
+ )
191
+
192
+ self.branch2 = nn.MaxPool2d(3, stride=2)
193
+
194
+ def forward(self, x):
195
+ x0 = self.branch0(x)
196
+ x1 = self.branch1(x)
197
+ x2 = self.branch2(x)
198
+ out = torch.cat((x0, x1, x2), 1)
199
+ return out
200
+
201
+
202
+ class InceptionC(nn.Module):
203
+ def __init__(self):
204
+ super(InceptionC, self).__init__()
205
+
206
+ self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
207
+
208
+ self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
209
+ self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
210
+ self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
211
+
212
+ self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
213
+ self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
214
+ self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
215
+ self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
216
+ self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
217
+
218
+ self.branch3 = nn.Sequential(
219
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
220
+ BasicConv2d(1536, 256, kernel_size=1, stride=1)
221
+ )
222
+
223
+ def forward(self, x):
224
+ x0 = self.branch0(x)
225
+
226
+ x1_0 = self.branch1_0(x)
227
+ x1_1a = self.branch1_1a(x1_0)
228
+ x1_1b = self.branch1_1b(x1_0)
229
+ x1 = torch.cat((x1_1a, x1_1b), 1)
230
+
231
+ x2_0 = self.branch2_0(x)
232
+ x2_1 = self.branch2_1(x2_0)
233
+ x2_2 = self.branch2_2(x2_1)
234
+ x2_3a = self.branch2_3a(x2_2)
235
+ x2_3b = self.branch2_3b(x2_2)
236
+ x2 = torch.cat((x2_3a, x2_3b), 1)
237
+
238
+ x3 = self.branch3(x)
239
+
240
+ out = torch.cat((x0, x1, x2, x3), 1)
241
+ return out
242
+
243
+
244
+ class InceptionV4(nn.Module):
245
+ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
246
+ super(InceptionV4, self).__init__()
247
+ assert output_stride == 32
248
+ self.drop_rate = drop_rate
249
+ self.num_classes = num_classes
250
+ self.num_features = 1536
251
+
252
+ self.features = nn.Sequential(
253
+ BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
254
+ BasicConv2d(32, 32, kernel_size=3, stride=1),
255
+ BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
256
+ Mixed3a(),
257
+ Mixed4a(),
258
+ Mixed5a(),
259
+ InceptionA(),
260
+ InceptionA(),
261
+ InceptionA(),
262
+ InceptionA(),
263
+ ReductionA(), # Mixed6a
264
+ InceptionB(),
265
+ InceptionB(),
266
+ InceptionB(),
267
+ InceptionB(),
268
+ InceptionB(),
269
+ InceptionB(),
270
+ InceptionB(),
271
+ ReductionB(), # Mixed7a
272
+ InceptionC(),
273
+ InceptionC(),
274
+ InceptionC(),
275
+ )
276
+ self.feature_info = [
277
+ dict(num_chs=64, reduction=2, module='features.2'),
278
+ dict(num_chs=160, reduction=4, module='features.3'),
279
+ dict(num_chs=384, reduction=8, module='features.9'),
280
+ dict(num_chs=1024, reduction=16, module='features.17'),
281
+ dict(num_chs=1536, reduction=32, module='features.21'),
282
+ ]
283
+ self.global_pool, self.last_linear = create_classifier(
284
+ self.num_features, self.num_classes, pool_type=global_pool)
285
+
286
+ @torch.jit.ignore
287
+ def group_matcher(self, coarse=False):
288
+ return dict(
289
+ stem=r'^features\.[012]\.',
290
+ blocks=r'^features\.(\d+)'
291
+ )
292
+
293
+ @torch.jit.ignore
294
+ def set_grad_checkpointing(self, enable=True):
295
+ assert not enable, 'gradient checkpointing not supported'
296
+
297
+ @torch.jit.ignore
298
+ def get_classifier(self):
299
+ return self.last_linear
300
+
301
+ def reset_classifier(self, num_classes, global_pool='avg'):
302
+ self.num_classes = num_classes
303
+ self.global_pool, self.last_linear = create_classifier(
304
+ self.num_features, self.num_classes, pool_type=global_pool)
305
+
306
+ def forward_features(self, x):
307
+ return self.features(x)
308
+
309
+ def forward_head(self, x, pre_logits: bool = False):
310
+ x = self.global_pool(x)
311
+ if self.drop_rate > 0:
312
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
313
+ return x if pre_logits else self.last_linear(x)
314
+
315
+ def forward(self, x):
316
+ x = self.forward_features(x)
317
+ x = self.forward_head(x)
318
+ return x
319
+
320
+
321
+ def _create_inception_v4(variant, pretrained=False, **kwargs):
322
+ return build_model_with_cfg(
323
+ InceptionV4, variant, pretrained,
324
+ feature_cfg=dict(flatten_sequential=True),
325
+ **kwargs)
326
+
327
+
328
+ @register_model
329
+ def inception_v4(pretrained=False, **kwargs):
330
+ return _create_inception_v4('inception_v4', pretrained, **kwargs)
src/custom_timm/models/levit.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ LeViT
2
+
3
+ Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
4
+ - https://arxiv.org/abs/2104.01136
5
+
6
+ @article{graham2021levit,
7
+ title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
8
+ author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
9
+ journal={arXiv preprint arXiv:22104.01136},
10
+ year={2021}
11
+ }
12
+
13
+ Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
14
+
15
+ This version combines both conv/linear models and fixes torchscript compatibility.
16
+
17
+ Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
18
+ """
19
+
20
+ # Copyright (c) 2015-present, Facebook, Inc.
21
+ # All rights reserved.
22
+
23
+ # Modified from
24
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
25
+ # Copyright 2020 Ross Wightman, Apache-2.0 License
26
+ import itertools
27
+ from copy import deepcopy
28
+ from functools import partial
29
+ from typing import Dict
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+
34
+ from custom_timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
35
+ from .helpers import build_model_with_cfg, checkpoint_seq
36
+ from .layers import to_ntuple, get_act_layer
37
+ from .vision_transformer import trunc_normal_
38
+ from .registry import register_model
39
+
40
+
41
+ def _cfg(url='', **kwargs):
42
+ return {
43
+ 'url': url,
44
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
45
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
46
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
47
+ 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
48
+ **kwargs
49
+ }
50
+
51
+
52
+ default_cfgs = dict(
53
+ levit_128s=_cfg(
54
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
55
+ ),
56
+ levit_128=_cfg(
57
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
58
+ ),
59
+ levit_192=_cfg(
60
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
61
+ ),
62
+ levit_256=_cfg(
63
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
64
+ ),
65
+ levit_384=_cfg(
66
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
67
+ ),
68
+
69
+ levit_256d=_cfg(url='', classifier='head.l'),
70
+ )
71
+
72
+ model_cfgs = dict(
73
+ levit_128s=dict(
74
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
75
+ levit_128=dict(
76
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
77
+ levit_192=dict(
78
+ embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
79
+ levit_256=dict(
80
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
81
+ levit_384=dict(
82
+ embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
83
+
84
+ levit_256d=dict(
85
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6)),
86
+ )
87
+
88
+ __all__ = ['Levit']
89
+
90
+
91
+ @register_model
92
+ def levit_128s(pretrained=False, use_conv=False, **kwargs):
93
+ return create_levit(
94
+ 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
95
+
96
+
97
+ @register_model
98
+ def levit_128(pretrained=False, use_conv=False, **kwargs):
99
+ return create_levit(
100
+ 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
101
+
102
+
103
+ @register_model
104
+ def levit_192(pretrained=False, use_conv=False, **kwargs):
105
+ return create_levit(
106
+ 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
107
+
108
+
109
+ @register_model
110
+ def levit_256(pretrained=False, use_conv=False, **kwargs):
111
+ return create_levit(
112
+ 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
113
+
114
+
115
+ @register_model
116
+ def levit_384(pretrained=False, use_conv=False, **kwargs):
117
+ return create_levit(
118
+ 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
119
+
120
+
121
+ @register_model
122
+ def levit_256d(pretrained=False, use_conv=False, **kwargs):
123
+ return create_levit(
124
+ 'levit_256d', pretrained=pretrained, use_conv=use_conv, distilled=False, **kwargs)
125
+
126
+
127
+ class ConvNorm(nn.Sequential):
128
+ def __init__(
129
+ self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
130
+ groups=1, bn_weight_init=1, resolution=-10000):
131
+ super().__init__()
132
+ self.add_module('c', nn.Conv2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
133
+ self.add_module('bn', nn.BatchNorm2d(out_chs))
134
+
135
+ nn.init.constant_(self.bn.weight, bn_weight_init)
136
+
137
+ @torch.no_grad()
138
+ def fuse(self):
139
+ c, bn = self._modules.values()
140
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
141
+ w = c.weight * w[:, None, None, None]
142
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
143
+ m = nn.Conv2d(
144
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
145
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
146
+ m.weight.data.copy_(w)
147
+ m.bias.data.copy_(b)
148
+ return m
149
+
150
+
151
+ class LinearNorm(nn.Sequential):
152
+ def __init__(self, in_features, out_features, bn_weight_init=1, resolution=-100000):
153
+ super().__init__()
154
+ self.add_module('c', nn.Linear(in_features, out_features, bias=False))
155
+ self.add_module('bn', nn.BatchNorm1d(out_features))
156
+
157
+ nn.init.constant_(self.bn.weight, bn_weight_init)
158
+
159
+ @torch.no_grad()
160
+ def fuse(self):
161
+ l, bn = self._modules.values()
162
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
163
+ w = l.weight * w[:, None]
164
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
165
+ m = nn.Linear(w.size(1), w.size(0))
166
+ m.weight.data.copy_(w)
167
+ m.bias.data.copy_(b)
168
+ return m
169
+
170
+ def forward(self, x):
171
+ x = self.c(x)
172
+ return self.bn(x.flatten(0, 1)).reshape_as(x)
173
+
174
+
175
+ class NormLinear(nn.Sequential):
176
+ def __init__(self, in_features, out_features, bias=True, std=0.02):
177
+ super().__init__()
178
+ self.add_module('bn', nn.BatchNorm1d(in_features))
179
+ self.add_module('l', nn.Linear(in_features, out_features, bias=bias))
180
+
181
+ trunc_normal_(self.l.weight, std=std)
182
+ if self.l.bias is not None:
183
+ nn.init.constant_(self.l.bias, 0)
184
+
185
+ @torch.no_grad()
186
+ def fuse(self):
187
+ bn, l = self._modules.values()
188
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
189
+ b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
190
+ w = l.weight * w[None, :]
191
+ if l.bias is None:
192
+ b = b @ self.l.weight.T
193
+ else:
194
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
195
+ m = nn.Linear(w.size(1), w.size(0))
196
+ m.weight.data.copy_(w)
197
+ m.bias.data.copy_(b)
198
+ return m
199
+
200
+
201
+ def stem_b16(in_chs, out_chs, activation, resolution=224):
202
+ return nn.Sequential(
203
+ ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
204
+ activation(),
205
+ ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
206
+ activation(),
207
+ ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
208
+ activation(),
209
+ ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
210
+
211
+
212
+ class Residual(nn.Module):
213
+ def __init__(self, m, drop):
214
+ super().__init__()
215
+ self.m = m
216
+ self.drop = drop
217
+
218
+ def forward(self, x):
219
+ if self.training and self.drop > 0:
220
+ return x + self.m(x) * torch.rand(
221
+ x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
222
+ else:
223
+ return x + self.m(x)
224
+
225
+
226
+ class Subsample(nn.Module):
227
+ def __init__(self, stride, resolution):
228
+ super().__init__()
229
+ self.stride = stride
230
+ self.resolution = resolution
231
+
232
+ def forward(self, x):
233
+ B, N, C = x.shape
234
+ x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
235
+ return x.reshape(B, -1, C)
236
+
237
+
238
+ class Attention(nn.Module):
239
+ ab: Dict[str, torch.Tensor]
240
+
241
+ def __init__(
242
+ self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
243
+ super().__init__()
244
+ ln_layer = ConvNorm if use_conv else LinearNorm
245
+ self.use_conv = use_conv
246
+ self.num_heads = num_heads
247
+ self.scale = key_dim ** -0.5
248
+ self.key_dim = key_dim
249
+ self.key_attn_dim = key_dim * num_heads
250
+ self.val_dim = int(attn_ratio * key_dim)
251
+ self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
252
+
253
+ self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, resolution=resolution)
254
+ self.proj = nn.Sequential(
255
+ act_layer(),
256
+ ln_layer(self.val_attn_dim, dim, bn_weight_init=0, resolution=resolution)
257
+ )
258
+
259
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution ** 2))
260
+ pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1)
261
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
262
+ rel_pos = (rel_pos[0] * resolution) + rel_pos[1]
263
+ self.register_buffer('attention_bias_idxs', rel_pos)
264
+ self.ab = {}
265
+
266
+ @torch.no_grad()
267
+ def train(self, mode=True):
268
+ super().train(mode)
269
+ if mode and self.ab:
270
+ self.ab = {} # clear ab cache
271
+
272
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
273
+ if self.training:
274
+ return self.attention_biases[:, self.attention_bias_idxs]
275
+ else:
276
+ device_key = str(device)
277
+ if device_key not in self.ab:
278
+ self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
279
+ return self.ab[device_key]
280
+
281
+ def forward(self, x): # x (B,C,H,W)
282
+ if self.use_conv:
283
+ B, C, H, W = x.shape
284
+ q, k, v = self.qkv(x).view(
285
+ B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
286
+
287
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
288
+ attn = attn.softmax(dim=-1)
289
+
290
+ x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
291
+ else:
292
+ B, N, C = x.shape
293
+ q, k, v = self.qkv(x).view(
294
+ B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
295
+ q = q.permute(0, 2, 1, 3)
296
+ k = k.permute(0, 2, 3, 1)
297
+ v = v.permute(0, 2, 1, 3)
298
+
299
+ attn = q @ k * self.scale + self.get_attention_biases(x.device)
300
+ attn = attn.softmax(dim=-1)
301
+
302
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
303
+ x = self.proj(x)
304
+ return x
305
+
306
+
307
+ class AttentionSubsample(nn.Module):
308
+ ab: Dict[str, torch.Tensor]
309
+
310
+ def __init__(
311
+ self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
312
+ act_layer=None, stride=2, resolution=14, resolution_out=7, use_conv=False):
313
+ super().__init__()
314
+ self.stride = stride
315
+ self.num_heads = num_heads
316
+ self.scale = key_dim ** -0.5
317
+ self.key_dim = key_dim
318
+ self.key_attn_dim = key_dim * num_heads
319
+ self.val_dim = int(attn_ratio * key_dim)
320
+ self.val_attn_dim = self.val_dim * self.num_heads
321
+ self.resolution = resolution
322
+ self.resolution_out_area = resolution_out ** 2
323
+
324
+ self.use_conv = use_conv
325
+ if self.use_conv:
326
+ ln_layer = ConvNorm
327
+ sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
328
+ else:
329
+ ln_layer = LinearNorm
330
+ sub_layer = partial(Subsample, resolution=resolution)
331
+
332
+ self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, resolution=resolution)
333
+ self.q = nn.Sequential(
334
+ sub_layer(stride=stride),
335
+ ln_layer(in_dim, self.key_attn_dim, resolution=resolution_out)
336
+ )
337
+ self.proj = nn.Sequential(
338
+ act_layer(),
339
+ ln_layer(self.val_attn_dim, out_dim, resolution=resolution_out)
340
+ )
341
+
342
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.resolution ** 2))
343
+ k_pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1)
344
+ q_pos = torch.stack(torch.meshgrid(
345
+ torch.arange(0, resolution, step=stride),
346
+ torch.arange(0, resolution, step=stride))).flatten(1)
347
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
348
+ rel_pos = (rel_pos[0] * resolution) + rel_pos[1]
349
+ self.register_buffer('attention_bias_idxs', rel_pos)
350
+
351
+ self.ab = {} # per-device attention_biases cache
352
+
353
+ @torch.no_grad()
354
+ def train(self, mode=True):
355
+ super().train(mode)
356
+ if mode and self.ab:
357
+ self.ab = {} # clear ab cache
358
+
359
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
360
+ if self.training:
361
+ return self.attention_biases[:, self.attention_bias_idxs]
362
+ else:
363
+ device_key = str(device)
364
+ if device_key not in self.ab:
365
+ self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
366
+ return self.ab[device_key]
367
+
368
+ def forward(self, x):
369
+ if self.use_conv:
370
+ B, C, H, W = x.shape
371
+ k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
372
+ q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_out_area)
373
+
374
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
375
+ attn = attn.softmax(dim=-1)
376
+
377
+ x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution, self.resolution)
378
+ else:
379
+ B, N, C = x.shape
380
+ k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
381
+ k = k.permute(0, 2, 3, 1) # BHCN
382
+ v = v.permute(0, 2, 1, 3) # BHNC
383
+ q = self.q(x).view(B, self.resolution_out_area, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
384
+
385
+ attn = q @ k * self.scale + self.get_attention_biases(x.device)
386
+ attn = attn.softmax(dim=-1)
387
+
388
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
389
+ x = self.proj(x)
390
+ return x
391
+
392
+
393
+ class Levit(nn.Module):
394
+ """ Vision Transformer with support for patch or hybrid CNN input stage
395
+
396
+ NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
397
+ w/ train scripts that don't take tuple outputs,
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ img_size=224,
403
+ patch_size=16,
404
+ in_chans=3,
405
+ num_classes=1000,
406
+ embed_dim=(192,),
407
+ key_dim=64,
408
+ depth=(12,),
409
+ num_heads=(3,),
410
+ attn_ratio=2,
411
+ mlp_ratio=2,
412
+ hybrid_backbone=None,
413
+ down_ops=None,
414
+ act_layer='hard_swish',
415
+ attn_act_layer='hard_swish',
416
+ use_conv=False,
417
+ global_pool='avg',
418
+ drop_rate=0.,
419
+ drop_path_rate=0.):
420
+ super().__init__()
421
+ act_layer = get_act_layer(act_layer)
422
+ attn_act_layer = get_act_layer(attn_act_layer)
423
+ ln_layer = ConvNorm if use_conv else LinearNorm
424
+ self.use_conv = use_conv
425
+ if isinstance(img_size, tuple):
426
+ # FIXME origin impl passes single img/res dim through whole hierarchy,
427
+ # not sure this model will be used enough to spend time fixing it.
428
+ assert img_size[0] == img_size[1]
429
+ img_size = img_size[0]
430
+ self.num_classes = num_classes
431
+ self.global_pool = global_pool
432
+ self.num_features = embed_dim[-1]
433
+ self.embed_dim = embed_dim
434
+ self.grad_checkpointing = False
435
+
436
+ num_stages = len(embed_dim)
437
+ assert len(depth) == len(num_heads) == num_stages
438
+ key_dim = to_ntuple(num_stages)(key_dim)
439
+ attn_ratio = to_ntuple(num_stages)(attn_ratio)
440
+ mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
441
+ down_ops = down_ops or (
442
+ # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
443
+ ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
444
+ ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
445
+ ('',)
446
+ )
447
+
448
+ self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
449
+
450
+ self.blocks = []
451
+ resolution = img_size // patch_size
452
+ for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
453
+ zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
454
+ for _ in range(dpth):
455
+ self.blocks.append(
456
+ Residual(
457
+ Attention(
458
+ ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
459
+ resolution=resolution, use_conv=use_conv),
460
+ drop_path_rate))
461
+ if mr > 0:
462
+ h = int(ed * mr)
463
+ self.blocks.append(
464
+ Residual(nn.Sequential(
465
+ ln_layer(ed, h, resolution=resolution),
466
+ act_layer(),
467
+ ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
468
+ ), drop_path_rate))
469
+ if do[0] == 'Subsample':
470
+ # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
471
+ resolution_out = (resolution - 1) // do[5] + 1
472
+ self.blocks.append(
473
+ AttentionSubsample(
474
+ *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
475
+ attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
476
+ resolution=resolution, resolution_out=resolution_out, use_conv=use_conv))
477
+ resolution = resolution_out
478
+ if do[4] > 0: # mlp_ratio
479
+ h = int(embed_dim[i + 1] * do[4])
480
+ self.blocks.append(
481
+ Residual(nn.Sequential(
482
+ ln_layer(embed_dim[i + 1], h, resolution=resolution),
483
+ act_layer(),
484
+ ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
485
+ ), drop_path_rate))
486
+ self.blocks = nn.Sequential(*self.blocks)
487
+
488
+ # Classifier head
489
+ self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
490
+
491
+ @torch.jit.ignore
492
+ def no_weight_decay(self):
493
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
494
+
495
+ @torch.jit.ignore
496
+ def group_matcher(self, coarse=False):
497
+ matcher = dict(
498
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
499
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
500
+ )
501
+ return matcher
502
+
503
+ @torch.jit.ignore
504
+ def set_grad_checkpointing(self, enable=True):
505
+ self.grad_checkpointing = enable
506
+
507
+ @torch.jit.ignore
508
+ def get_classifier(self):
509
+ return self.head
510
+
511
+ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
512
+ self.num_classes = num_classes
513
+ if global_pool is not None:
514
+ self.global_pool = global_pool
515
+ self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
516
+
517
+ def forward_features(self, x):
518
+ x = self.patch_embed(x)
519
+ if not self.use_conv:
520
+ x = x.flatten(2).transpose(1, 2)
521
+ if self.grad_checkpointing and not torch.jit.is_scripting():
522
+ x = checkpoint_seq(self.blocks, x)
523
+ else:
524
+ x = self.blocks(x)
525
+ return x
526
+
527
+ def forward_head(self, x, pre_logits: bool = False):
528
+ if self.global_pool == 'avg':
529
+ x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
530
+ return x if pre_logits else self.head(x)
531
+
532
+ def forward(self, x):
533
+ x = self.forward_features(x)
534
+ x = self.forward_head(x)
535
+ return x
536
+
537
+
538
+ class LevitDistilled(Levit):
539
+ def __init__(self, *args, **kwargs):
540
+ super().__init__(*args, **kwargs)
541
+ self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
542
+ self.distilled_training = False # must set this True to train w/ distillation token
543
+
544
+ @torch.jit.ignore
545
+ def get_classifier(self):
546
+ return self.head, self.head_dist
547
+
548
+ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
549
+ self.num_classes = num_classes
550
+ if global_pool is not None:
551
+ self.global_pool = global_pool
552
+ self.head = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
553
+ self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
554
+
555
+ @torch.jit.ignore
556
+ def set_distilled_training(self, enable=True):
557
+ self.distilled_training = enable
558
+
559
+ def forward_head(self, x):
560
+ if self.global_pool == 'avg':
561
+ x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
562
+ x, x_dist = self.head(x), self.head_dist(x)
563
+ if self.distilled_training and self.training and not torch.jit.is_scripting():
564
+ # only return separate classification predictions when training in distilled mode
565
+ return x, x_dist
566
+ else:
567
+ # during standard train/finetune, inference average the classifier predictions
568
+ return (x + x_dist) / 2
569
+
570
+
571
+ def checkpoint_filter_fn(state_dict, model):
572
+ if 'model' in state_dict:
573
+ # For deit models
574
+ state_dict = state_dict['model']
575
+ D = model.state_dict()
576
+ for k in state_dict.keys():
577
+ if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
578
+ state_dict[k] = state_dict[k][:, :, None, None]
579
+ return state_dict
580
+
581
+
582
+ def create_levit(variant, pretrained=False, distilled=True, **kwargs):
583
+ if kwargs.get('features_only', None):
584
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
585
+
586
+ model_cfg = dict(**model_cfgs[variant], **kwargs)
587
+ model = build_model_with_cfg(
588
+ LevitDistilled if distilled else Levit, variant, pretrained,
589
+ pretrained_filter_fn=checkpoint_filter_fn,
590
+ **model_cfg)
591
+ return model
592
+
src/custom_timm/optim/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .adabelief import AdaBelief
2
+ from .adafactor import Adafactor
3
+ from .adahessian import Adahessian
4
+ from .adamp import AdamP
5
+ from .adamw import AdamW
6
+ from .lamb import Lamb
7
+ from .lars import Lars
8
+ from .lookahead import Lookahead
9
+ from .madgrad import MADGRAD
10
+ from .nadam import Nadam
11
+ from .nvnovograd import NvNovoGrad
12
+ from .radam import RAdam
13
+ from .rmsprop_tf import RMSpropTF
14
+ from .sgdp import SGDP
15
+ from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
src/custom_timm/optim/adabelief.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.optim.optimizer import Optimizer
4
+
5
+
6
+ class AdaBelief(Optimizer):
7
+ r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
8
+
9
+ Arguments:
10
+ params (iterable): iterable of parameters to optimize or dicts defining
11
+ parameter groups
12
+ lr (float, optional): learning rate (default: 1e-3)
13
+ betas (Tuple[float, float], optional): coefficients used for computing
14
+ running averages of gradient and its square (default: (0.9, 0.999))
15
+ eps (float, optional): term added to the denominator to improve
16
+ numerical stability (default: 1e-16)
17
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
18
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
19
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
20
+ (default: False)
21
+ decoupled_decay (boolean, optional): (default: True) If set as True, then
22
+ the optimizer uses decoupled weight decay as in AdamW
23
+ fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
24
+ is set as True.
25
+ When fixed_decay == True, the weight decay is performed as
26
+ $W_{new} = W_{old} - W_{old} \times decay$.
27
+ When fixed_decay == False, the weight decay is performed as
28
+ $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
29
+ weight decay ratio decreases with learning rate (lr).
30
+ rectify (boolean, optional): (default: True) If set as True, then perform the rectified
31
+ update similar to RAdam
32
+ degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
33
+ when variance of gradient is high
34
+ reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
35
+
36
+ For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
37
+ For example train/args for EfficientNet see these gists
38
+ - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
39
+ - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
40
+ """
41
+
42
+ def __init__(
43
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
44
+ decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
45
+
46
+ if not 0.0 <= lr:
47
+ raise ValueError("Invalid learning rate: {}".format(lr))
48
+ if not 0.0 <= eps:
49
+ raise ValueError("Invalid epsilon value: {}".format(eps))
50
+ if not 0.0 <= betas[0] < 1.0:
51
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
52
+ if not 0.0 <= betas[1] < 1.0:
53
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
54
+
55
+ if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
56
+ for param in params:
57
+ if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
58
+ param['buffer'] = [[None, None, None] for _ in range(10)]
59
+
60
+ defaults = dict(
61
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
62
+ degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
63
+ fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
64
+ super(AdaBelief, self).__init__(params, defaults)
65
+
66
+ def __setstate__(self, state):
67
+ super(AdaBelief, self).__setstate__(state)
68
+ for group in self.param_groups:
69
+ group.setdefault('amsgrad', False)
70
+
71
+ @torch.no_grad()
72
+ def reset(self):
73
+ for group in self.param_groups:
74
+ for p in group['params']:
75
+ state = self.state[p]
76
+ amsgrad = group['amsgrad']
77
+
78
+ # State initialization
79
+ state['step'] = 0
80
+ # Exponential moving average of gradient values
81
+ state['exp_avg'] = torch.zeros_like(p)
82
+
83
+ # Exponential moving average of squared gradient values
84
+ state['exp_avg_var'] = torch.zeros_like(p)
85
+ if amsgrad:
86
+ # Maintains max of all exp. moving avg. of sq. grad. values
87
+ state['max_exp_avg_var'] = torch.zeros_like(p)
88
+
89
+ @torch.no_grad()
90
+ def step(self, closure=None):
91
+ """Performs a single optimization step.
92
+ Arguments:
93
+ closure (callable, optional): A closure that reevaluates the model
94
+ and returns the loss.
95
+ """
96
+ loss = None
97
+ if closure is not None:
98
+ with torch.enable_grad():
99
+ loss = closure()
100
+
101
+ for group in self.param_groups:
102
+ for p in group['params']:
103
+ if p.grad is None:
104
+ continue
105
+ grad = p.grad
106
+ if grad.dtype in {torch.float16, torch.bfloat16}:
107
+ grad = grad.float()
108
+ if grad.is_sparse:
109
+ raise RuntimeError(
110
+ 'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
111
+
112
+ p_fp32 = p
113
+ if p.dtype in {torch.float16, torch.bfloat16}:
114
+ p_fp32 = p_fp32.float()
115
+
116
+ amsgrad = group['amsgrad']
117
+ beta1, beta2 = group['betas']
118
+ state = self.state[p]
119
+ # State initialization
120
+ if len(state) == 0:
121
+ state['step'] = 0
122
+ # Exponential moving average of gradient values
123
+ state['exp_avg'] = torch.zeros_like(p_fp32)
124
+ # Exponential moving average of squared gradient values
125
+ state['exp_avg_var'] = torch.zeros_like(p_fp32)
126
+ if amsgrad:
127
+ # Maintains max of all exp. moving avg. of sq. grad. values
128
+ state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
129
+
130
+ # perform weight decay, check if decoupled weight decay
131
+ if group['decoupled_decay']:
132
+ if not group['fixed_decay']:
133
+ p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
134
+ else:
135
+ p_fp32.mul_(1.0 - group['weight_decay'])
136
+ else:
137
+ if group['weight_decay'] != 0:
138
+ grad.add_(p_fp32, alpha=group['weight_decay'])
139
+
140
+ # get current state variable
141
+ exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
142
+
143
+ state['step'] += 1
144
+ bias_correction1 = 1 - beta1 ** state['step']
145
+ bias_correction2 = 1 - beta2 ** state['step']
146
+
147
+ # Update first and second moment running average
148
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
149
+ grad_residual = grad - exp_avg
150
+ exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
151
+
152
+ if amsgrad:
153
+ max_exp_avg_var = state['max_exp_avg_var']
154
+ # Maintains the maximum of all 2nd moment running avg. till now
155
+ torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var)
156
+
157
+ # Use the max. for normalizing running avg. of gradient
158
+ denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
159
+ else:
160
+ denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
161
+
162
+ # update
163
+ if not group['rectify']:
164
+ # Default update
165
+ step_size = group['lr'] / bias_correction1
166
+ p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
167
+ else:
168
+ # Rectified update, forked from RAdam
169
+ buffered = group['buffer'][int(state['step'] % 10)]
170
+ if state['step'] == buffered[0]:
171
+ num_sma, step_size = buffered[1], buffered[2]
172
+ else:
173
+ buffered[0] = state['step']
174
+ beta2_t = beta2 ** state['step']
175
+ num_sma_max = 2 / (1 - beta2) - 1
176
+ num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
177
+ buffered[1] = num_sma
178
+
179
+ # more conservative since it's an approximated value
180
+ if num_sma >= 5:
181
+ step_size = math.sqrt(
182
+ (1 - beta2_t) *
183
+ (num_sma - 4) / (num_sma_max - 4) *
184
+ (num_sma - 2) / num_sma *
185
+ num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
186
+ elif group['degenerated_to_sgd']:
187
+ step_size = 1.0 / (1 - beta1 ** state['step'])
188
+ else:
189
+ step_size = -1
190
+ buffered[2] = step_size
191
+
192
+ if num_sma >= 5:
193
+ denom = exp_avg_var.sqrt().add_(group['eps'])
194
+ p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
195
+ elif step_size > 0:
196
+ p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
197
+
198
+ if p.dtype in {torch.float16, torch.bfloat16}:
199
+ p.copy_(p_fp32)
200
+
201
+ return loss
src/custom_timm/optim/adafactor.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adafactor Optimizer
2
+
3
+ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4
+
5
+ Original header/copyright below.
6
+
7
+ """
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ #
10
+ # This source code is licensed under the MIT license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ import torch
13
+ import math
14
+
15
+
16
+ class Adafactor(torch.optim.Optimizer):
17
+ """Implements Adafactor algorithm.
18
+ This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19
+ (see https://arxiv.org/abs/1804.04235)
20
+
21
+ Note that this optimizer internally adjusts the learning rate depending on the
22
+ *scale_parameter*, *relative_step* and *warmup_init* options.
23
+
24
+ To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25
+ `relative_step=False`.
26
+
27
+ Arguments:
28
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29
+ lr (float, optional): external learning rate (default: None)
30
+ eps (tuple[float, float]): regularization constants for square gradient
31
+ and parameter scale respectively (default: (1e-30, 1e-3))
32
+ clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33
+ decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34
+ beta1 (float): coefficient used for computing running averages of gradient (default: None)
35
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36
+ scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37
+ warmup_init (bool): time-dependent learning rate computation depends on
38
+ whether warm-up initialization is being used (default: False)
39
+ """
40
+
41
+ def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
42
+ decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
43
+ relative_step = not lr
44
+ if warmup_init and not relative_step:
45
+ raise ValueError('warmup_init requires relative_step=True')
46
+
47
+ beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
48
+ defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
49
+ beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
50
+ relative_step=relative_step, warmup_init=warmup_init)
51
+ super(Adafactor, self).__init__(params, defaults)
52
+
53
+ @staticmethod
54
+ def _get_lr(param_group, param_state):
55
+ if param_group['relative_step']:
56
+ min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
57
+ lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
58
+ param_scale = 1.0
59
+ if param_group['scale_parameter']:
60
+ param_scale = max(param_group['eps_scale'], param_state['RMS'])
61
+ param_group['lr'] = lr_t * param_scale
62
+ return param_group['lr']
63
+
64
+ @staticmethod
65
+ def _get_options(param_group, param_shape):
66
+ factored = len(param_shape) >= 2
67
+ use_first_moment = param_group['beta1'] is not None
68
+ return factored, use_first_moment
69
+
70
+ @staticmethod
71
+ def _rms(tensor):
72
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
73
+
74
+ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
75
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
76
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
77
+ return torch.mul(r_factor, c_factor)
78
+
79
+ @torch.no_grad()
80
+ def step(self, closure=None):
81
+ """Performs a single optimization step.
82
+ Arguments:
83
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
84
+ """
85
+ loss = None
86
+ if closure is not None:
87
+ with torch.enable_grad():
88
+ loss = closure()
89
+
90
+ for group in self.param_groups:
91
+ for p in group['params']:
92
+ if p.grad is None:
93
+ continue
94
+ grad = p.grad
95
+ if grad.dtype in {torch.float16, torch.bfloat16}:
96
+ grad = grad.float()
97
+ if grad.is_sparse:
98
+ raise RuntimeError('Adafactor does not support sparse gradients.')
99
+
100
+ state = self.state[p]
101
+
102
+ factored, use_first_moment = self._get_options(group, grad.shape)
103
+ # State Initialization
104
+ if len(state) == 0:
105
+ state['step'] = 0
106
+
107
+ if use_first_moment:
108
+ # Exponential moving average of gradient values
109
+ state['exp_avg'] = torch.zeros_like(grad)
110
+ if factored:
111
+ state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
112
+ state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
113
+ else:
114
+ state['exp_avg_sq'] = torch.zeros_like(grad)
115
+
116
+ state['RMS'] = 0
117
+ else:
118
+ if use_first_moment:
119
+ state['exp_avg'] = state['exp_avg'].to(grad)
120
+ if factored:
121
+ state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
122
+ state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
123
+ else:
124
+ state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
125
+
126
+ p_fp32 = p
127
+ if p.dtype in {torch.float16, torch.bfloat16}:
128
+ p_fp32 = p_fp32.float()
129
+
130
+ state['step'] += 1
131
+ state['RMS'] = self._rms(p_fp32)
132
+ lr_t = self._get_lr(group, state)
133
+
134
+ beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
135
+ update = grad ** 2 + group['eps']
136
+ if factored:
137
+ exp_avg_sq_row = state['exp_avg_sq_row']
138
+ exp_avg_sq_col = state['exp_avg_sq_col']
139
+
140
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
141
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
142
+
143
+ # Approximation of exponential moving average of square of gradient
144
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
145
+ update.mul_(grad)
146
+ else:
147
+ exp_avg_sq = state['exp_avg_sq']
148
+
149
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
150
+ update = exp_avg_sq.rsqrt().mul_(grad)
151
+
152
+ update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
153
+ update.mul_(lr_t)
154
+
155
+ if use_first_moment:
156
+ exp_avg = state['exp_avg']
157
+ exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
158
+ update = exp_avg
159
+
160
+ if group['weight_decay'] != 0:
161
+ p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
162
+
163
+ p_fp32.add_(-update)
164
+ if p.dtype in {torch.float16, torch.bfloat16}:
165
+ p.copy_(p_fp32)
166
+
167
+ return loss
src/custom_timm/optim/adahessian.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdaHessian Optimizer
2
+
3
+ Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
4
+ Originally licensed MIT, Copyright 2020, David Samuel
5
+ """
6
+ import torch
7
+
8
+
9
+ class Adahessian(torch.optim.Optimizer):
10
+ """
11
+ Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
12
+
13
+ Arguments:
14
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
15
+ lr (float, optional): learning rate (default: 0.1)
16
+ betas ((float, float), optional): coefficients used for computing running averages of gradient and the
17
+ squared hessian trace (default: (0.9, 0.999))
18
+ eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
19
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
20
+ hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
21
+ update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
22
+ (to save time) (default: 1)
23
+ n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
24
+ """
25
+
26
+ def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
27
+ hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
28
+ if not 0.0 <= lr:
29
+ raise ValueError(f"Invalid learning rate: {lr}")
30
+ if not 0.0 <= eps:
31
+ raise ValueError(f"Invalid epsilon value: {eps}")
32
+ if not 0.0 <= betas[0] < 1.0:
33
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
34
+ if not 0.0 <= betas[1] < 1.0:
35
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
36
+ if not 0.0 <= hessian_power <= 1.0:
37
+ raise ValueError(f"Invalid Hessian power value: {hessian_power}")
38
+
39
+ self.n_samples = n_samples
40
+ self.update_each = update_each
41
+ self.avg_conv_kernel = avg_conv_kernel
42
+
43
+ # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
44
+ self.seed = 2147483647
45
+ self.generator = torch.Generator().manual_seed(self.seed)
46
+
47
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
48
+ super(Adahessian, self).__init__(params, defaults)
49
+
50
+ for p in self.get_params():
51
+ p.hess = 0.0
52
+ self.state[p]["hessian step"] = 0
53
+
54
+ @property
55
+ def is_second_order(self):
56
+ return True
57
+
58
+ def get_params(self):
59
+ """
60
+ Gets all parameters in all param_groups with gradients
61
+ """
62
+
63
+ return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
64
+
65
+ def zero_hessian(self):
66
+ """
67
+ Zeros out the accumalated hessian traces.
68
+ """
69
+
70
+ for p in self.get_params():
71
+ if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
72
+ p.hess.zero_()
73
+
74
+ @torch.no_grad()
75
+ def set_hessian(self):
76
+ """
77
+ Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
78
+ """
79
+
80
+ params = []
81
+ for p in filter(lambda p: p.grad is not None, self.get_params()):
82
+ if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
83
+ params.append(p)
84
+ self.state[p]["hessian step"] += 1
85
+
86
+ if len(params) == 0:
87
+ return
88
+
89
+ if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
90
+ self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
91
+
92
+ grads = [p.grad for p in params]
93
+
94
+ for i in range(self.n_samples):
95
+ # Rademacher distribution {-1.0, 1.0}
96
+ zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
97
+ h_zs = torch.autograd.grad(
98
+ grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
99
+ for h_z, z, p in zip(h_zs, zs, params):
100
+ p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
101
+
102
+ @torch.no_grad()
103
+ def step(self, closure=None):
104
+ """
105
+ Performs a single optimization step.
106
+ Arguments:
107
+ closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
108
+ """
109
+
110
+ loss = None
111
+ if closure is not None:
112
+ loss = closure()
113
+
114
+ self.zero_hessian()
115
+ self.set_hessian()
116
+
117
+ for group in self.param_groups:
118
+ for p in group['params']:
119
+ if p.grad is None or p.hess is None:
120
+ continue
121
+
122
+ if self.avg_conv_kernel and p.dim() == 4:
123
+ p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
124
+
125
+ # Perform correct stepweight decay as in AdamW
126
+ p.mul_(1 - group['lr'] * group['weight_decay'])
127
+
128
+ state = self.state[p]
129
+
130
+ # State initialization
131
+ if len(state) == 1:
132
+ state['step'] = 0
133
+ # Exponential moving average of gradient values
134
+ state['exp_avg'] = torch.zeros_like(p)
135
+ # Exponential moving average of Hessian diagonal square values
136
+ state['exp_hessian_diag_sq'] = torch.zeros_like(p)
137
+
138
+ exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
139
+ beta1, beta2 = group['betas']
140
+ state['step'] += 1
141
+
142
+ # Decay the first and second moment running average coefficient
143
+ exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
144
+ exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
145
+
146
+ bias_correction1 = 1 - beta1 ** state['step']
147
+ bias_correction2 = 1 - beta2 ** state['step']
148
+
149
+ k = group['hessian_power']
150
+ denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
151
+
152
+ # make update
153
+ step_size = group['lr'] / bias_correction1
154
+ p.addcdiv_(exp_avg, denom, value=-step_size)
155
+
156
+ return loss
src/custom_timm/optim/adamp.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3
+
4
+ Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5
+ Code: https://github.com/clovaai/AdamP
6
+
7
+ Copyright (c) 2020-present NAVER Corp.
8
+ MIT license
9
+ """
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.optim.optimizer import Optimizer
14
+ import math
15
+
16
+
17
+ def _channel_view(x) -> torch.Tensor:
18
+ return x.reshape(x.size(0), -1)
19
+
20
+
21
+ def _layer_view(x) -> torch.Tensor:
22
+ return x.reshape(1, -1)
23
+
24
+
25
+ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
26
+ wd = 1.
27
+ expand_size = (-1,) + (1,) * (len(p.shape) - 1)
28
+ for view_func in [_channel_view, _layer_view]:
29
+ param_view = view_func(p)
30
+ grad_view = view_func(grad)
31
+ cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
32
+
33
+ # FIXME this is a problem for PyTorch XLA
34
+ if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
35
+ p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
36
+ perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
37
+ wd = wd_ratio
38
+ return perturb, wd
39
+
40
+ return perturb, wd
41
+
42
+
43
+ class AdamP(Optimizer):
44
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
45
+ weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
46
+ defaults = dict(
47
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
48
+ delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
49
+ super(AdamP, self).__init__(params, defaults)
50
+
51
+ @torch.no_grad()
52
+ def step(self, closure=None):
53
+ loss = None
54
+ if closure is not None:
55
+ with torch.enable_grad():
56
+ loss = closure()
57
+
58
+ for group in self.param_groups:
59
+ for p in group['params']:
60
+ if p.grad is None:
61
+ continue
62
+
63
+ grad = p.grad
64
+ beta1, beta2 = group['betas']
65
+ nesterov = group['nesterov']
66
+
67
+ state = self.state[p]
68
+
69
+ # State initialization
70
+ if len(state) == 0:
71
+ state['step'] = 0
72
+ state['exp_avg'] = torch.zeros_like(p)
73
+ state['exp_avg_sq'] = torch.zeros_like(p)
74
+
75
+ # Adam
76
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
77
+
78
+ state['step'] += 1
79
+ bias_correction1 = 1 - beta1 ** state['step']
80
+ bias_correction2 = 1 - beta2 ** state['step']
81
+
82
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
83
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
84
+
85
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
86
+ step_size = group['lr'] / bias_correction1
87
+
88
+ if nesterov:
89
+ perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
90
+ else:
91
+ perturb = exp_avg / denom
92
+
93
+ # Projection
94
+ wd_ratio = 1.
95
+ if len(p.shape) > 1:
96
+ perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
97
+
98
+ # Weight decay
99
+ if group['weight_decay'] > 0:
100
+ p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
101
+
102
+ # Step
103
+ p.add_(perturb, alpha=-step_size)
104
+
105
+ return loss
src/custom_timm/optim/adamw.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AdamW Optimizer
2
+ Impl copied from PyTorch master
3
+
4
+ NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
5
+ someday
6
+ """
7
+ import math
8
+ import torch
9
+ from torch.optim.optimizer import Optimizer
10
+
11
+
12
+ class AdamW(Optimizer):
13
+ r"""Implements AdamW algorithm.
14
+
15
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
16
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
17
+
18
+ Arguments:
19
+ params (iterable): iterable of parameters to optimize or dicts defining
20
+ parameter groups
21
+ lr (float, optional): learning rate (default: 1e-3)
22
+ betas (Tuple[float, float], optional): coefficients used for computing
23
+ running averages of gradient and its square (default: (0.9, 0.999))
24
+ eps (float, optional): term added to the denominator to improve
25
+ numerical stability (default: 1e-8)
26
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
27
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
29
+ (default: False)
30
+
31
+ .. _Adam\: A Method for Stochastic Optimization:
32
+ https://arxiv.org/abs/1412.6980
33
+ .. _Decoupled Weight Decay Regularization:
34
+ https://arxiv.org/abs/1711.05101
35
+ .. _On the Convergence of Adam and Beyond:
36
+ https://openreview.net/forum?id=ryQu7f-RZ
37
+ """
38
+
39
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
40
+ weight_decay=1e-2, amsgrad=False):
41
+ if not 0.0 <= lr:
42
+ raise ValueError("Invalid learning rate: {}".format(lr))
43
+ if not 0.0 <= eps:
44
+ raise ValueError("Invalid epsilon value: {}".format(eps))
45
+ if not 0.0 <= betas[0] < 1.0:
46
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
47
+ if not 0.0 <= betas[1] < 1.0:
48
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
49
+ defaults = dict(lr=lr, betas=betas, eps=eps,
50
+ weight_decay=weight_decay, amsgrad=amsgrad)
51
+ super(AdamW, self).__init__(params, defaults)
52
+
53
+ def __setstate__(self, state):
54
+ super(AdamW, self).__setstate__(state)
55
+ for group in self.param_groups:
56
+ group.setdefault('amsgrad', False)
57
+
58
+ @torch.no_grad()
59
+ def step(self, closure=None):
60
+ """Performs a single optimization step.
61
+
62
+ Arguments:
63
+ closure (callable, optional): A closure that reevaluates the model
64
+ and returns the loss.
65
+ """
66
+ loss = None
67
+ if closure is not None:
68
+ with torch.enable_grad():
69
+ loss = closure()
70
+
71
+ for group in self.param_groups:
72
+ for p in group['params']:
73
+ if p.grad is None:
74
+ continue
75
+
76
+ # Perform stepweight decay
77
+ p.data.mul_(1 - group['lr'] * group['weight_decay'])
78
+
79
+ # Perform optimization step
80
+ grad = p.grad
81
+ if grad.is_sparse:
82
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
83
+ amsgrad = group['amsgrad']
84
+
85
+ state = self.state[p]
86
+
87
+ # State initialization
88
+ if len(state) == 0:
89
+ state['step'] = 0
90
+ # Exponential moving average of gradient values
91
+ state['exp_avg'] = torch.zeros_like(p)
92
+ # Exponential moving average of squared gradient values
93
+ state['exp_avg_sq'] = torch.zeros_like(p)
94
+ if amsgrad:
95
+ # Maintains max of all exp. moving avg. of sq. grad. values
96
+ state['max_exp_avg_sq'] = torch.zeros_like(p)
97
+
98
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
99
+ if amsgrad:
100
+ max_exp_avg_sq = state['max_exp_avg_sq']
101
+ beta1, beta2 = group['betas']
102
+
103
+ state['step'] += 1
104
+ bias_correction1 = 1 - beta1 ** state['step']
105
+ bias_correction2 = 1 - beta2 ** state['step']
106
+
107
+ # Decay the first and second moment running average coefficient
108
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
109
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
110
+ if amsgrad:
111
+ # Maintains the maximum of all 2nd moment running avg. till now
112
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
113
+ # Use the max. for normalizing running avg. of gradient
114
+ denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
115
+ else:
116
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
117
+
118
+ step_size = group['lr'] / bias_correction1
119
+
120
+ p.addcdiv_(exp_avg, denom, value=-step_size)
121
+
122
+ return loss
src/custom_timm/optim/lamb.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
2
+
3
+ This optimizer code was adapted from the following (starting with latest)
4
+ * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
5
+ * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
6
+ * https://github.com/cybertronai/pytorch-lamb
7
+
8
+ Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
9
+ similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
10
+
11
+ In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
12
+
13
+ Original copyrights for above sources are below.
14
+
15
+ Modifications Copyright 2021 Ross Wightman
16
+ """
17
+ # Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
18
+
19
+ # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
20
+ #
21
+ # Licensed under the Apache License, Version 2.0 (the "License");
22
+ # you may not use this file except in compliance with the License.
23
+ # You may obtain a copy of the License at
24
+ #
25
+ # http://www.apache.org/licenses/LICENSE-2.0
26
+ #
27
+ # Unless required by applicable law or agreed to in writing, software
28
+ # distributed under the License is distributed on an "AS IS" BASIS,
29
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
+ # See the License for the specific language governing permissions and
31
+ # limitations under the License.
32
+
33
+ # MIT License
34
+ #
35
+ # Copyright (c) 2019 cybertronai
36
+ #
37
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
38
+ # of this software and associated documentation files (the "Software"), to deal
39
+ # in the Software without restriction, including without limitation the rights
40
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
41
+ # copies of the Software, and to permit persons to whom the Software is
42
+ # furnished to do so, subject to the following conditions:
43
+ #
44
+ # The above copyright notice and this permission notice shall be included in all
45
+ # copies or substantial portions of the Software.
46
+ #
47
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
48
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
49
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
50
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
51
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
52
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
53
+ # SOFTWARE.
54
+ import math
55
+
56
+ import torch
57
+ from torch.optim import Optimizer
58
+
59
+
60
+ class Lamb(Optimizer):
61
+ """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
62
+ reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
63
+
64
+ LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
65
+
66
+ Arguments:
67
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
68
+ lr (float, optional): learning rate. (default: 1e-3)
69
+ betas (Tuple[float, float], optional): coefficients used for computing
70
+ running averages of gradient and its norm. (default: (0.9, 0.999))
71
+ eps (float, optional): term added to the denominator to improve
72
+ numerical stability. (default: 1e-8)
73
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
74
+ grad_averaging (bool, optional): whether apply (1-beta2) to grad when
75
+ calculating running averages of gradient. (default: True)
76
+ max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
77
+ trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
78
+ always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
79
+ weight decay parameter (default: False)
80
+
81
+ .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
82
+ https://arxiv.org/abs/1904.00962
83
+ .. _On the Convergence of Adam and Beyond:
84
+ https://openreview.net/forum?id=ryQu7f-RZ
85
+ """
86
+
87
+ def __init__(
88
+ self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
89
+ weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
90
+ defaults = dict(
91
+ lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
92
+ grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
93
+ trust_clip=trust_clip, always_adapt=always_adapt)
94
+ super().__init__(params, defaults)
95
+
96
+ @torch.no_grad()
97
+ def step(self, closure=None):
98
+ """Performs a single optimization step.
99
+ Arguments:
100
+ closure (callable, optional): A closure that reevaluates the model
101
+ and returns the loss.
102
+ """
103
+ loss = None
104
+ if closure is not None:
105
+ with torch.enable_grad():
106
+ loss = closure()
107
+
108
+ device = self.param_groups[0]['params'][0].device
109
+ one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
110
+ global_grad_norm = torch.zeros(1, device=device)
111
+ for group in self.param_groups:
112
+ for p in group['params']:
113
+ if p.grad is None:
114
+ continue
115
+ grad = p.grad
116
+ if grad.is_sparse:
117
+ raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
118
+ global_grad_norm.add_(grad.pow(2).sum())
119
+
120
+ global_grad_norm = torch.sqrt(global_grad_norm)
121
+ # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
122
+ # scalar types properly https://github.com/pytorch/pytorch/issues/9190
123
+ max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
124
+ clip_global_grad_norm = torch.where(
125
+ global_grad_norm > max_grad_norm,
126
+ global_grad_norm / max_grad_norm,
127
+ one_tensor)
128
+
129
+ for group in self.param_groups:
130
+ bias_correction = 1 if group['bias_correction'] else 0
131
+ beta1, beta2 = group['betas']
132
+ grad_averaging = 1 if group['grad_averaging'] else 0
133
+ beta3 = 1 - beta1 if grad_averaging else 1.0
134
+
135
+ # assume same step across group now to simplify things
136
+ # per parameter step can be easily support by making it tensor, or pass list into kernel
137
+ if 'step' in group:
138
+ group['step'] += 1
139
+ else:
140
+ group['step'] = 1
141
+
142
+ if bias_correction:
143
+ bias_correction1 = 1 - beta1 ** group['step']
144
+ bias_correction2 = 1 - beta2 ** group['step']
145
+ else:
146
+ bias_correction1, bias_correction2 = 1.0, 1.0
147
+
148
+ for p in group['params']:
149
+ if p.grad is None:
150
+ continue
151
+ grad = p.grad.div_(clip_global_grad_norm)
152
+ state = self.state[p]
153
+
154
+ # State initialization
155
+ if len(state) == 0:
156
+ # Exponential moving average of gradient valuesa
157
+ state['exp_avg'] = torch.zeros_like(p)
158
+ # Exponential moving average of squared gradient values
159
+ state['exp_avg_sq'] = torch.zeros_like(p)
160
+
161
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
162
+
163
+ # Decay the first and second moment running average coefficient
164
+ exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
165
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
166
+
167
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
168
+ update = (exp_avg / bias_correction1).div_(denom)
169
+
170
+ weight_decay = group['weight_decay']
171
+ if weight_decay != 0:
172
+ update.add_(p, alpha=weight_decay)
173
+
174
+ if weight_decay != 0 or group['always_adapt']:
175
+ # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
176
+ # excluded from weight decay, unless always_adapt == True, then always enabled.
177
+ w_norm = p.norm(2.0)
178
+ g_norm = update.norm(2.0)
179
+ # FIXME nested where required since logical and/or not working in PT XLA
180
+ trust_ratio = torch.where(
181
+ w_norm > 0,
182
+ torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
183
+ one_tensor,
184
+ )
185
+ if group['trust_clip']:
186
+ # LAMBC trust clipping, upper bound fixed at one
187
+ trust_ratio = torch.minimum(trust_ratio, one_tensor)
188
+ update.mul_(trust_ratio)
189
+
190
+ p.add_(update, alpha=-group['lr'])
191
+
192
+ return loss
src/custom_timm/optim/lars.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch LARS / LARC Optimizer
2
+
3
+ An implementation of LARS (SGD) + LARC in PyTorch
4
+
5
+ Based on:
6
+ * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
7
+ * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
8
+
9
+ Additional cleanup and modifications to properly support PyTorch XLA.
10
+
11
+ Copyright 2021 Ross Wightman
12
+ """
13
+ import torch
14
+ from torch.optim.optimizer import Optimizer
15
+
16
+
17
+ class Lars(Optimizer):
18
+ """ LARS for PyTorch
19
+
20
+ Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
21
+
22
+ Args:
23
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
24
+ lr (float, optional): learning rate (default: 1.0).
25
+ momentum (float, optional): momentum factor (default: 0)
26
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
27
+ dampening (float, optional): dampening for momentum (default: 0)
28
+ nesterov (bool, optional): enables Nesterov momentum (default: False)
29
+ trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
30
+ eps (float): eps for division denominator (default: 1e-8)
31
+ trust_clip (bool): enable LARC trust ratio clipping (default: False)
32
+ always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ params,
38
+ lr=1.0,
39
+ momentum=0,
40
+ dampening=0,
41
+ weight_decay=0,
42
+ nesterov=False,
43
+ trust_coeff=0.001,
44
+ eps=1e-8,
45
+ trust_clip=False,
46
+ always_adapt=False,
47
+ ):
48
+ if lr < 0.0:
49
+ raise ValueError(f"Invalid learning rate: {lr}")
50
+ if momentum < 0.0:
51
+ raise ValueError(f"Invalid momentum value: {momentum}")
52
+ if weight_decay < 0.0:
53
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54
+ if nesterov and (momentum <= 0 or dampening != 0):
55
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
56
+
57
+ defaults = dict(
58
+ lr=lr,
59
+ momentum=momentum,
60
+ dampening=dampening,
61
+ weight_decay=weight_decay,
62
+ nesterov=nesterov,
63
+ trust_coeff=trust_coeff,
64
+ eps=eps,
65
+ trust_clip=trust_clip,
66
+ always_adapt=always_adapt,
67
+ )
68
+ super().__init__(params, defaults)
69
+
70
+ def __setstate__(self, state):
71
+ super().__setstate__(state)
72
+ for group in self.param_groups:
73
+ group.setdefault("nesterov", False)
74
+
75
+ @torch.no_grad()
76
+ def step(self, closure=None):
77
+ """Performs a single optimization step.
78
+
79
+ Args:
80
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
81
+ """
82
+ loss = None
83
+ if closure is not None:
84
+ with torch.enable_grad():
85
+ loss = closure()
86
+
87
+ device = self.param_groups[0]['params'][0].device
88
+ one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
89
+
90
+ for group in self.param_groups:
91
+ weight_decay = group['weight_decay']
92
+ momentum = group['momentum']
93
+ dampening = group['dampening']
94
+ nesterov = group['nesterov']
95
+ trust_coeff = group['trust_coeff']
96
+ eps = group['eps']
97
+
98
+ for p in group['params']:
99
+ if p.grad is None:
100
+ continue
101
+ grad = p.grad
102
+
103
+ # apply LARS LR adaptation, LARC clipping, weight decay
104
+ # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
105
+ if weight_decay != 0 or group['always_adapt']:
106
+ w_norm = p.norm(2.0)
107
+ g_norm = grad.norm(2.0)
108
+ trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
109
+ # FIXME nested where required since logical and/or not working in PT XLA
110
+ trust_ratio = torch.where(
111
+ w_norm > 0,
112
+ torch.where(g_norm > 0, trust_ratio, one_tensor),
113
+ one_tensor,
114
+ )
115
+ if group['trust_clip']:
116
+ trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
117
+ grad.add_(p, alpha=weight_decay)
118
+ grad.mul_(trust_ratio)
119
+
120
+ # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
121
+ if momentum != 0:
122
+ param_state = self.state[p]
123
+ if 'momentum_buffer' not in param_state:
124
+ buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
125
+ else:
126
+ buf = param_state['momentum_buffer']
127
+ buf.mul_(momentum).add_(grad, alpha=1. - dampening)
128
+ if nesterov:
129
+ grad = grad.add(buf, alpha=momentum)
130
+ else:
131
+ grad = buf
132
+
133
+ p.add_(grad, alpha=-group['lr'])
134
+
135
+ return loss
src/custom_timm/optim/lookahead.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Lookahead Optimizer Wrapper.
2
+ Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3
+ Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4
+
5
+ Hacked together by / Copyright 2020 Ross Wightman
6
+ """
7
+ import torch
8
+ from torch.optim.optimizer import Optimizer
9
+ from collections import defaultdict
10
+
11
+
12
+ class Lookahead(Optimizer):
13
+ def __init__(self, base_optimizer, alpha=0.5, k=6):
14
+ # NOTE super().__init__() not called on purpose
15
+ if not 0.0 <= alpha <= 1.0:
16
+ raise ValueError(f'Invalid slow update rate: {alpha}')
17
+ if not 1 <= k:
18
+ raise ValueError(f'Invalid lookahead steps: {k}')
19
+ defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
20
+ self._base_optimizer = base_optimizer
21
+ self.param_groups = base_optimizer.param_groups
22
+ self.defaults = base_optimizer.defaults
23
+ self.defaults.update(defaults)
24
+ self.state = defaultdict(dict)
25
+ # manually add our defaults to the param groups
26
+ for name, default in defaults.items():
27
+ for group in self._base_optimizer.param_groups:
28
+ group.setdefault(name, default)
29
+
30
+ @torch.no_grad()
31
+ def update_slow(self, group):
32
+ for fast_p in group["params"]:
33
+ if fast_p.grad is None:
34
+ continue
35
+ param_state = self._base_optimizer.state[fast_p]
36
+ if 'lookahead_slow_buff' not in param_state:
37
+ param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
38
+ param_state['lookahead_slow_buff'].copy_(fast_p)
39
+ slow = param_state['lookahead_slow_buff']
40
+ slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
41
+ fast_p.copy_(slow)
42
+
43
+ def sync_lookahead(self):
44
+ for group in self._base_optimizer.param_groups:
45
+ self.update_slow(group)
46
+
47
+ @torch.no_grad()
48
+ def step(self, closure=None):
49
+ loss = self._base_optimizer.step(closure)
50
+ for group in self._base_optimizer.param_groups:
51
+ group['lookahead_step'] += 1
52
+ if group['lookahead_step'] % group['lookahead_k'] == 0:
53
+ self.update_slow(group)
54
+ return loss
55
+
56
+ def state_dict(self):
57
+ return self._base_optimizer.state_dict()
58
+
59
+ def load_state_dict(self, state_dict):
60
+ self._base_optimizer.load_state_dict(state_dict)
61
+ self.param_groups = self._base_optimizer.param_groups
src/custom_timm/optim/madgrad.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch MADGRAD optimizer
2
+
3
+ MADGRAD: https://arxiv.org/abs/2101.11075
4
+
5
+ Code from: https://github.com/facebookresearch/madgrad
6
+ """
7
+ # Copyright (c) Facebook, Inc. and its affiliates.
8
+ #
9
+ # This source code is licensed under the MIT license found in the
10
+ # LICENSE file in the root directory of this source tree.
11
+
12
+ import math
13
+ from typing import TYPE_CHECKING, Any, Callable, Optional
14
+
15
+ import torch
16
+ import torch.optim
17
+
18
+ if TYPE_CHECKING:
19
+ from torch.optim.optimizer import _params_t
20
+ else:
21
+ _params_t = Any
22
+
23
+
24
+ class MADGRAD(torch.optim.Optimizer):
25
+ """
26
+ MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
27
+ Optimization.
28
+
29
+ .. _MADGRAD: https://arxiv.org/abs/2101.11075
30
+
31
+ MADGRAD is a general purpose optimizer that can be used in place of SGD or
32
+ Adam may converge faster and generalize better. Currently GPU-only.
33
+ Typically, the same learning rate schedule that is used for SGD or Adam may
34
+ be used. The overall learning rate is not comparable to either method and
35
+ should be determined by a hyper-parameter sweep.
36
+
37
+ MADGRAD requires less weight decay than other methods, often as little as
38
+ zero. Momentum values used for SGD or Adam's beta1 should work here also.
39
+
40
+ On sparse problems both weight_decay and momentum should be set to 0.
41
+
42
+ Arguments:
43
+ params (iterable):
44
+ Iterable of parameters to optimize or dicts defining parameter groups.
45
+ lr (float):
46
+ Learning rate (default: 1e-2).
47
+ momentum (float):
48
+ Momentum value in the range [0,1) (default: 0.9).
49
+ weight_decay (float):
50
+ Weight decay, i.e. a L2 penalty (default: 0).
51
+ eps (float):
52
+ Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ params: _params_t,
58
+ lr: float = 1e-2,
59
+ momentum: float = 0.9,
60
+ weight_decay: float = 0,
61
+ eps: float = 1e-6,
62
+ decoupled_decay: bool = False,
63
+ ):
64
+ if momentum < 0 or momentum >= 1:
65
+ raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
66
+ if lr <= 0:
67
+ raise ValueError(f"Learning rate {lr} must be positive")
68
+ if weight_decay < 0:
69
+ raise ValueError(f"Weight decay {weight_decay} must be non-negative")
70
+ if eps < 0:
71
+ raise ValueError(f"Eps must be non-negative")
72
+
73
+ defaults = dict(
74
+ lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
75
+ super().__init__(params, defaults)
76
+
77
+ @property
78
+ def supports_memory_efficient_fp16(self) -> bool:
79
+ return False
80
+
81
+ @property
82
+ def supports_flat_params(self) -> bool:
83
+ return True
84
+
85
+ @torch.no_grad()
86
+ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
87
+ """Performs a single optimization step.
88
+
89
+ Arguments:
90
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
91
+ """
92
+ loss = None
93
+ if closure is not None:
94
+ with torch.enable_grad():
95
+ loss = closure()
96
+
97
+ for group in self.param_groups:
98
+ eps = group['eps']
99
+ lr = group['lr'] + eps
100
+ weight_decay = group['weight_decay']
101
+ momentum = group['momentum']
102
+ ck = 1 - momentum
103
+
104
+ for p in group["params"]:
105
+ if p.grad is None:
106
+ continue
107
+ grad = p.grad
108
+ if momentum != 0.0 and grad.is_sparse:
109
+ raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
110
+
111
+ state = self.state[p]
112
+ if len(state) == 0:
113
+ state['step'] = 0
114
+ state['grad_sum_sq'] = torch.zeros_like(p)
115
+ state['s'] = torch.zeros_like(p)
116
+ if momentum != 0:
117
+ state['x0'] = torch.clone(p).detach()
118
+
119
+ state['step'] += 1
120
+ grad_sum_sq = state['grad_sum_sq']
121
+ s = state['s']
122
+ lamb = lr * math.sqrt(state['step'])
123
+
124
+ # Apply weight decay
125
+ if weight_decay != 0:
126
+ if group['decoupled_decay']:
127
+ p.mul_(1.0 - group['lr'] * weight_decay)
128
+ else:
129
+ if grad.is_sparse:
130
+ raise RuntimeError("weight_decay option is not compatible with sparse gradients")
131
+ grad.add_(p, alpha=weight_decay)
132
+
133
+ if grad.is_sparse:
134
+ grad = grad.coalesce()
135
+ grad_val = grad._values()
136
+
137
+ p_masked = p.sparse_mask(grad)
138
+ grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
139
+ s_masked = s.sparse_mask(grad)
140
+
141
+ # Compute x_0 from other known quantities
142
+ rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
143
+ x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
144
+
145
+ # Dense + sparse op
146
+ grad_sq = grad * grad
147
+ grad_sum_sq.add_(grad_sq, alpha=lamb)
148
+ grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
149
+
150
+ rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
151
+
152
+ s.add_(grad, alpha=lamb)
153
+ s_masked._values().add_(grad_val, alpha=lamb)
154
+
155
+ # update masked copy of p
156
+ p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
157
+ # Copy updated masked p to dense p using an add operation
158
+ p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
159
+ p.add_(p_masked, alpha=-1)
160
+ else:
161
+ if momentum == 0:
162
+ # Compute x_0 from other known quantities
163
+ rms = grad_sum_sq.pow(1 / 3).add_(eps)
164
+ x0 = p.addcdiv(s, rms, value=1)
165
+ else:
166
+ x0 = state['x0']
167
+
168
+ # Accumulate second moments
169
+ grad_sum_sq.addcmul_(grad, grad, value=lamb)
170
+ rms = grad_sum_sq.pow(1 / 3).add_(eps)
171
+
172
+ # Update s
173
+ s.add_(grad, alpha=lamb)
174
+
175
+ # Step
176
+ if momentum == 0:
177
+ p.copy_(x0.addcdiv(s, rms, value=-1))
178
+ else:
179
+ z = x0.addcdiv(s, rms, value=-1)
180
+
181
+ # p is a moving average of z
182
+ p.mul_(1 - ck).add_(z, alpha=ck)
183
+
184
+ return loss
src/custom_timm/optim/nadam.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.optim.optimizer import Optimizer
5
+
6
+
7
+ class Nadam(Optimizer):
8
+ """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
9
+
10
+ It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
11
+
12
+ Arguments:
13
+ params (iterable): iterable of parameters to optimize or dicts defining
14
+ parameter groups
15
+ lr (float, optional): learning rate (default: 2e-3)
16
+ betas (Tuple[float, float], optional): coefficients used for computing
17
+ running averages of gradient and its square
18
+ eps (float, optional): term added to the denominator to improve
19
+ numerical stability (default: 1e-8)
20
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
21
+ schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
22
+
23
+ __ http://cs229.stanford.edu/proj2015/054_report.pdf
24
+ __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
25
+
26
+ Originally taken from: https://github.com/pytorch/pytorch/pull/1408
27
+ NOTE: Has potential issues but does work well on some problems.
28
+ """
29
+
30
+ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
31
+ weight_decay=0, schedule_decay=4e-3):
32
+ if not 0.0 <= lr:
33
+ raise ValueError("Invalid learning rate: {}".format(lr))
34
+ defaults = dict(
35
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
36
+ super(Nadam, self).__init__(params, defaults)
37
+
38
+ @torch.no_grad()
39
+ def step(self, closure=None):
40
+ """Performs a single optimization step.
41
+
42
+ Arguments:
43
+ closure (callable, optional): A closure that reevaluates the model
44
+ and returns the loss.
45
+ """
46
+ loss = None
47
+ if closure is not None:
48
+ with torch.enable_grad():
49
+ loss = closure()
50
+
51
+ for group in self.param_groups:
52
+ for p in group['params']:
53
+ if p.grad is None:
54
+ continue
55
+ grad = p.grad
56
+ state = self.state[p]
57
+
58
+ # State initialization
59
+ if len(state) == 0:
60
+ state['step'] = 0
61
+ state['m_schedule'] = 1.
62
+ state['exp_avg'] = torch.zeros_like(p)
63
+ state['exp_avg_sq'] = torch.zeros_like(p)
64
+
65
+ # Warming momentum schedule
66
+ m_schedule = state['m_schedule']
67
+ schedule_decay = group['schedule_decay']
68
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
69
+ beta1, beta2 = group['betas']
70
+ eps = group['eps']
71
+ state['step'] += 1
72
+ t = state['step']
73
+ bias_correction2 = 1 - beta2 ** t
74
+
75
+ if group['weight_decay'] != 0:
76
+ grad = grad.add(p, alpha=group['weight_decay'])
77
+
78
+ momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
79
+ momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
80
+ m_schedule_new = m_schedule * momentum_cache_t
81
+ m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
82
+ state['m_schedule'] = m_schedule_new
83
+
84
+ # Decay the first and second moment running average coefficient
85
+ exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
86
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
87
+
88
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
89
+ p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
90
+ p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
91
+
92
+ return loss
src/custom_timm/optim/nvnovograd.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Nvidia NovoGrad Optimizer.
2
+ Original impl by Nvidia from Jasper example:
3
+ - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
4
+ Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
5
+ - https://arxiv.org/abs/1905.11286
6
+ """
7
+
8
+ import torch
9
+ from torch.optim.optimizer import Optimizer
10
+ import math
11
+
12
+
13
+ class NvNovoGrad(Optimizer):
14
+ """
15
+ Implements Novograd algorithm.
16
+
17
+ Args:
18
+ params (iterable): iterable of parameters to optimize or dicts defining
19
+ parameter groups
20
+ lr (float, optional): learning rate (default: 1e-3)
21
+ betas (Tuple[float, float], optional): coefficients used for computing
22
+ running averages of gradient and its square (default: (0.95, 0.98))
23
+ eps (float, optional): term added to the denominator to improve
24
+ numerical stability (default: 1e-8)
25
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
26
+ grad_averaging: gradient averaging
27
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
28
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
29
+ (default: False)
30
+ """
31
+
32
+ def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
33
+ weight_decay=0, grad_averaging=False, amsgrad=False):
34
+ if not 0.0 <= lr:
35
+ raise ValueError("Invalid learning rate: {}".format(lr))
36
+ if not 0.0 <= eps:
37
+ raise ValueError("Invalid epsilon value: {}".format(eps))
38
+ if not 0.0 <= betas[0] < 1.0:
39
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
40
+ if not 0.0 <= betas[1] < 1.0:
41
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
42
+ defaults = dict(lr=lr, betas=betas, eps=eps,
43
+ weight_decay=weight_decay,
44
+ grad_averaging=grad_averaging,
45
+ amsgrad=amsgrad)
46
+
47
+ super(NvNovoGrad, self).__init__(params, defaults)
48
+
49
+ def __setstate__(self, state):
50
+ super(NvNovoGrad, self).__setstate__(state)
51
+ for group in self.param_groups:
52
+ group.setdefault('amsgrad', False)
53
+
54
+ @torch.no_grad()
55
+ def step(self, closure=None):
56
+ """Performs a single optimization step.
57
+
58
+ Arguments:
59
+ closure (callable, optional): A closure that reevaluates the model
60
+ and returns the loss.
61
+ """
62
+ loss = None
63
+ if closure is not None:
64
+ with torch.enable_grad():
65
+ loss = closure()
66
+
67
+ for group in self.param_groups:
68
+ for p in group['params']:
69
+ if p.grad is None:
70
+ continue
71
+ grad = p.grad
72
+ if grad.is_sparse:
73
+ raise RuntimeError('Sparse gradients are not supported.')
74
+ amsgrad = group['amsgrad']
75
+
76
+ state = self.state[p]
77
+
78
+ # State initialization
79
+ if len(state) == 0:
80
+ state['step'] = 0
81
+ # Exponential moving average of gradient values
82
+ state['exp_avg'] = torch.zeros_like(p)
83
+ # Exponential moving average of squared gradient values
84
+ state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
85
+ if amsgrad:
86
+ # Maintains max of all exp. moving avg. of sq. grad. values
87
+ state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
88
+
89
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
90
+ if amsgrad:
91
+ max_exp_avg_sq = state['max_exp_avg_sq']
92
+ beta1, beta2 = group['betas']
93
+
94
+ state['step'] += 1
95
+
96
+ norm = torch.sum(torch.pow(grad, 2))
97
+
98
+ if exp_avg_sq == 0:
99
+ exp_avg_sq.copy_(norm)
100
+ else:
101
+ exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
102
+
103
+ if amsgrad:
104
+ # Maintains the maximum of all 2nd moment running avg. till now
105
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
106
+ # Use the max. for normalizing running avg. of gradient
107
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
108
+ else:
109
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
110
+
111
+ grad.div_(denom)
112
+ if group['weight_decay'] != 0:
113
+ grad.add_(p, alpha=group['weight_decay'])
114
+ if group['grad_averaging']:
115
+ grad.mul_(1 - beta1)
116
+ exp_avg.mul_(beta1).add_(grad)
117
+
118
+ p.add_(exp_avg, alpha=-group['lr'])
119
+
120
+ return loss
src/custom_timm/optim/optim_factory.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Optimizer Factory w/ Custom Weight Decay
2
+ Hacked together by / Copyright 2021 Ross Wightman
3
+ """
4
+ import logging
5
+ from itertools import islice
6
+ from typing import Optional, Callable, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+
12
+ from custom_timm.models.helpers import group_parameters
13
+
14
+ from .adabelief import AdaBelief
15
+ from .adafactor import Adafactor
16
+ from .adahessian import Adahessian
17
+ from .adamp import AdamP
18
+ from .lamb import Lamb
19
+ from .lars import Lars
20
+ from .lookahead import Lookahead
21
+ from .madgrad import MADGRAD
22
+ from .nadam import Nadam
23
+ from .nvnovograd import NvNovoGrad
24
+ from .radam import RAdam
25
+ from .rmsprop_tf import RMSpropTF
26
+ from .sgdp import SGDP
27
+
28
+ try:
29
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
30
+ has_apex = True
31
+ except ImportError:
32
+ has_apex = False
33
+
34
+ _logger = logging.getLogger(__name__)
35
+
36
+
37
+ def param_groups_weight_decay(
38
+ model: nn.Module,
39
+ weight_decay=1e-5,
40
+ no_weight_decay_list=()
41
+ ):
42
+ no_weight_decay_list = set(no_weight_decay_list)
43
+ decay = []
44
+ no_decay = []
45
+ for name, param in model.named_parameters():
46
+ if not param.requires_grad:
47
+ continue
48
+
49
+ if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
50
+ no_decay.append(param)
51
+ else:
52
+ decay.append(param)
53
+
54
+ return [
55
+ {'params': no_decay, 'weight_decay': 0.},
56
+ {'params': decay, 'weight_decay': weight_decay}]
57
+
58
+
59
+ def _group(it, size):
60
+ it = iter(it)
61
+ return iter(lambda: tuple(islice(it, size)), ())
62
+
63
+
64
+ def _layer_map(model, layers_per_group=12, num_groups=None):
65
+ def _in_head(n, hp):
66
+ if not hp:
67
+ return True
68
+ elif isinstance(hp, (tuple, list)):
69
+ return any([n.startswith(hpi) for hpi in hp])
70
+ else:
71
+ return n.startswith(hp)
72
+
73
+ head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
74
+ names_trunk = []
75
+ names_head = []
76
+ for n, _ in model.named_parameters():
77
+ names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
78
+
79
+ # group non-head layers
80
+ num_trunk_layers = len(names_trunk)
81
+ if num_groups is not None:
82
+ layers_per_group = -(num_trunk_layers // -num_groups)
83
+ names_trunk = list(_group(names_trunk, layers_per_group))
84
+
85
+ num_trunk_groups = len(names_trunk)
86
+ layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
87
+ layer_map.update({n: num_trunk_groups for n in names_head})
88
+ return layer_map
89
+
90
+
91
+ def param_groups_layer_decay(
92
+ model: nn.Module,
93
+ weight_decay: float = 0.05,
94
+ no_weight_decay_list: Tuple[str] = (),
95
+ layer_decay: float = .75,
96
+ end_layer_decay: Optional[float] = None,
97
+ verbose: bool = False,
98
+ ):
99
+ """
100
+ Parameter groups for layer-wise lr decay & weight decay
101
+ Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
102
+ """
103
+ no_weight_decay_list = set(no_weight_decay_list)
104
+ param_group_names = {} # NOTE for debugging
105
+ param_groups = {}
106
+
107
+ if hasattr(model, 'group_matcher'):
108
+ # FIXME interface needs more work
109
+ layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
110
+ else:
111
+ # fallback
112
+ layer_map = _layer_map(model)
113
+ num_layers = max(layer_map.values()) + 1
114
+ layer_max = num_layers - 1
115
+ layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
116
+
117
+ for name, param in model.named_parameters():
118
+ if not param.requires_grad:
119
+ continue
120
+
121
+ # no decay: all 1D parameters and model specific ones
122
+ if param.ndim == 1 or name in no_weight_decay_list:
123
+ g_decay = "no_decay"
124
+ this_decay = 0.
125
+ else:
126
+ g_decay = "decay"
127
+ this_decay = weight_decay
128
+
129
+ layer_id = layer_map.get(name, layer_max)
130
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
131
+
132
+ if group_name not in param_groups:
133
+ this_scale = layer_scales[layer_id]
134
+ param_group_names[group_name] = {
135
+ "lr_scale": this_scale,
136
+ "weight_decay": this_decay,
137
+ "param_names": [],
138
+ }
139
+ param_groups[group_name] = {
140
+ "lr_scale": this_scale,
141
+ "weight_decay": this_decay,
142
+ "params": [],
143
+ }
144
+
145
+ param_group_names[group_name]["param_names"].append(name)
146
+ param_groups[group_name]["params"].append(param)
147
+
148
+ if verbose:
149
+ import json
150
+ _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
151
+
152
+ return list(param_groups.values())
153
+
154
+
155
+ def optimizer_kwargs(cfg):
156
+ """ cfg/argparse to kwargs helper
157
+ Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
158
+ """
159
+ kwargs = dict(
160
+ opt=cfg.opt,
161
+ lr=cfg.lr,
162
+ weight_decay=cfg.weight_decay,
163
+ momentum=cfg.momentum)
164
+ if getattr(cfg, 'opt_eps', None) is not None:
165
+ kwargs['eps'] = cfg.opt_eps
166
+ if getattr(cfg, 'opt_betas', None) is not None:
167
+ kwargs['betas'] = cfg.opt_betas
168
+ if getattr(cfg, 'layer_decay', None) is not None:
169
+ kwargs['layer_decay'] = cfg.layer_decay
170
+ if getattr(cfg, 'opt_args', None) is not None:
171
+ kwargs.update(cfg.opt_args)
172
+ return kwargs
173
+
174
+
175
+ def create_optimizer(args, model, filter_bias_and_bn=True):
176
+ """ Legacy optimizer factory for backwards compatibility.
177
+ NOTE: Use create_optimizer_v2 for new code.
178
+ """
179
+ return create_optimizer_v2(
180
+ model,
181
+ **optimizer_kwargs(cfg=args),
182
+ filter_bias_and_bn=filter_bias_and_bn,
183
+ )
184
+
185
+
186
+ def create_optimizer_v2(
187
+ model_or_params,
188
+ opt: str = 'sgd',
189
+ lr: Optional[float] = None,
190
+ weight_decay: float = 0.,
191
+ momentum: float = 0.9,
192
+ filter_bias_and_bn: bool = True,
193
+ layer_decay: Optional[float] = None,
194
+ param_group_fn: Optional[Callable] = None,
195
+ **kwargs):
196
+ """ Create an optimizer.
197
+
198
+ TODO currently the model is passed in and all parameters are selected for optimization.
199
+ For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
200
+ * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
201
+ * expose the parameters interface and leave it up to caller
202
+
203
+ Args:
204
+ model_or_params (nn.Module): model containing parameters to optimize
205
+ opt: name of optimizer to create
206
+ lr: initial learning rate
207
+ weight_decay: weight decay to apply in optimizer
208
+ momentum: momentum for momentum based optimizers (others may use betas via kwargs)
209
+ filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
210
+ **kwargs: extra optimizer specific kwargs to pass through
211
+
212
+ Returns:
213
+ Optimizer
214
+ """
215
+ if isinstance(model_or_params, nn.Module):
216
+ # a model was passed in, extract parameters and add weight decays to appropriate layers
217
+ no_weight_decay = {}
218
+ if hasattr(model_or_params, 'no_weight_decay'):
219
+ no_weight_decay = model_or_params.no_weight_decay()
220
+
221
+ if param_group_fn:
222
+ parameters = param_group_fn(model_or_params)
223
+ elif layer_decay is not None:
224
+ parameters = param_groups_layer_decay(
225
+ model_or_params,
226
+ weight_decay=weight_decay,
227
+ layer_decay=layer_decay,
228
+ no_weight_decay_list=no_weight_decay)
229
+ weight_decay = 0.
230
+ elif weight_decay and filter_bias_and_bn:
231
+ parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
232
+ weight_decay = 0.
233
+ else:
234
+ parameters = model_or_params.parameters()
235
+ else:
236
+ # iterable of parameters or param groups passed in
237
+ parameters = model_or_params
238
+
239
+ opt_lower = opt.lower()
240
+ opt_split = opt_lower.split('_')
241
+ opt_lower = opt_split[-1]
242
+ if 'fused' in opt_lower:
243
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
244
+
245
+ opt_args = dict(weight_decay=weight_decay, **kwargs)
246
+ if lr is not None:
247
+ opt_args.setdefault('lr', lr)
248
+
249
+ # basic SGD & related
250
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
251
+ # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
252
+ opt_args.pop('eps', None)
253
+ optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
254
+ elif opt_lower == 'momentum':
255
+ opt_args.pop('eps', None)
256
+ optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
257
+ elif opt_lower == 'sgdp':
258
+ optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
259
+
260
+ # adaptive
261
+ elif opt_lower == 'adam':
262
+ optimizer = optim.Adam(parameters, **opt_args)
263
+ elif opt_lower == 'adamw':
264
+ optimizer = optim.AdamW(parameters, **opt_args)
265
+ elif opt_lower == 'adamp':
266
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
267
+ elif opt_lower == 'nadam':
268
+ try:
269
+ # NOTE PyTorch >= 1.10 should have native NAdam
270
+ optimizer = optim.Nadam(parameters, **opt_args)
271
+ except AttributeError:
272
+ optimizer = Nadam(parameters, **opt_args)
273
+ elif opt_lower == 'radam':
274
+ optimizer = RAdam(parameters, **opt_args)
275
+ elif opt_lower == 'adamax':
276
+ optimizer = optim.Adamax(parameters, **opt_args)
277
+ elif opt_lower == 'adabelief':
278
+ optimizer = AdaBelief(parameters, rectify=False, **opt_args)
279
+ elif opt_lower == 'radabelief':
280
+ optimizer = AdaBelief(parameters, rectify=True, **opt_args)
281
+ elif opt_lower == 'adadelta':
282
+ optimizer = optim.Adadelta(parameters, **opt_args)
283
+ elif opt_lower == 'adagrad':
284
+ opt_args.setdefault('eps', 1e-8)
285
+ optimizer = optim.Adagrad(parameters, **opt_args)
286
+ elif opt_lower == 'adafactor':
287
+ optimizer = Adafactor(parameters, **opt_args)
288
+ elif opt_lower == 'lamb':
289
+ optimizer = Lamb(parameters, **opt_args)
290
+ elif opt_lower == 'lambc':
291
+ optimizer = Lamb(parameters, trust_clip=True, **opt_args)
292
+ elif opt_lower == 'larc':
293
+ optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
294
+ elif opt_lower == 'lars':
295
+ optimizer = Lars(parameters, momentum=momentum, **opt_args)
296
+ elif opt_lower == 'nlarc':
297
+ optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
298
+ elif opt_lower == 'nlars':
299
+ optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
300
+ elif opt_lower == 'madgrad':
301
+ optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
302
+ elif opt_lower == 'madgradw':
303
+ optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
304
+ elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
305
+ optimizer = NvNovoGrad(parameters, **opt_args)
306
+ elif opt_lower == 'rmsprop':
307
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
308
+ elif opt_lower == 'rmsproptf':
309
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
310
+
311
+ # second order
312
+ elif opt_lower == 'adahessian':
313
+ optimizer = Adahessian(parameters, **opt_args)
314
+
315
+ # NVIDIA fused optimizers, require APEX to be installed
316
+ elif opt_lower == 'fusedsgd':
317
+ opt_args.pop('eps', None)
318
+ optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
319
+ elif opt_lower == 'fusedmomentum':
320
+ opt_args.pop('eps', None)
321
+ optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
322
+ elif opt_lower == 'fusedadam':
323
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
324
+ elif opt_lower == 'fusedadamw':
325
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
326
+ elif opt_lower == 'fusedlamb':
327
+ optimizer = FusedLAMB(parameters, **opt_args)
328
+ elif opt_lower == 'fusednovograd':
329
+ opt_args.setdefault('betas', (0.95, 0.98))
330
+ optimizer = FusedNovoGrad(parameters, **opt_args)
331
+
332
+ else:
333
+ assert False and "Invalid optimizer"
334
+ raise ValueError
335
+
336
+ if len(opt_split) > 1:
337
+ if opt_split[0] == 'lookahead':
338
+ optimizer = Lookahead(optimizer)
339
+
340
+ return optimizer
src/custom_timm/optim/radam.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAdam Optimizer.
2
+ Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
3
+ Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4
+ """
5
+ import math
6
+ import torch
7
+ from torch.optim.optimizer import Optimizer
8
+
9
+
10
+ class RAdam(Optimizer):
11
+
12
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
13
+ defaults = dict(
14
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
15
+ buffer=[[None, None, None] for _ in range(10)])
16
+ super(RAdam, self).__init__(params, defaults)
17
+
18
+ def __setstate__(self, state):
19
+ super(RAdam, self).__setstate__(state)
20
+
21
+ @torch.no_grad()
22
+ def step(self, closure=None):
23
+ loss = None
24
+ if closure is not None:
25
+ with torch.enable_grad():
26
+ loss = closure()
27
+
28
+ for group in self.param_groups:
29
+
30
+ for p in group['params']:
31
+ if p.grad is None:
32
+ continue
33
+ grad = p.grad.float()
34
+ if grad.is_sparse:
35
+ raise RuntimeError('RAdam does not support sparse gradients')
36
+
37
+ p_fp32 = p.float()
38
+
39
+ state = self.state[p]
40
+
41
+ if len(state) == 0:
42
+ state['step'] = 0
43
+ state['exp_avg'] = torch.zeros_like(p_fp32)
44
+ state['exp_avg_sq'] = torch.zeros_like(p_fp32)
45
+ else:
46
+ state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
47
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
48
+
49
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
50
+ beta1, beta2 = group['betas']
51
+
52
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
53
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
54
+
55
+ state['step'] += 1
56
+ buffered = group['buffer'][int(state['step'] % 10)]
57
+ if state['step'] == buffered[0]:
58
+ num_sma, step_size = buffered[1], buffered[2]
59
+ else:
60
+ buffered[0] = state['step']
61
+ beta2_t = beta2 ** state['step']
62
+ num_sma_max = 2 / (1 - beta2) - 1
63
+ num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
64
+ buffered[1] = num_sma
65
+
66
+ # more conservative since it's an approximated value
67
+ if num_sma >= 5:
68
+ step_size = group['lr'] * math.sqrt(
69
+ (1 - beta2_t) *
70
+ (num_sma - 4) / (num_sma_max - 4) *
71
+ (num_sma - 2) / num_sma *
72
+ num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
73
+ else:
74
+ step_size = group['lr'] / (1 - beta1 ** state['step'])
75
+ buffered[2] = step_size
76
+
77
+ if group['weight_decay'] != 0:
78
+ p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
79
+
80
+ # more conservative since it's an approximated value
81
+ if num_sma >= 5:
82
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
83
+ p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
84
+ else:
85
+ p_fp32.add_(exp_avg, alpha=-step_size)
86
+
87
+ p.copy_(p_fp32)
88
+
89
+ return loss
src/custom_timm/optim/rmsprop_tf.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ RMSProp modified to behave like Tensorflow impl
2
+
3
+ Originally cut & paste from PyTorch RMSProp
4
+ https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
5
+ Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
6
+
7
+ Modifications Copyright 2021 Ross Wightman
8
+ """
9
+
10
+ import torch
11
+ from torch.optim import Optimizer
12
+
13
+
14
+ class RMSpropTF(Optimizer):
15
+ """Implements RMSprop algorithm (TensorFlow style epsilon)
16
+
17
+ NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
18
+ and a few other modifications to closer match Tensorflow for matching hyper-params.
19
+
20
+ Noteworthy changes include:
21
+ 1. Epsilon applied inside square-root
22
+ 2. square_avg initialized to ones
23
+ 3. LR scaling of update accumulated in momentum buffer
24
+
25
+ Proposed by G. Hinton in his
26
+ `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
27
+
28
+ The centered version first appears in `Generating Sequences
29
+ With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
30
+
31
+ Arguments:
32
+ params (iterable): iterable of parameters to optimize or dicts defining
33
+ parameter groups
34
+ lr (float, optional): learning rate (default: 1e-2)
35
+ momentum (float, optional): momentum factor (default: 0)
36
+ alpha (float, optional): smoothing (decay) constant (default: 0.9)
37
+ eps (float, optional): term added to the denominator to improve
38
+ numerical stability (default: 1e-10)
39
+ centered (bool, optional) : if ``True``, compute the centered RMSProp,
40
+ the gradient is normalized by an estimation of its variance
41
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42
+ decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
43
+ lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
44
+ update as per defaults in Tensorflow
45
+
46
+ """
47
+
48
+ def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
49
+ decoupled_decay=False, lr_in_momentum=True):
50
+ if not 0.0 <= lr:
51
+ raise ValueError("Invalid learning rate: {}".format(lr))
52
+ if not 0.0 <= eps:
53
+ raise ValueError("Invalid epsilon value: {}".format(eps))
54
+ if not 0.0 <= momentum:
55
+ raise ValueError("Invalid momentum value: {}".format(momentum))
56
+ if not 0.0 <= weight_decay:
57
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
58
+ if not 0.0 <= alpha:
59
+ raise ValueError("Invalid alpha value: {}".format(alpha))
60
+
61
+ defaults = dict(
62
+ lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
63
+ decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
64
+ super(RMSpropTF, self).__init__(params, defaults)
65
+
66
+ def __setstate__(self, state):
67
+ super(RMSpropTF, self).__setstate__(state)
68
+ for group in self.param_groups:
69
+ group.setdefault('momentum', 0)
70
+ group.setdefault('centered', False)
71
+
72
+ @torch.no_grad()
73
+ def step(self, closure=None):
74
+ """Performs a single optimization step.
75
+
76
+ Arguments:
77
+ closure (callable, optional): A closure that reevaluates the model
78
+ and returns the loss.
79
+ """
80
+ loss = None
81
+ if closure is not None:
82
+ with torch.enable_grad():
83
+ loss = closure()
84
+
85
+ for group in self.param_groups:
86
+ for p in group['params']:
87
+ if p.grad is None:
88
+ continue
89
+ grad = p.grad
90
+ if grad.is_sparse:
91
+ raise RuntimeError('RMSprop does not support sparse gradients')
92
+ state = self.state[p]
93
+
94
+ # State initialization
95
+ if len(state) == 0:
96
+ state['step'] = 0
97
+ state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
98
+ if group['momentum'] > 0:
99
+ state['momentum_buffer'] = torch.zeros_like(p)
100
+ if group['centered']:
101
+ state['grad_avg'] = torch.zeros_like(p)
102
+
103
+ square_avg = state['square_avg']
104
+ one_minus_alpha = 1. - group['alpha']
105
+
106
+ state['step'] += 1
107
+
108
+ if group['weight_decay'] != 0:
109
+ if group['decoupled_decay']:
110
+ p.mul_(1. - group['lr'] * group['weight_decay'])
111
+ else:
112
+ grad = grad.add(p, alpha=group['weight_decay'])
113
+
114
+ # Tensorflow order of ops for updating squared avg
115
+ square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
116
+ # square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
117
+
118
+ if group['centered']:
119
+ grad_avg = state['grad_avg']
120
+ grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
121
+ avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
122
+ # grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original
123
+ else:
124
+ avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
125
+
126
+ if group['momentum'] > 0:
127
+ buf = state['momentum_buffer']
128
+ # Tensorflow accumulates the LR scaling in the momentum buffer
129
+ if group['lr_in_momentum']:
130
+ buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
131
+ p.add_(-buf)
132
+ else:
133
+ # PyTorch scales the param update by LR
134
+ buf.mul_(group['momentum']).addcdiv_(grad, avg)
135
+ p.add_(buf, alpha=-group['lr'])
136
+ else:
137
+ p.addcdiv_(grad, avg, value=-group['lr'])
138
+
139
+ return loss